In [1]:
# encoding=utf-8
import os.path as osp
import os
import copy
import matplotlib.pyplot as plt
import torch
from torch.nn import Linear
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.data import TemporalData
from torch_geometric.datasets import JODIEDataset
from torch_geometric.datasets import ICEWS18
from torch_geometric.nn import TGNMemory, TransformerConv
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models.tgn import (LastNeighborLoader, IdentityMessage, MeanAggregator,
                                           LastAggregator)
from torch_geometric import *
from torch_geometric.utils import negative_sampling
from tqdm import tqdm
import networkx as nx
import numpy as np
import math
import copy
import re
import time
import json
import pandas as pd
from random import choice
import gc

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
# msg structure:      [src_node_feature,edge_attr,dst_node_feature]

# compute the best partition
import datetime
# import community as community_louvain

import xxhash

# Find the edge index which the edge vector is corresponding to
def tensor_find(t,x):
    """
    Find the 1-based row index of the first occurrence of a value in a PyTorch tensor.

    Parameters:
    t (torch.Tensor): The PyTorch tensor to search.
    x (Any): The value to locate within the tensor.

    Returns:
    int: The 1-based row index of the first occurrence of the value in the tensor.
    """
    t_np=t.cpu().numpy()
    idx=np.argwhere(t_np==x)
    return idx[0][0]+1


def std(t):
    """
    Calculate the standard deviation of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the standard deviation.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The standard deviation of the input data.
    """
    t = np.array(t)
    return np.std(t)


def var(t):
    """
    Calculate the variance of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the variance.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The variance of the input data.
    """
    t = np.array(t)
    return np.var(t)


def mean(t):
    """
    Calculate the mean of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the mean.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The mean of the input data.
    """
    t = np.array(t)
    return np.mean(t)

def hashgen(l):
    """
    Generate a single hash value from a list of string values.

    Parameters:
    l (list of str): A list of string values, which can represent properties of a node or edge.

    Returns:
    int: A single hashed integer value generated from the input list.
    """
    hasher = xxhash.xxh64()
    for e in l:
        hasher.update(e)
    return hasher.intdigest()


def cal_pos_edges_loss(link_pred_ratio):
    """
    Calculate the loss for positive edges using a binary classification approach.

    Parameters:
    link_pred_ratio (list or Tensor): A list or tensor containing predicted link probabilities.

    Returns:
    Tensor: A tensor containing the computed loss for each prediction.
    """
    loss = []
    for i in link_pred_ratio:
        # Compare predicted values with a target tensor of ones (positive class)
        loss.append(criterion(i, torch.ones(1)))
    return torch.tensor(loss)


def cal_pos_edges_loss_multiclass(link_pred_ratio, labels):
    """
    Calculate the loss for positive edges in a multi-class classification setting.

    Parameters:
    link_pred_ratio (list of Tensors): A list of tensors containing predicted class probabilities for links.
    labels (list of Tensors): A list of tensors containing the ground truth labels for the links.

    Returns:
    Tensor: A tensor containing the computed loss for each prediction.
    """
    loss = []
    for i in range(len(link_pred_ratio)):
        # Compare predicted values with ground truth labels for multi-class classification
        loss.append(criterion(link_pred_ratio[i].reshape(1, -1), labels[i].reshape(-1)))
    return torch.tensor(loss)


def cal_pos_edges_loss_autoencoder(decoded, msg):
    """
    Calculate the loss for positive edges in an autoencoder setting.

    Parameters:
    decoded (list of Tensors): A list of tensors representing the decoded outputs.
    msg (list of Tensors): A list of tensors representing the original messages (inputs to the autoencoder).

    Returns:
    Tensor: A tensor containing the computed loss for each decoded message.
    """
    loss = []
    for i in range(len(decoded)):
        # Compare the decoded outputs with the original messages
        loss.append(criterion(decoded[i].reshape(1, -1), msg[i].reshape(-1)))
    return torch.tensor(loss)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%autosave 120  

Autosaving every 120 seconds


In [3]:
from datetime import datetime, timezone
import time
import pytz
from time import mktime
from datetime import datetime
import time
def ns_time_to_datetime(ns):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00.000000000
    """
    dt = datetime.fromtimestamp(int(ns) // 1000000000)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')
    s += '.' + str(int(int(ns) % 1000000000)).zfill(9)
    return s

def ns_time_to_datetime_US(ns):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00.000000000
    """
    tz = pytz.timezone('US/Eastern')
    dt = pytz.datetime.datetime.fromtimestamp(int(ns) // 1000000000, tz)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')
    s += '.' + str(int(int(ns) % 1000000000)).zfill(9)
    return s

def time_to_datetime_US(s):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00
    """
    tz = pytz.timezone('US/Eastern')
    dt = pytz.datetime.datetime.fromtimestamp(int(s), tz)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')

    return s

def datetime_to_ns_time(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    timeStamp = int(time.mktime(timeArray))
    timeStamp = timeStamp * 1000000000
    return timeStamp

def datetime_to_ns_time_US(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    tz = pytz.timezone('US/Eastern')
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    dt = datetime.fromtimestamp(mktime(timeArray))
    timestamp = tz.localize(dt)
    timestamp = timestamp.timestamp()
    timeStamp = timestamp * 1000000000
    return int(timeStamp)

def datetime_to_timestamp_US(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    tz = pytz.timezone('US/Eastern')
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    dt = datetime.fromtimestamp(mktime(timeArray))
    timestamp = tz.localize(dt)
    timestamp = timestamp.timestamp()
    timeStamp = timestamp
    return int(timeStamp)

In [4]:
import psycopg2

from psycopg2 import extras as ex
connect = psycopg2.connect(database = 'tc_clearscope3_dataset_db',
                           host = 'localhost',
                           user = 'postgres',
                           password = '123456',
                           port = '5432'
                          )

cur = connect.cursor()

In [5]:
graph_4_4=torch.load("./train_graphs/graph_4_4.TemporalData.simple").to(device=device)
graph_4_5=torch.load("./train_graphs/graph_4_5.TemporalData.simple").to(device=device)
graph_4_6=torch.load("./train_graphs/graph_4_6.TemporalData.simple").to(device=device)

train_data=graph_4_4

In [6]:
# Constructing the map for nodeid to msg
sql="select * from node2id ORDER BY index_id;"
cur.execute(sql)
rows = cur.fetchall()

# node hash => nodeid (index) and nodeid (index) => msg (dictionary)
nodeid2msg={}
for i in rows:
    nodeid2msg[i[0]]=i[-1]
    nodeid2msg[i[-1]]={i[1]:i[2]}  

In [7]:
rel2id={1: 'EVENT_CLOSE',
 'EVENT_CLOSE': 1,
 2: 'EVENT_OPEN',
 'EVENT_OPEN': 2,
 3: 'EVENT_READ',
 'EVENT_READ': 3,
 4: 'EVENT_WRITE',
 'EVENT_WRITE': 4,
 5: 'EVENT_RECVFROM',
 'EVENT_RECVFROM': 5,
 6: 'EVENT_RECVMSG',
 'EVENT_RECVMSG': 6,
 7: 'EVENT_SENDMSG',
 'EVENT_SENDMSG': 7,
 8: 'EVENT_SENDTO',
 'EVENT_SENDTO': 8}

In [8]:
# train_data, val_data, test_data = data.train_val_test_split(val_ratio=0.15, test_ratio=0.15)
# max_node_num = max(torch.cat([data.dst,data.src]))+1
# max_node_num = data.num_nodes+1
max_node_num = 172724  # +1
# min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
min_dst_idx, max_dst_idx = 0, max_node_num

# The LastNeighborLoader is used in TGN models to handle neighbor sampling dynamically by storing and retrieving the most recent neighbors of nodes in the graph. 
# it receives total number of nodes to store, and max number of neighbors to store for each node
neighbor_loader = LastNeighborLoader(max_node_num, size=20, device=device)

In [9]:
class GraphAttentionEmbedding(torch.nn.Module):
    """
    The GraphAttentionEmbedding class implements a graph attention-based 
    embedding model using a two-layer TransformerConv. 
    It is designed for temporal graph networks, where the embedding considers 
    both node features and edge attributes, including temporal information.

    Parameters:
    - in_channels (int): Input feature dimensionality for nodes.
    - out_channels (int): Output feature dimensionality for nodes.
    - msg_dim (int): Dimensionality of edge message features.
    - time_enc (TimeEncoder): Time encoding module.
    """
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super(GraphAttentionEmbedding, self).__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels

        # First TransformerConv layer with 8 heads
        self.conv = TransformerConv(in_channels, out_channels, heads=8,
                                    dropout=0.0, edge_dim=edge_dim)
        
        # Second TransformerConv layer with 1 head, no concatenation
        self.conv2 = TransformerConv(out_channels*8, out_channels,heads=1, concat=False,
                             dropout=0.0, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        """
        Forward pass of the GraphAttentionEmbedding model.

        Parameters:
        - x (torch.Tensor): Node features of shape (num_nodes, in_channels).
        - last_update (torch.Tensor): Timestamps of last updates for each node.
        - edge_index (torch.Tensor): Edge index in COO format of shape (2, num_edges).
        - t (torch.Tensor): Edge timestamps of shape (num_edges,).
        - msg (torch.Tensor): Edge message features of shape (num_edges, msg_dim).

        Returns:
        - torch.Tensor: Updated node embeddings of shape (num_nodes, out_channels).
        """
        last_update.to(device)
        x = x.to(device)
        t = t.to(device)

        # Compute relative time differences and encode them
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))

        # Concatenate relative time encoding and message features as edge attributes
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)

        # Apply the first TransformerConv layer with ReLU activation
        x = F.relu(self.conv(x, edge_index, edge_attr))

        # Apply the second TransformerConv layer with ReLU activation
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        return x


class LinkPredictor(torch.nn.Module):
    """
    A neural network for Event Type pediction tasks, transforming node embeddings
    and predicting link properties.

    Parameters:
    - in_channels (int): The dimensionality of the input node embeddings.
    """

    def __init__(self, in_channels):
        super(LinkPredictor, self).__init__()
        
        # Linear transformations for source and destination embeddings
        self.lin_src = Linear(in_channels, in_channels*2)
        self.lin_dst = Linear(in_channels, in_channels*2)
        
        # Sequential feedforward network for link prediction
        self.lin_seq = nn.Sequential(
            Linear(in_channels*4, in_channels*8),
            torch.nn.Dropout(0.5),
            nn.Tanh(),
            Linear(in_channels*8, in_channels*2),
            torch.nn.Dropout(0.5),
            nn.Tanh(),
            Linear(in_channels*2, int(in_channels//2)),
            torch.nn.Dropout(0.5),
            nn.Tanh(),

            # Final prediction layer
            # train_data.msg contains feature vector of src and dest and edge (event) information.
            # train_data.msg.shape[1]-32 is the size of the edge (event) info
            Linear(int(in_channels//2), train_data.msg.shape[1]-32)                   
        )
        

    def forward(self, z_src, z_dst):
        """
        Forward pass for the LinkPredictor.

        Parameters:
        - z_src (torch.Tensor): Source node embeddings, shape (batch_size, in_channels).
        - z_dst (torch.Tensor): Destination node embeddings, shape (batch_size, in_channels).

        Returns:
        - torch.Tensor: Predicted Edge (event) properties, shape (batch_size, train_data.msg.shape[1] - 32).
        """
        h = torch.cat([self.lin_src(z_src) , self.lin_dst(z_dst)],dim=-1)      
         
        h = self.lin_seq (h)
        
        return h

memory_dim = 100         # node state
time_dim = 100
embedding_dim = 100      # edge embedding

# create the memory
memory = TGNMemory(
    max_node_num,
    train_data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(train_data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

# create the graph neural network with transformer
gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=train_data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

# create the MLP link predictor
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

# set Adam optimizer for all 3 networks
optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(link_pred.parameters()), lr=0.00005, eps=1e-08,weight_decay=0.01)


# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(max_node_num, dtype=torch.long, device=device)

saved_nodes=set()

In [10]:
BATCH=1024

def train(train_data):
    """
    Trains the Temporal Graph Network (TGN) using sequential batches of data.

    Parameters:
    - train_data: Dataset object with sequential batches, timestamps, and messages.

    Returns:
    - float: Average loss over all events in the training data.
    """
    
    # Set networks to training mode
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    saved_nodes=set()

    # Tracks total loss over the training data
    total_loss = 0
    
    # Process each batch in the training dataset
    for batch in train_data.seq_batches(batch_size=BATCH):
        # Reset gradients
        optimizer.zero_grad()

        # Extract batch data: source, destination, timestamps, and messages
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg        
        
        # Retrieve unique nodes involved in the batch and their neighbors
        n_id = torch.cat([src, pos_dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)

        # Helper vector to map global node indices to local ones.
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get updated memory of all nodes involved in the computation.
        z, last_update = memory(n_id)
        
        # Pass embeddings through the GNN to update them
        z = gnn(z, last_update, edge_index, train_data.t[e_id], train_data.msg[e_id])
        
        # Compute predictions for positive edges
        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])       

         # Concatenate predictions
        y_pred = torch.cat([pos_out], dim=0)
        
        # Ground-truth label generation using tensor_find
        # msg contains feature vector of src and dest and edge (event) information.
        # extracting event info from them
        y_true=[]
        for m in msg:
            l=tensor_find(m[16:-16],1)-1
            y_true.append(l)           
          
        y_true = torch.tensor(y_true).to(device=device)
        y_true=y_true.reshape(-1).to(torch.long).to(device=device)
        
        # Compute loss using predictions and ground-truth labels
        loss = criterion(y_pred, y_true)
        
        # Update memory and neighbor loader with ground-truth state.
        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        
        # Backpropagation
        loss.backward()
        optimizer.step()

        # Detach memory to free computation graph
        memory.detach()

        # Accumulate total loss scaled by the number of events in the batch
        total_loss += float(loss) * batch.num_events

    # Return the average loss over all events
    return total_loss / train_data.num_events

In [11]:
# train on benign graphs
train_graphs=[graph_4_4, graph_4_5, graph_4_6]

# train fro 30 epochs
for epoch in tqdm(range(1, 31)):
    for g in train_graphs:
        loss = train(g)
        print(f'  Epoch: {epoch:02d}, Loss: {loss:.4f}')

# store the models in file
model=[memory,gnn, link_pred,neighbor_loader]
os.system("mkdir -p ./models/")
torch.save(model,"./models/model_saved_share.pt")

  0%|          | 0/30 [00:00<?, ?it/s]

  Epoch: 01, Loss: 1.3058
  Epoch: 01, Loss: 1.0811


  3%|▎         | 1/30 [02:55<1:24:54, 175.68s/it]

  Epoch: 01, Loss: 1.0880
  Epoch: 02, Loss: 0.8895
  Epoch: 02, Loss: 0.9006


  7%|▋         | 2/30 [05:59<1:24:09, 180.34s/it]

  Epoch: 02, Loss: 1.0012
  Epoch: 03, Loss: 0.8326
  Epoch: 03, Loss: 0.8404


 10%|█         | 3/30 [09:43<1:30:15, 200.57s/it]

  Epoch: 03, Loss: 0.9658
  Epoch: 04, Loss: 0.8042
  Epoch: 04, Loss: 0.8003


 13%|█▎        | 4/30 [13:49<1:34:41, 218.51s/it]

  Epoch: 04, Loss: 0.9366
  Epoch: 05, Loss: 0.7862
  Epoch: 05, Loss: 0.7731


 17%|█▋        | 5/30 [17:54<1:34:55, 227.82s/it]

  Epoch: 05, Loss: 0.9189
  Epoch: 06, Loss: 0.7727
  Epoch: 06, Loss: 0.7500


 20%|██        | 6/30 [21:54<1:32:46, 231.95s/it]

  Epoch: 06, Loss: 0.9024
  Epoch: 07, Loss: 0.7597
  Epoch: 07, Loss: 0.7283


 23%|██▎       | 7/30 [25:50<1:29:27, 233.35s/it]

  Epoch: 07, Loss: 0.8883
  Epoch: 08, Loss: 0.7491
  Epoch: 08, Loss: 0.7126


 27%|██▋       | 8/30 [30:08<1:28:24, 241.12s/it]

  Epoch: 08, Loss: 0.8781
  Epoch: 09, Loss: 0.7408
  Epoch: 09, Loss: 0.6998


 30%|███       | 9/30 [34:51<1:29:01, 254.37s/it]

  Epoch: 09, Loss: 0.8721
  Epoch: 10, Loss: 0.7343
  Epoch: 10, Loss: 0.6919


 33%|███▎      | 10/30 [39:45<1:28:53, 266.68s/it]

  Epoch: 10, Loss: 0.8658
  Epoch: 11, Loss: 0.7295
  Epoch: 11, Loss: 0.6853


 37%|███▋      | 11/30 [45:21<1:31:07, 287.78s/it]

  Epoch: 11, Loss: 0.8643
  Epoch: 12, Loss: 0.7255
  Epoch: 12, Loss: 0.6794


 40%|████      | 12/30 [51:00<1:31:00, 303.38s/it]

  Epoch: 12, Loss: 0.8484
  Epoch: 13, Loss: 0.7210
  Epoch: 13, Loss: 0.6752


 43%|████▎     | 13/30 [56:40<1:29:06, 314.50s/it]

  Epoch: 13, Loss: 0.8290
  Epoch: 14, Loss: 0.7157
  Epoch: 14, Loss: 0.6648


 47%|████▋     | 14/30 [1:02:32<1:26:51, 325.72s/it]

  Epoch: 14, Loss: 0.8190
  Epoch: 15, Loss: 0.7116
  Epoch: 15, Loss: 0.6609


 50%|█████     | 15/30 [1:08:22<1:23:15, 333.03s/it]

  Epoch: 15, Loss: 0.8141
  Epoch: 16, Loss: 0.7095
  Epoch: 16, Loss: 0.6582


 53%|█████▎    | 16/30 [1:14:25<1:19:50, 342.20s/it]

  Epoch: 16, Loss: 0.8110
  Epoch: 17, Loss: 0.7075
  Epoch: 17, Loss: 0.6550


 57%|█████▋    | 17/30 [1:20:28<1:15:28, 348.37s/it]

  Epoch: 17, Loss: 0.8071
  Epoch: 18, Loss: 0.7055
  Epoch: 18, Loss: 0.6536


 60%|██████    | 18/30 [1:26:31<1:10:31, 352.61s/it]

  Epoch: 18, Loss: 0.8060
  Epoch: 19, Loss: 0.7047
  Epoch: 19, Loss: 0.6511


 63%|██████▎   | 19/30 [1:32:44<1:05:46, 358.79s/it]

  Epoch: 19, Loss: 0.8038
  Epoch: 20, Loss: 0.7039
  Epoch: 20, Loss: 0.6504


 67%|██████▋   | 20/30 [1:38:50<1:00:11, 361.16s/it]

  Epoch: 20, Loss: 0.8012
  Epoch: 21, Loss: 0.7038
  Epoch: 21, Loss: 0.6506


 70%|███████   | 21/30 [1:45:03<54:40, 364.45s/it]  

  Epoch: 21, Loss: 0.8009
  Epoch: 22, Loss: 0.7023
  Epoch: 22, Loss: 0.6474


 73%|███████▎  | 22/30 [1:51:16<48:57, 367.13s/it]

  Epoch: 22, Loss: 0.7991
  Epoch: 23, Loss: 0.7012
  Epoch: 23, Loss: 0.6481


 77%|███████▋  | 23/30 [1:57:24<42:51, 367.40s/it]

  Epoch: 23, Loss: 0.7984
  Epoch: 24, Loss: 0.7000
  Epoch: 24, Loss: 0.6456


 80%|████████  | 24/30 [2:03:38<36:55, 369.24s/it]

  Epoch: 24, Loss: 0.7973
  Epoch: 25, Loss: 0.6996
  Epoch: 25, Loss: 0.6445


 83%|████████▎ | 25/30 [2:09:54<30:57, 371.55s/it]

  Epoch: 25, Loss: 0.7955
  Epoch: 26, Loss: 0.6994
  Epoch: 26, Loss: 0.6443


 87%|████████▋ | 26/30 [2:15:54<24:32, 368.02s/it]

  Epoch: 26, Loss: 0.7954
  Epoch: 27, Loss: 0.6984
  Epoch: 27, Loss: 0.6435


 90%|█████████ | 27/30 [2:22:05<18:26, 368.98s/it]

  Epoch: 27, Loss: 0.7940
  Epoch: 28, Loss: 0.6985
  Epoch: 28, Loss: 0.6422


 93%|█████████▎| 28/30 [2:28:20<12:21, 370.79s/it]

  Epoch: 28, Loss: 0.7935
  Epoch: 29, Loss: 0.6973
  Epoch: 29, Loss: 0.6418


 97%|█████████▋| 29/30 [2:34:37<06:12, 372.49s/it]

  Epoch: 29, Loss: 0.7924
  Epoch: 30, Loss: 0.6965
  Epoch: 30, Loss: 0.6409


100%|██████████| 30/30 [2:40:51<00:00, 321.73s/it]

  Epoch: 30, Loss: 0.7918





# Test

In [12]:
import time 

# Disable gradient computation for inference
@torch.no_grad()

def test_day_new(inference_data, path):
    """
    Evaluate the Temporal Graph Network (TGN) on inference data.
    Stores result for each time window in txt file.
    Each text file contains event information. 
    Example: {'loss': 2.8073678016662598, 'srcnode': 76938, 'dstnode': 868, 'srcmsg': "{'subject': 'system_server'}", 'dstmsg': "{'file': '/data/system/sync/pending.xml'}", 'edge_type': 'EVENT_OPEN', 'time': 1522814400030000000}

    Parameters:
    - inference_data: Dataset containing sequential batches for temporal graph inference.
    - path (str): Path to save logs and results.

    Returns:
    - dict: Summary of results, including loss and timing for each time window.
    """

    if os.path.exists(path):
        pass
    else:
        os.mkdir(path)
    
    # Set networks to evaluation mode
    memory.eval()
    gnn.eval()
    link_pred.eval()
    
    memory.reset_state()  # Start with a fresh memory.  
    neighbor_loader.reset_state()  # Start with an empty graph.
    
    # Initialize tracking variables
    time_with_loss={}
    total_loss = 0    
    edge_list=[]
    
    unique_nodes=torch.tensor([]).to(device=device)
    total_edges=0


    start_time=inference_data.t[0]
    event_count=0
    
    pos_o=[]
    loss_list=[]

    print("after merge:",inference_data)
    
    # Record the running time to evaluate the performance
    start = time.perf_counter()

    for batch in inference_data.seq_batches(batch_size=BATCH):
        
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
        unique_nodes=torch.cat([unique_nodes,src,pos_dst]).unique()
        total_edges+=BATCH
        
        # Retrieve embeddings and neighbors
        n_id = torch.cat([src, pos_dst]).unique()       
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get memory and update embeddings via GNN
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, inference_data.t[e_id], inference_data.msg[e_id])

        # Predict edge properties
        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
        pos_o.append(pos_out)
        y_pred = torch.cat([pos_out], dim=0)

        # Generate ground-truth labels
        y_true=[]
        for m in msg:
            # Extract label index from the message
            l=tensor_find(m[16:-16],1)-1
            y_true.append(l) 
        y_true = torch.tensor(y_true).to(device=device)
        y_true=y_true.reshape(-1).to(torch.long).to(device=device)

        # Only consider which edge hasn't been correctly predicted.
        # For benign graphs, the behaviors patterns are similar and therefore their losses are small
        # For anoamlous behaviors, some behaviors might not be seen before, so the probability of predicting those edges are low. Thus their losses are high.
        loss = criterion(y_pred, y_true)
        total_loss += float(loss) * batch.num_events
        
        # update the edges in the batch to the memory and neighbor_loader
        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        
        # compute the loss for each edge
        each_edge_loss= cal_pos_edges_loss_multiclass(pos_out,y_true)
        
        # Process edges for logging
        for i in range(len(pos_out)):
            srcnode=int(src[i])
            dstnode=int(pos_dst[i])  
            
            srcmsg=str(nodeid2msg[srcnode]) 
            dstmsg=str(nodeid2msg[dstnode])
            t_var=int(t[i])
            edgeindex=tensor_find(msg[i][16:-16],1)   
            edge_type=rel2id[edgeindex]
            loss=each_edge_loss[i]    

            temp_dic={}
            temp_dic['loss']=float(loss)
            temp_dic['srcnode']=srcnode
            temp_dic['dstnode']=dstnode
            temp_dic['srcmsg']=srcmsg
            temp_dic['dstmsg']=dstmsg
            temp_dic['edge_type']=edge_type
            temp_dic['time']=t_var
            
            edge_list.append(temp_dic)
        
        # Check if the time interval is over (15 minutes - Window Size)
        event_count+=len(batch.src)
        if t[-1]>start_time+60000000000*15:
            # Here is a checkpoint, which records all edge losses in the current time window
            time_interval=ns_time_to_datetime_US(start_time)+"~"+ns_time_to_datetime_US(t[-1])

            end = time.perf_counter()
            time_with_loss[time_interval]={'loss':loss,
                                          'nodes_count':len(unique_nodes),
                                          'total_edges':total_edges,
                                          'costed_time':(end-start)}
            
            
            log=open(path+"/"+time_interval+".txt",'w')
            
            # Compute average loss for the interval
            for e in edge_list:
                loss+=e['loss']
            loss=loss/event_count   

            # Save results to log file
            print(f'Time: {time_interval}, Loss: {loss:.4f}, Nodes_count: {len(unique_nodes)}, Cost Time: {(end-start):.2f}s')
            edge_list = sorted(edge_list, key=lambda x:x['loss'],reverse=True)  # Rank the results based on edge losses
            for e in edge_list: 
                log.write(str(e))
                log.write("\n") 
            
            # Reset tracking variables for the next interval
            event_count=0
            total_loss=0
            loss=0
            start_time=t[-1]
            log.close()
            edge_list.clear()

    return time_with_loss

In [13]:
graph_4_7=torch.load("./train_graphs/graph_4_7.TemporalData.simple").to(device=device)
graph_4_10=torch.load("./train_graphs/graph_4_10.TemporalData.simple").to(device=device)
graph_4_11=torch.load("./train_graphs/graph_4_11.TemporalData.simple").to(device=device)

In [14]:
model = torch.load("./models/model_saved_share.pt", map_location=device)
memory,gnn, link_pred, neighbor_loader = model

In [15]:
ans_4_4=test_day_new(graph_4_4,"graph_4_4")

after merge: TemporalData(dst=[1357851], msg=[1357851, 40], src=[1357851], t=[1357851])
Time: 2018-04-04 00:00:00.030000000~2018-04-04 00:18:00.409000000, Loss: 2.1723, Nodes_count: 45, Cost Time: 0.11s
Time: 2018-04-04 00:18:00.409000000~2018-04-04 00:44:00.541000000, Loss: 0.2415, Nodes_count: 81, Cost Time: 0.62s
Time: 2018-04-04 00:44:00.541000000~2018-04-04 01:01:40.901000000, Loss: 0.4488, Nodes_count: 103, Cost Time: 0.89s
Time: 2018-04-04 01:01:40.901000000~2018-04-04 01:28:00.498000000, Loss: 0.3832, Nodes_count: 129, Cost Time: 0.98s
Time: 2018-04-04 01:28:00.498000000~2018-04-04 01:46:29.679000000, Loss: 0.4371, Nodes_count: 145, Cost Time: 1.06s
Time: 2018-04-04 01:46:29.679000000~2018-04-04 02:13:00.519000000, Loss: 0.3564, Nodes_count: 160, Cost Time: 1.15s
Time: 2018-04-04 02:13:00.519000000~2018-04-04 02:30:02.355000000, Loss: 0.5438, Nodes_count: 195, Cost Time: 1.24s
Time: 2018-04-04 02:30:02.355000000~2018-04-04 02:57:00.339000000, Loss: 0.3305, Nodes_count: 215, Cos

In [16]:
ans_4_5=test_day_new(graph_4_5,"graph_4_5")

after merge: TemporalData(dst=[840914], msg=[840914, 40], src=[840914], t=[840914])
Time: 2018-04-05 00:00:00.041000000~2018-04-05 00:19:19.331000000, Loss: 2.0655, Nodes_count: 42, Cost Time: 0.07s
Time: 2018-04-05 00:19:19.331000000~2018-04-05 00:44:22.644000000, Loss: 0.2333, Nodes_count: 75, Cost Time: 0.15s
Time: 2018-04-05 00:44:22.644000000~2018-04-05 01:05:24.697000000, Loss: 0.6920, Nodes_count: 93, Cost Time: 0.24s
Time: 2018-04-05 01:05:24.697000000~2018-04-05 01:27:27.947000000, Loss: 0.6294, Nodes_count: 118, Cost Time: 0.33s
Time: 2018-04-05 01:27:27.947000000~2018-04-05 01:50:05.181000000, Loss: 0.5873, Nodes_count: 138, Cost Time: 0.41s
Time: 2018-04-05 01:50:05.181000000~2018-04-05 02:15:00.642000000, Loss: 0.6035, Nodes_count: 168, Cost Time: 0.50s
Time: 2018-04-05 02:15:00.642000000~2018-04-05 02:37:19.553000000, Loss: 0.3142, Nodes_count: 186, Cost Time: 0.59s
Time: 2018-04-05 02:37:19.553000000~2018-04-05 02:58:26.087000000, Loss: 0.3182, Nodes_count: 216, Cost Tim

In [17]:
ans_4_6=test_day_new(graph_4_6,"graph_4_6")

after merge: TemporalData(dst=[1134670], msg=[1134670, 40], src=[1134670], t=[1134670])
Time: 2018-04-06 00:00:00.050000000~2018-04-06 00:19:16.992000000, Loss: 1.0684, Nodes_count: 43, Cost Time: 0.14s
Time: 2018-04-06 00:19:16.992000000~2018-04-06 00:45:05.853000000, Loss: 0.3859, Nodes_count: 87, Cost Time: 0.48s
Time: 2018-04-06 00:45:05.853000000~2018-04-06 01:00:11.671000000, Loss: 0.3000, Nodes_count: 114, Cost Time: 0.67s
Time: 2018-04-06 01:00:11.671000000~2018-04-06 01:18:57.427000000, Loss: 0.4082, Nodes_count: 126, Cost Time: 0.92s
Time: 2018-04-06 01:18:57.427000000~2018-04-06 01:42:40.184000000, Loss: 0.1656, Nodes_count: 141, Cost Time: 1.10s
Time: 2018-04-06 01:42:40.184000000~2018-04-06 02:05:25.221000000, Loss: 0.4362, Nodes_count: 176, Cost Time: 1.45s
Time: 2018-04-06 02:05:25.221000000~2018-04-06 02:22:54.336000000, Loss: 0.3202, Nodes_count: 195, Cost Time: 1.73s
Time: 2018-04-06 02:22:54.336000000~2018-04-06 02:44:45.281000000, Loss: 0.1943, Nodes_count: 225, Cos

In [18]:
ans_4_7=test_day_new(graph_4_7,"graph_4_7")

after merge: TemporalData(dst=[1847921], msg=[1847921, 40], src=[1847921], t=[1847921])
Time: 2018-04-07 00:00:00.040000000~2018-04-07 00:30:11.169000000, Loss: 1.4250, Nodes_count: 75, Cost Time: 0.15s
Time: 2018-04-07 00:30:11.169000000~2018-04-07 00:55:07.465000000, Loss: 0.6400, Nodes_count: 99, Cost Time: 0.32s
Time: 2018-04-07 00:55:07.465000000~2018-04-07 01:18:15.361000000, Loss: 0.4791, Nodes_count: 119, Cost Time: 0.41s
Time: 2018-04-07 01:18:15.361000000~2018-04-07 01:44:15.385000000, Loss: 0.5063, Nodes_count: 142, Cost Time: 0.49s
Time: 2018-04-07 01:44:15.385000000~2018-04-07 02:15:00.653000000, Loss: 0.5642, Nodes_count: 176, Cost Time: 0.67s
Time: 2018-04-07 02:15:00.653000000~2018-04-07 02:30:22.992000000, Loss: 0.8789, Nodes_count: 201, Cost Time: 0.85s
Time: 2018-04-07 02:30:22.992000000~2018-04-07 02:45:49.212000000, Loss: 0.7982, Nodes_count: 210, Cost Time: 0.96s
Time: 2018-04-07 02:45:49.212000000~2018-04-07 03:11:00.395000000, Loss: 0.4518, Nodes_count: 227, Cos

In [19]:
ans_4_10=test_day_new(graph_4_10,"graph_4_10")

after merge: TemporalData(dst=[2554245], msg=[2554245, 40], src=[2554245], t=[2554245])
Time: 2018-04-10 00:00:00.041000000~2018-04-10 01:14:00.502000000, Loss: 2.2879, Nodes_count: 81, Cost Time: 0.07s
Time: 2018-04-10 01:14:00.502000000~2018-04-10 02:24:00.583000000, Loss: 0.8633, Nodes_count: 156, Cost Time: 0.16s
Time: 2018-04-10 02:24:00.583000000~2018-04-10 03:30:37.185000000, Loss: 1.4760, Nodes_count: 226, Cost Time: 0.25s
Time: 2018-04-10 03:30:37.185000000~2018-04-10 04:51:00.454000000, Loss: 1.1623, Nodes_count: 300, Cost Time: 0.34s
Time: 2018-04-10 04:51:00.454000000~2018-04-10 06:00:07.009000000, Loss: 0.8640, Nodes_count: 367, Cost Time: 0.43s
Time: 2018-04-10 06:00:07.009000000~2018-04-10 07:04:52.138000000, Loss: 0.9286, Nodes_count: 442, Cost Time: 0.53s
Time: 2018-04-10 07:04:52.138000000~2018-04-10 07:19:59.157000000, Loss: 0.7921, Nodes_count: 1100, Cost Time: 3.18s
Time: 2018-04-10 07:19:59.157000000~2018-04-10 07:52:43.143000000, Loss: 0.6916, Nodes_count: 1151, 

In [20]:
ans_4_11=test_day_new(graph_4_11,"graph_4_11")

after merge: TemporalData(dst=[1976440], msg=[1976440, 40], src=[1976440], t=[1976440])
Time: 2018-04-11 00:00:00.063000000~2018-04-11 02:00:00.161000000, Loss: 2.3246, Nodes_count: 127, Cost Time: 0.07s
Time: 2018-04-11 02:00:00.161000000~2018-04-11 03:45:00.251000000, Loss: 1.1766, Nodes_count: 230, Cost Time: 0.16s
Time: 2018-04-11 03:45:00.251000000~2018-04-11 05:30:00.544000000, Loss: 1.2557, Nodes_count: 333, Cost Time: 0.26s
Time: 2018-04-11 05:30:00.544000000~2018-04-11 07:00:52.974000000, Loss: 1.2609, Nodes_count: 439, Cost Time: 0.35s
Time: 2018-04-11 07:00:52.974000000~2018-04-11 07:25:16.408000000, Loss: 0.7562, Nodes_count: 482, Cost Time: 0.65s
Time: 2018-04-11 07:25:16.408000000~2018-04-11 07:46:44.962000000, Loss: 0.8973, Nodes_count: 514, Cost Time: 0.85s
Time: 2018-04-11 07:46:44.962000000~2018-04-11 08:05:53.555000000, Loss: 0.7641, Nodes_count: 548, Cost Time: 0.96s
Time: 2018-04-11 08:05:53.555000000~2018-04-11 08:21:10.246000000, Loss: 0.8983, Nodes_count: 744, C

# Initialize the node IDF

In [21]:
node_set=set()

file_list=[]

file_path="graph_4_4/"
file_l=os.listdir("graph_4_4/")
for i in file_l:
    file_list.append(file_path+i)

file_path="graph_4_5/"
file_l=os.listdir("graph_4_5/")
for i in file_l:
    file_list.append(file_path+i)

file_path="graph_4_6/"
file_l=os.listdir("graph_4_6/")
for i in file_l:
    file_list.append(file_path+i)


file_path="graph_4_7/"
file_l=os.listdir("graph_4_7/")
for i in file_l:
    file_list.append(file_path+i)

node_IDF={}
node_set = {}
for f_path in tqdm(file_list):
    f=open(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        jdata=eval(l)
        if jdata['loss']>0:
            if 'netflow' not in str(jdata['srcmsg']):
                if str(jdata['srcmsg']) not in node_set.keys():
                    node_set[str(jdata['srcmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['srcmsg'])].add(f_path)
            if 'netflow' not in str(jdata['dstmsg']):
                if str(jdata['dstmsg']) not in node_set.keys():
                    node_set[str(jdata['dstmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['dstmsg'])].add(f_path)
for n in node_set:
    include_count = len(node_set[n])   
    IDF=math.log(len(file_list)/(include_count+1))
    node_IDF[n] = IDF    


torch.save(node_IDF,"node_IDF")
print("IDF weight calculate complete!")


100%|██████████| 293/293 [03:00<00:00,  1.63it/s]

IDF weight calculate complete!





In [22]:
def cal_train_IDF(find_str,file_list):
    include_count=0
    for f_path in (file_list):
        f=open(f_path)
        if find_str in f.read():
            include_count+=1             
    IDF=math.log(len(file_list)/(include_count+1))
    return IDF


def cal_IDF(find_str,file_path,file_list):
    file_list=os.listdir(file_path)
    include_count=0
    different_neighbor=set()
    for f_path in (file_list):
        f=open(file_path+f_path)
        if find_str in f.read():
            include_count+=1                
                
    IDF=math.log(len(file_list)/(include_count+1))
    
    return IDF,1

def cal_redundant(find_str,edge_list):
    
    different_neighbor=set()
    for e in edge_list:
        if find_str in str(e):
            different_neighbor.add(e[0])
            different_neighbor.add(e[1])
    return len(different_neighbor)-2

def cal_anomaly_loss(loss_list,edge_list,file_path):
    
    if len(loss_list)!=len(edge_list):
        print("error!")
        return 0
    count=0
    loss_sum=0
    loss_std=std(loss_list)
    loss_mean=mean(loss_list)
    edge_set=set()
    node_set=set()
    node2redundant={}
    
    thr=loss_mean+1.5*loss_std

    print("thr:",thr)

    for i in range(len(loss_list)):
        if loss_list[i]>thr:
            count+=1
            src_node=edge_list[i][0]
            dst_node=edge_list[i][1]
            
            loss_sum+=loss_list[i]
    
            node_set.add(src_node)
            node_set.add(dst_node)
            edge_set.add(edge_list[i][0]+edge_list[i][1])
    return count, loss_sum/(count + 0.00001) ,node_set,edge_set
#     return count, count/len(loss_list)

# Construct the relations between time windows

In [23]:

def is_include_key_word(s):
    keywords=[
         'netflow',

        'glx_alsa_675',
        '/data/system/',
         '/storage/emulated/',
        '/data/data/com.android',
        '/proc/',
        'nz9885vc.default',
      
      ]
    flag=False
    for i in keywords:
        if i in s:
            flag=True
    return flag



def cal_set_rel(s1,s2,node_IDF, file_list):
    new_s=s1 & s2
    count=0
    for i in new_s:
#     jdata=json.loads(i)
        if is_include_key_word(i) is False :
            if i in node_IDF.keys():
                IDF=node_IDF[i]
            else:
                IDF=math.log(len(file_list)/(1))
            if IDF>6:
                print("node:",i," IDF:",IDF)
                count+=1
    return count





# def cal_set_rel(s1,s2,node_IDF, file_list, node_IDF_4_4_7, file_list_4_4_7):
#     new_s=s1 & s2
#     count=0
#     for i in new_s:
# #     jdata=json.loads(i)
#         if 'netflow' not in i and 'glx_alsa_675' not in i and '/data/system/' not in i and '/storage/emulated/' not in i and  '/data/data/com.android' not in i and  '/proc/' not in i and 'nz9885vc.default' not in i :

# #         'netflow' not in i
# #         and 'usr' not in i and 'var' not in i
#             if i in node_IDF.keys():
#                 IDF=node_IDF[i]
#             else:
#                 IDF=math.log(len(file_list)/(1))
                
#             if i in node_IDF_4_4_7.keys():
#                 IDF4=node_IDF_4_4_7[i]
#             else:
#                 IDF4=math.log(len(file_list_4_4_7)/(1))    
            
# #             print(IDF)
#             if (IDF+IDF4)>9:
#                 print("node:",i," IDF:",IDF+IDF4)
#                 count+=1
#     return count



# label generation

In [24]:
labels={}
    
    
filelist = os.listdir("graph_4_10")
for f in filelist:
    labels["graph_4_10/"+f]=0

filelist = os.listdir("graph_4_11")
for f in filelist:
    labels["graph_4_11/"+f]=0

In [25]:
attack_list=[
    'graph_4_11/2018-04-11 13:46:38.658000000~2018-04-11 14:02:21.103000000.txt',
    'graph_4_11/2018-04-11 14:02:21.103000000~2018-04-11 14:18:19.001000000.txt',
    'graph_4_11/2018-04-11 14:18:19.001000000~2018-04-11 14:33:38.600000000.txt',
    'graph_4_11/2018-04-11 14:33:38.600000000~2018-04-11 14:49:05.326000000.txt',
    'graph_4_11/2018-04-11 14:49:05.326000000~2018-04-11 15:04:48.749000000.txt',
]
for i in attack_list:
    labels[i]=1

In [26]:
print(f"Benign count: {len(labels.values()) - sum(labels.values())}")
print(f"Attack count: {sum(labels.values())}")

Benign count: 114
Attack count: 5


In [27]:
pred_label={}

filelist = os.listdir("graph_4_10/")
for f in filelist:
    pred_label["graph_4_10/"+f]=0
    
filelist = os.listdir("graph_4_11/")
for f in filelist:
    pred_label["graph_4_11/"+f]=0

In [28]:
file_list=[]

file_path="graph_4_4/"
file_l=os.listdir("graph_4_4/")
for i in file_l:
    file_list.append(file_path+i)

file_path="graph_4_5/"
file_l=os.listdir("graph_4_5/")
for i in file_l:
    file_list.append(file_path+i)

file_path="graph_4_6/"
file_l=os.listdir("graph_4_6/")
for i in file_l:
    file_list.append(file_path+i)


file_path="graph_4_7/"
file_l=os.listdir("graph_4_7/")
for i in file_l:
    file_list.append(file_path+i)

# 4-10

In [29]:
# node_IDF=torch.load("node_IDF_4_10")
# node_IDF_4_7=torch.load("node_IDF_4_4-7")
node_IDF_4_4_7=torch.load("node_IDF")
y_data_4_10=[]
df_list_4_10=[]
# node_set_list=[]
history_list=[]
tw_que=[]
his_tw={}
current_tw={}
loss_list_4_10=[]


file_l=os.listdir("graph_4_10")
index_count=0
for f_path in sorted(file_l):
    f=open("graph_4_10/"+f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
#     print(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_4_10.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_4_10/")
    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list:
        for his_tw in hq:
            if cal_set_rel(current_tw['nodeset'],his_tw['nodeset'],node_IDF_4_4_7, file_list)!=0 and current_tw['name']!=his_tw['name']:
#                 print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
            if added_que_flag:
                break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list.append(temp_hq)
    index_count+=1
    loss_list_4_10.append(loss_avg)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))

index_count: 0
thr: 2.8531981129075934
2018-04-10 00:00:00.041000000~2018-04-10 01:14:00.502000000.txt    0.0  count: 0  percentage: 0.0  node count: 0  edge count: 0
index_count: 1
thr: 2.62544484760982
2018-04-10 01:14:00.502000000~2018-04-10 02:24:00.583000000.txt    3.256865427079852  count: 111  percentage: 0.1083984375  node count: 71  edge count: 69
index_count: 2
thr: 4.441793639232343
2018-04-10 02:24:00.583000000~2018-04-10 03:30:37.185000000.txt    6.022709648928218  count: 119  percentage: 0.1162109375  node count: 5  edge count: 3
index_count: 3
thr: 3.5811508321365713
2018-04-10 03:30:37.185000000~2018-04-10 04:51:00.454000000.txt    4.999476101846638  count: 110  percentage: 0.107421875  node count: 6  edge count: 5
index_count: 4
thr: 2.608054670654323
2018-04-10 04:51:00.454000000~2018-04-10 06:00:07.009000000.txt    3.9850593269069865  count: 54  percentage: 0.052734375  node count: 11  edge count: 9
index_count: 5
thr: 2.6798890573177845
2018-04-10 06:00:07.009000000

In [30]:
name_list=[]
for hl in history_list:
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)
#     name_list=[]
    if loss_count>9:
        name_list=[]
        for i in hl:
            name_list.append(i['name']) 
        print(name_list)
        for i in name_list:
            pred_label[i]=1
        print(loss_count)

# 4-11

In [31]:
# node_IDF=torch.load("node_IDF_4_11")
# node_IDF_4_4_7=torch.load("node_IDF_4_4-7")
node_IDF_4_4_7=torch.load("node_IDF")
y_data_4_11=[]
df_list_4_11=[]
# node_set_list=[]
history_list=[]
tw_que=[]
his_tw={}
current_tw={}

loss_list_4_11=[]

file_path_list=[]


file_path="graph_4_11/"
file_l=os.listdir("graph_4_11/")
for i in file_l:
    file_path_list.append(file_path+i)

index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_4_11.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_4_11/")

    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list:
        for his_tw in hq:
            if cal_set_rel(current_tw['nodeset'],his_tw['nodeset'],node_IDF_4_4_7, file_list)!=0 and current_tw['name']!=his_tw['name']:
#                 print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
            if added_que_flag:
                break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list.append(temp_hq)
    index_count+=1
    loss_list_4_11.append(loss_avg)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))

index_count: 0
thr: 2.9659018460488697
graph_4_11/2018-04-11 00:00:00.063000000~2018-04-11 02:00:00.161000000.txt    0.0  count: 0  percentage: 0.0  node count: 0  edge count: 0
index_count: 1
thr: 2.7465355863712353
graph_4_11/2018-04-11 02:00:00.161000000~2018-04-11 03:45:00.251000000.txt    3.791358893114378  count: 35  percentage: 0.0341796875  node count: 7  edge count: 5
index_count: 2
thr: 2.908069439074869
graph_4_11/2018-04-11 03:45:00.251000000~2018-04-11 05:30:00.544000000.txt    4.347204530669758  count: 36  percentage: 0.03515625  node count: 6  edge count: 4
index_count: 3
thr: 2.874940965561514
graph_4_11/2018-04-11 05:30:00.544000000~2018-04-11 07:00:52.974000000.txt    4.243466263908712  count: 21  percentage: 0.0205078125  node count: 4  edge count: 2
index_count: 4
thr: 1.7660760315509239
graph_4_11/2018-04-11 07:00:52.974000000~2018-04-11 07:25:16.408000000.txt    2.49229529783925  count: 354  percentage: 0.08642578125  node count: 59  edge count: 56
index_count: 5


In [32]:
name_list=[]
for hl in history_list:
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)
#     name_list=[]
    if loss_count>9:
        name_list=[]
        for i in hl:
            name_list.append(i['name']) 
        print(name_list)
        for i in name_list:
            pred_label[i]=1
        print(loss_count)

In [33]:
from sklearn.metrics import average_precision_score, roc_auc_score

from sklearn.metrics import roc_auc_score
import torch
from sklearn import preprocessing
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix

def plot_thr():
    np.seterr(invalid='ignore')
    step=0.01
    thr_list=torch.arange(-5,5,step)
    
    

    precision_list=[]
    recall_list=[]
    fscore_list=[]
    accuracy_list=[]
    auc_val_list=[]
    for thr in thr_list:
        threshold=thr
        y_prediction=[]
        for i in y_test_scores:
            if i >threshold:
                y_prediction.append(1)
            else:
                y_prediction.append(0)
        precision,recall,fscore,accuracy,auc_val=classifier_evaluation(y_test, y_prediction)   
        precision_list.append(float(precision))
        recall_list.append(float(recall))
        fscore_list.append(float(fscore))
        accuracy_list.append(float(accuracy))
        auc_val_list.append(float(auc_val))

    max_fscore=max(fscore_list)
    max_fscore_index=fscore_list.index(max_fscore)
    print(max_fscore_index)
    print("max threshold:",thr_list[max_fscore_index])
    print('precision:',precision_list[max_fscore_index])
    print('recall:',recall_list[max_fscore_index])
    print('fscore:',fscore_list[max_fscore_index])
    print('accuracy:',accuracy_list[max_fscore_index])    
    print('auc:',auc_val_list[max_fscore_index])
    
        
     # list tensor
#     precision_list=torch.tensor(precision_list)   
#     recall_list=torch.tensor(recall_list)   
#     fscore_list=torch.tensor(fscore_list)   
#     accuracy_list=torch.tensor(accuracy_list)   
#     auc_val_list=torch.tensor(auc_val_list)   

    


    
    # plt.scatter(attack_x, attack_y, s=20, c='r', label='Attack graph',marker='*')
    # plt.scatter(bengin_x, bengin_y, s=20, c='g', label='Bengin graph',marker='1')
    # plt.scatter(bengin_x, bengin_y, s=20, c='g', label='Bengin graph',marker='1')

    plt.plot(thr_list,precision_list,color='red',label='precision',linewidth=2.0,linestyle='-')
    plt.plot(thr_list,recall_list,color='orange',label='recall',linewidth=2.0,linestyle='solid')
    plt.plot(thr_list,fscore_list,color='y',label='F-score',linewidth=2.0,linestyle='dashed')
    plt.plot(thr_list,accuracy_list,color='g',label='accuracy',linewidth=2.0,linestyle='dashdot')
    plt.plot(thr_list,auc_val_list,color='b',label='auc_val',linewidth=2.0,linestyle='dotted')
    # '-', '--', '-.', ':', 'None', ' ', '', 'solid', 'dashed', 'dashdot', 'dotted'


    # plt.scatter(turnovers, graph_loss, c=color)
    plt.xlabel("Threshold", fontdict={'size': 16})
    plt.ylabel("Rate", fontdict={'size': 16})
    plt.title("Different evaluation Indicators by varying threshold value", fontdict={'size': 12})
    plt.legend(loc='best', fontsize=12, markerscale=0.5)
    plt.show()

def classifier_evaluation(y_test, y_test_pred):
    # groundtruth, pred_value
    tn, fp, fn, tp =confusion_matrix(y_test, y_test_pred).ravel()
#     tn+=100
#     print(clf_name," : ")
    print('tn:',tn)
    print('fp:',fp)
    print('fn:',fn)
    print('tp:',tp)
    precision=tp/(tp+fp)
    recall=tp/(tp+fn)
    accuracy=(tp+tn)/(tp+tn+fp+fn)
    fscore=2*(precision*recall)/(precision+recall)    
    auc_val=roc_auc_score(y_test, y_test_pred)
    print("precision:",precision)
    print("recall:",recall)
    print("fscore:",fscore)
    print("accuracy:",accuracy)
    print("auc_val:",auc_val)
    return precision,recall,fscore,accuracy,auc_val

def minmax(data):
    min_val=min(data)
    max_val=max(data)
    ans=[]
    for i in data:
        ans.append((i-min_val)/(max_val-min_val))
    return ans




In [34]:
y=[]
y_pred=[]
for i in labels:
    y.append(labels[i])
    y_pred.append(pred_label[i])

In [35]:
classifier_evaluation(y,y_pred)

tn: 114
fp: 0
fn: 5
tp: 0
precision: nan
recall: 0.0
fscore: nan
accuracy: 0.957983193277311
auc_val: 0.5


  precision=tp/(tp+fp)


(nan, 0.0, nan, 0.957983193277311, 0.5)

# Count attack edge numbers

In [36]:
def keyword_hit(line):
    attack_nodes=[
            'shared_files',
        'csb.tracee.27331.27355',
        'netrecon',
#         '/data/data/org.mozilla.fennec_firefox_dev/',
     
#             'firefox',
        '153.178.46.202',
       '111.82.111.27',
        '166.199.230.185',
        '140.57.183.17',
      
        
        ]
    flag=False
    for i in attack_nodes:
        if i in line:
            flag=True
            break
    return flag



files=[
    
        'graph_4_11/2018-04-11 13:46:38.658000000~2018-04-11 14:02:21.103000000.txt',
    'graph_4_11/2018-04-11 14:02:21.103000000~2018-04-11 14:18:19.001000000.txt',
    'graph_4_11/2018-04-11 14:18:19.001000000~2018-04-11 14:33:38.600000000.txt',
    'graph_4_11/2018-04-11 14:33:38.600000000~2018-04-11 14:49:05.326000000.txt',
    'graph_4_11/2018-04-11 14:49:05.326000000~2018-04-11 15:04:48.749000000.txt',
]


In [37]:
attack_edge_count=0
for fpath in tqdm(files):
    f=open(fpath)
    for line in f:
        if keyword_hit(line):
            attack_edge_count+=1
print(attack_edge_count)

100%|██████████| 5/5 [00:00<00:00, 26.18it/s]

647





# Visualization

In [38]:
import os

from graphviz import Digraph
import networkx as nx
import datetime
import community.community_louvain as community_louvain
from tqdm import tqdm



# Some common path abstraction for visualization
replace_dic={
 '/data/data/org.mozilla.fennec_firefox_dev/cache/':'/data/data/org.mozilla.fennec_firefox_dev/cache/*',
     '/data/data/org.mozilla.fennec_firefox_dev/files/':'/data/data/org.mozilla.fennec_firefox_dev/files/*',
    '/system/fonts/':'/system/fonts/*',
    '/data/data/com.android.email/cache/':'/data/data/com.android.email/cache/*',
    '/data/data/com.android.email/files/':'/data/data/com.android.email/files/*',
    'UNNAMED':'UNNAMED:*',
    
}


def replace_path_name(path_name):
    for i in replace_dic:
        if i in path_name:
            return replace_dic[i]
    return path_name


# Users should manually put the detected anomalous time windows here
attack_list = [
        'graph_4_11/2018-04-11 13:46:38.658000000~2018-04-11 14:02:21.103000000.txt',
    'graph_4_11/2018-04-11 14:02:21.103000000~2018-04-11 14:18:19.001000000.txt',
    'graph_4_11/2018-04-11 14:18:19.001000000~2018-04-11 14:33:38.600000000.txt',
    'graph_4_11/2018-04-11 14:33:38.600000000~2018-04-11 14:49:05.326000000.txt',
    'graph_4_11/2018-04-11 14:49:05.326000000~2018-04-11 15:04:48.749000000.txt',
]

original_edges_count = 0
graphs = []
gg = nx.DiGraph()
count = 0
for path in tqdm(attack_list):
    if ".txt" in path:
        line_count = 0
        node_set = set()
        tempg = nx.DiGraph()
        f = open(path, "r")
        edge_list = []
        for line in f:
            count += 1
            l = line.strip()
            jdata = eval(l)
            edge_list.append(jdata)

        edge_list = sorted(edge_list, key=lambda x: x['loss'], reverse=True)
        original_edges_count += len(edge_list)

        loss_list = []
        for i in edge_list:
            loss_list.append(i['loss'])
        loss_mean = mean(loss_list)
        loss_std = std(loss_list)
        print(loss_mean)
        print(loss_std)
        thr = loss_mean + 1.5 * loss_std
        print("thr:", thr)
        for e in edge_list:
            if e['loss'] > thr:
                tempg.add_edge(str(hashgen(replace_path_name(e['srcmsg']))),
                               str(hashgen(replace_path_name(e['dstmsg']))))
                gg.add_edge(str(hashgen(replace_path_name(e['srcmsg']))), str(hashgen(replace_path_name(e['dstmsg']))),
                            loss=e['loss'], srcmsg=e['srcmsg'], dstmsg=e['dstmsg'], edge_type=e['edge_type'],
                            time=e['time'])


partition = community_louvain.best_partition(gg.to_undirected())

# Generate the candidate subgraphs based on community discovery results
communities = {}
max_partition = 0
for i in partition:
    if partition[i] > max_partition:
        max_partition = partition[i]
for i in range(max_partition + 1):
    communities[i] = nx.DiGraph()
for e in gg.edges:
    communities[partition[e[0]]].add_edge(e[0], e[1])
    communities[partition[e[1]]].add_edge(e[0], e[1])


# Define the attack nodes. They are **only be used to plot the colors of attack nodes and edges**.
# They won't change the detection results.
# Didn't add too much nodes for coloring. Most of the results are compared with the ground truth documentations manually
def attack_edge_flag(msg):
    attack_nodes = [
        '/data/data/org.mozilla.fennec_firefox_dev/',
        '/data/data/org.mozilla.fennec_firefox_dev/shared_files',
        '/data/local/tmp',
        'csb.tracee.27331.27355',
        '/data/data/org.mozilla.fennec_firefox_dev/csb.tracee.27331.27355',
        '111.82.111.27',
        '166.199.230.185',
        'glx_alsa_675',
    ]
    flag = False
    for i in attack_nodes:
        if i in str(msg):
            flag = True
    return flag


# Plot and render candidate subgraph
os.system(f"mkdir -p ./graph_visual/")
graph_index = 0
for c in communities:
    dot = Digraph(name="MyPicture", comment="the test", format="pdf")
    dot.graph_attr['rankdir'] = 'LR'

    for e in communities[c].edges:
        try:
            temp_edge = gg.edges[e]
            srcnode = e['srcnode']
            dstnode = e['dstnode']
        except:
            pass

        if True:
            # source node
            if "'subject': '" in temp_edge['srcmsg']:
                src_shape = 'box'
            elif "'file': '" in temp_edge['srcmsg']:
                src_shape = 'oval'
            elif "'netflow': '" in temp_edge['srcmsg']:
                src_shape = 'diamond'
            if attack_edge_flag(temp_edge['srcmsg']):
                src_node_color = 'red'
            else:
                src_node_color = 'blue'
            dot.node(name=str(hashgen(replace_path_name(temp_edge['srcmsg']))), label=str(
                replace_path_name(temp_edge['srcmsg']) + str(
                    partition[str(hashgen(replace_path_name(temp_edge['srcmsg'])))])), color=src_node_color,
                     shape=src_shape)

            # destination node
            if "'subject': '" in temp_edge['dstmsg']:
                dst_shape = 'box'
            elif "'file': '" in temp_edge['dstmsg']:
                dst_shape = 'oval'
            elif "'netflow': '" in temp_edge['dstmsg']:
                dst_shape = 'diamond'
            if attack_edge_flag(temp_edge['dstmsg']):
                dst_node_color = 'red'
            else:
                dst_node_color = 'blue'
            dot.node(name=str(hashgen(replace_path_name(temp_edge['dstmsg']))), label=str(
                replace_path_name(temp_edge['dstmsg']) + str(
                    partition[str(hashgen(replace_path_name(temp_edge['dstmsg'])))])), color=dst_node_color,
                     shape=dst_shape)

            if attack_edge_flag(temp_edge['srcmsg']) and attack_edge_flag(temp_edge['dstmsg']):
                edge_color = 'red'
            else:
                edge_color = 'blue'
            dot.edge(str(hashgen(replace_path_name(temp_edge['srcmsg']))),
                     str(hashgen(replace_path_name(temp_edge['dstmsg']))), label=temp_edge['edge_type'],
                     color=edge_color)

    dot.render(f'./graph_visual/subgraph_' + str(graph_index), view=False)
    graph_index += 1

 20%|██        | 1/5 [00:01<00:04,  1.17s/it]

1.6285169453354447
1.6305574435069978
thr: 4.074353110595942


 40%|████      | 2/5 [00:02<00:02,  1.00it/s]

1.6952151654125829
1.4655887472966909
thr: 3.8935982863576193


 60%|██████    | 3/5 [00:02<00:01,  1.05it/s]

1.6349986380277541
1.4537517902799393
thr: 3.815626323447663


 80%|████████  | 4/5 [00:03<00:00,  1.14it/s]

0.9074441219848823
0.9450880798032311
thr: 2.325076241689729


100%|██████████| 5/5 [00:04<00:00,  1.11it/s]

1.340633728827119
1.425944713629589
thr: 3.479550799271503





In [39]:
#END