In [None]:
# encoding=utf-8
import os.path as osp
import os
import copy
import matplotlib.pyplot as plt
import torch
from torch.nn import Linear
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.data import TemporalData
from torch_geometric.datasets import JODIEDataset
from torch_geometric.datasets import ICEWS18
from torch_geometric.nn import TGNMemory, TransformerConv
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models.tgn import (LastNeighborLoader, IdentityMessage, MeanAggregator,
                                           LastAggregator)
from torch_geometric import *
from torch_geometric.utils import negative_sampling
from tqdm import tqdm
import networkx as nx
import numpy as np
import math
import copy
import re
import time
import json
import pandas as pd
from random import choice
import gc
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device='cpu'
# msg structure:      [src_node_feature,edge_attr,dst_node_feature]

# compute the best partition 
import datetime
# import community as community_louvain

import xxhash

#  Find the edge index which the edge vector is corresponding to
def tensor_find(t,x):
    """
    Find the 1-based row index of the first occurrence of a value in a PyTorch tensor.

    Parameters:
    t (torch.Tensor): The PyTorch tensor to search.
    x (Any): The value to locate within the tensor.

    Returns:
    int: The 1-based row index of the first occurrence of the value in the tensor.
    """
    t_np=t.cpu().numpy()
    idx=np.argwhere(t_np==x)
    return idx[0][0]+1


def std(t):
    """
    Calculate the standard deviation of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the standard deviation.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The standard deviation of the input data.
    """
    t = np.array(t)
    return np.std(t)


def var(t):
    """
    Calculate the variance of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the variance.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The variance of the input data.
    """
    t = np.array(t)
    return np.var(t)


def mean(t):
    """
    Calculate the mean of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the mean.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The mean of the input data.
    """
    t = np.array(t)
    return np.mean(t)

def hashgen(l):
    """
    Generate a single hash value from a list of string values.

    Parameters:
    l (list of str): A list of string values, which can represent properties of a node or edge.

    Returns:
    int: A single hashed integer value generated from the input list.
    """
    hasher = xxhash.xxh64()
    for e in l:
        hasher.update(e)
    return hasher.intdigest()


def cal_pos_edges_loss(link_pred_ratio):
    """
    Calculate the loss for positive edges using a binary classification approach.

    Parameters:
    link_pred_ratio (list or Tensor): A list or tensor containing predicted link probabilities.

    Returns:
    Tensor: A tensor containing the computed loss for each prediction.
    """
    loss=[]
    for i in link_pred_ratio:
        loss.append(criterion(i,torch.ones(1)))
    return torch.tensor(loss)

def cal_pos_edges_loss_multiclass(link_pred_ratio,labels):
    """
    Calculate the loss for positive edges in a multi-class classification setting.

    Parameters:
    link_pred_ratio (list of Tensors): A list of tensors containing predicted class probabilities for links.
    labels (list of Tensors): A list of tensors containing the ground truth labels for the links.

    Returns:
    Tensor: A tensor containing the computed loss for each prediction.
    """
    loss=[] 
    for i in range(len(link_pred_ratio)):
        loss.append(criterion(link_pred_ratio[i].reshape(1,-1),labels[i].reshape(-1)))
    return torch.tensor(loss)

In [None]:
def cal_pos_edges_loss_autoencoder(decoded,msg):
    """
    Calculate the loss for positive edges in an autoencoder setting.

    Parameters:
    decoded (list of Tensors): A list of tensors representing the decoded outputs.
    msg (list of Tensors): A list of tensors representing the original messages (inputs to the autoencoder).

    Returns:
    Tensor: A tensor containing the computed loss for each decoded message.
    """
    loss=[] 
    for i in range(len(decoded)):
        loss.append(criterion(decoded[i].reshape(-1),msg[i].reshape(-1)))
    return torch.tensor(loss)

In [4]:
from datetime import datetime, timezone
import time
import pytz
from time import mktime
from datetime import datetime
import time
def ns_time_to_datetime(ns):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00.000000000
    """
    dt = datetime.fromtimestamp(int(ns) // 1000000000)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')
    s += '.' + str(int(int(ns) % 1000000000)).zfill(9)
    return s

def ns_time_to_datetime_US(ns):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00.000000000
    """
    tz = pytz.timezone('US/Eastern')
    dt = pytz.datetime.datetime.fromtimestamp(int(ns) // 1000000000, tz)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')
    s += '.' + str(int(int(ns) % 1000000000)).zfill(9)
    return s

def time_to_datetime_US(s):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00
    """
    tz = pytz.timezone('US/Eastern')
    dt = pytz.datetime.datetime.fromtimestamp(int(s), tz)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')

    return s

def datetime_to_ns_time(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    timeStamp = int(time.mktime(timeArray))
    timeStamp = timeStamp * 1000000000
    return timeStamp

def datetime_to_ns_time_US(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    tz = pytz.timezone('US/Eastern')
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    dt = datetime.fromtimestamp(mktime(timeArray))
    timestamp = tz.localize(dt)
    timestamp = timestamp.timestamp()
    timeStamp = timestamp * 1000000000
    return int(timeStamp)

def datetime_to_timestamp_US(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    tz = pytz.timezone('US/Eastern')
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    dt = datetime.fromtimestamp(mktime(timeArray))
    timestamp = tz.localize(dt)
    timestamp = timestamp.timestamp()
    timeStamp = timestamp
    return int(timeStamp)

In [5]:
rel2id={1: 'EVENT_CONNECT',
 'EVENT_CONNECT': 1,
 2: 'EVENT_EXECUTE',
 'EVENT_EXECUTE': 2,
 3: 'EVENT_OPEN',
 'EVENT_OPEN': 3,
 4: 'EVENT_READ',
 'EVENT_READ': 4,
 5: 'EVENT_RECVFROM',
 'EVENT_RECVFROM': 5,
 6: 'EVENT_RECVMSG',
 'EVENT_RECVMSG': 6,
 7: 'EVENT_SENDMSG',
 'EVENT_SENDMSG': 7,
 8: 'EVENT_SENDTO',
 'EVENT_SENDTO': 8,
 9: 'EVENT_WRITE',
 'EVENT_WRITE': 9}

In [6]:
import psycopg2

from psycopg2 import extras as ex
connect = psycopg2.connect(database = 'tc_theia_dataset_db',
                           host = 'localhost',
                           user = 'postgres',
                           password = '123456',
                           port = '5432'
                          )

cur = connect.cursor()

In [None]:
# Constructing the map for nodeid to msg
sql="select * from node2id ORDER BY index_id;"
cur.execute(sql)
rows = cur.fetchall()

# node hash => nodeid (index) and nodeid (index) => msg (dictionary)
nodeid2msg={}
for i in rows:
    nodeid2msg[i[0]]=i[-1]
    nodeid2msg[i[-1]]={i[1]:i[2]}  #0-828297

## Load data

In [8]:
graph_4_3=torch.load("./train_graph/graph_4_3.TemporalData.simple").to(device=device)
graph_4_4=torch.load("./train_graph/graph_4_4.TemporalData.simple").to(device=device)
graph_4_5=torch.load("./train_graph/graph_4_5.TemporalData.simple").to(device=device)
graph_4_9=torch.load("./train_graph/graph_4_9.TemporalData.simple").to(device=device)

graph_4_10=torch.load("./train_graph/graph_4_10.TemporalData.simple").to(device=device)
graph_4_11=torch.load("./train_graph/graph_4_11.TemporalData.simple").to(device=device)
graph_4_12=torch.load("./train_graph/graph_4_12.TemporalData.simple").to(device=device)
graph_4_13=torch.load("./train_graph/graph_4_13.TemporalData.simple").to(device=device)
train_data=graph_4_10

# GNN

In [9]:
[
    graph_4_3.num_nodes,
    graph_4_4.num_nodes,
    graph_4_5.num_nodes,
    graph_4_9.num_nodes,
    graph_4_10.num_nodes,
    graph_4_11.num_nodes,
    graph_4_12.num_nodes,
    graph_4_13.num_nodes,
]

[828311, 828304, 828187, 746145, 815985, 826308, 826255, 826255]

In [None]:
max_node_num = 828398 
min_dst_idx, max_dst_idx = 0, max_node_num
neighbor_loader = LastNeighborLoader(max_node_num, size=20, device=device)

class GraphAttentionEmbedding(torch.nn.Module):
    """
    The GraphAttentionEmbedding class implements a graph attention-based 
    embedding model using a two-layer TransformerConv. 
    It is designed for temporal graph networks, where the embedding considers 
    both node features and edge attributes, including temporal information.

    Parameters:
    - in_channels (int): Input feature dimensionality for nodes.
    - out_channels (int): Output feature dimensionality for nodes.
    - msg_dim (int): Dimensionality of edge message features.
    - time_enc (TimeEncoder): Time encoding module.
    """
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super(GraphAttentionEmbedding, self).__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels

        # First TransformerConv layer with 8 heads
        self.conv = TransformerConv(in_channels, out_channels, heads=8,
                                    dropout=0.0, edge_dim=edge_dim)
        
        # Second TransformerConv layer with 1 head, no concatenation
        self.conv2 = TransformerConv(out_channels*8, out_channels,heads=1, concat=False,
                             dropout=0.0, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        """
        Forward pass of the GraphAttentionEmbedding model.

        Parameters:
        - x (torch.Tensor): Node features of shape (num_nodes, in_channels).
        - last_update (torch.Tensor): Timestamps of last updates for each node.
        - edge_index (torch.Tensor): Edge index in COO format of shape (2, num_edges).
        - t (torch.Tensor): Edge timestamps of shape (num_edges,).
        - msg (torch.Tensor): Edge message features of shape (num_edges, msg_dim).

        Returns:
        - torch.Tensor: Updated node embeddings of shape (num_nodes, out_channels).
        """
        last_update.to(device)
        x = x.to(device)
        t = t.to(device)

        # Compute relative time differences and encode them
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))

        # Concatenate relative time encoding and message features as edge attributes
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)

        # Apply the first TransformerConv layer with ReLU activation
        x = F.relu(self.conv(x, edge_index, edge_attr))

        # Apply the second TransformerConv layer with ReLU activation
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        return x


class LinkPredictor(torch.nn.Module):
    """
    A neural network for Event Type pediction tasks, transforming node embeddings
    and predicting link properties.

    Parameters:
    - in_channels (int): The dimensionality of the input node embeddings.
    """

    def __init__(self, in_channels):
        super(LinkPredictor, self).__init__()
        
        # Linear transformations for source and destination embeddings
        self.lin_src = Linear(in_channels, in_channels*2)
        self.lin_dst = Linear(in_channels, in_channels*2)
        
        # Sequential feedforward network for link prediction
        self.lin_seq = nn.Sequential(
            Linear(in_channels*4, in_channels*8),
            torch.nn.Dropout(0.5),
            nn.Tanh(),
            Linear(in_channels*8, in_channels*2),
            torch.nn.Dropout(0.5),
            nn.Tanh(),
            Linear(in_channels*2, int(in_channels//2)),
            torch.nn.Dropout(0.5),
            nn.Tanh(),

            # Final prediction layer
            # train_data.msg contains feature vector of src and dest and edge (event) information.
            # train_data.msg.shape[1]-32 is the size of the edge (event) info
            Linear(int(in_channels//2), train_data.msg.shape[1]-32)                   
        )
        

    def forward(self, z_src, z_dst):
        """
        Forward pass for the LinkPredictor.

        Parameters:
        - z_src (torch.Tensor): Source node embeddings, shape (batch_size, in_channels).
        - z_dst (torch.Tensor): Destination node embeddings, shape (batch_size, in_channels).

        Returns:
        - torch.Tensor: Predicted Edge (event) properties, shape (batch_size, train_data.msg.shape[1] - 32).
        """
        h = torch.cat([self.lin_src(z_src) , self.lin_dst(z_dst)],dim=-1)      
         
        h = self.lin_seq (h)
        
        return h

memory_dim = 100         # node state
time_dim = 100
embedding_dim = 100      # edge embedding

# create the memory
memory = TGNMemory(
    max_node_num,
    train_data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(train_data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

# create the graph neural network with transformer
gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=train_data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

# create the MLP link predictor
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

# set Adam optimizer for all 3 networks
optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(link_pred.parameters()), lr=0.00005, eps=1e-08,weight_decay=0.01)


# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(max_node_num, dtype=torch.long, device=device)

saved_nodes=set()

In [None]:
BATCH=1024

def train(train_data):
    """
    Trains the Temporal Graph Network (TGN) using sequential batches of data.

    Parameters:
    - train_data: Dataset object with sequential batches, timestamps, and messages.

    Returns:
    - float: Average loss over all events in the training data.
    """
    
    # Set networks to training mode
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    saved_nodes=set()

    # Tracks total loss over the training data
    total_loss = 0
    
    # Process each batch in the training dataset
    for batch in train_data.seq_batches(batch_size=BATCH):
        # Reset gradients
        optimizer.zero_grad()

        # Extract batch data: source, destination, timestamps, and messages
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg        
        
        # Retrieve unique nodes involved in the batch and their neighbors
        n_id = torch.cat([src, pos_dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)

        # Helper vector to map global node indices to local ones.
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get updated memory of all nodes involved in the computation.
        z, last_update = memory(n_id)
        
        # Pass embeddings through the GNN to update them
        z = gnn(z, last_update, edge_index, train_data.t[e_id], train_data.msg[e_id])
        
        # Compute predictions for positive edges
        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])       

         # Concatenate predictions
        y_pred = torch.cat([pos_out], dim=0)
        
        # Ground-truth label generation using tensor_find
        # msg contains feature vector of src and dest and edge (event) information.
        # extracting event info from them
        y_true=[]
        for m in msg:
            l=tensor_find(m[16:-16],1)-1
            y_true.append(l)           
          
        y_true = torch.tensor(y_true).to(device=device)
        y_true=y_true.reshape(-1).to(torch.long).to(device=device)
        
        # Compute loss using predictions and ground-truth labels
        loss = criterion(y_pred, y_true)
        
        # Update memory and neighbor loader with ground-truth state.
        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        
        # Backpropagation
        loss.backward()
        optimizer.step()

        # Detach memory to free computation graph
        memory.detach()

        # Accumulate total loss scaled by the number of events in the batch
        total_loss += float(loss) * batch.num_events

    # Return the average loss over all events
    return total_loss / train_data.num_events

## Start to train

In [None]:
# train on benign graphs
train_graphs=[
    graph_4_3,
    graph_4_4, 
    graph_4_5,
#     graph_4_9
]

# train fro 50 epochs
for epoch in tqdm(range(1, 51)):
    for g in train_graphs:
        loss = train(g)
        print(f'  Epoch: {epoch:02d}, Loss: {loss:.4f}')

# store the models in file
model=[memory,gnn, link_pred,neighbor_loader]
os.system("mkdir -p ./models/")
torch.save(model,"./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt")

  0%|          | 0/50 [00:00<?, ?it/s]

  Epoch: 01, Loss: 0.6451
  Epoch: 01, Loss: 0.2991


  2%|▏         | 1/50 [25:03<20:27:52, 1503.52s/it]

  Epoch: 01, Loss: 0.3362
  Epoch: 02, Loss: 0.3019
  Epoch: 02, Loss: 0.2715


  4%|▍         | 2/50 [53:18<21:32:58, 1616.22s/it]

  Epoch: 02, Loss: 0.3097
  Epoch: 03, Loss: 0.2910
  Epoch: 03, Loss: 0.2653


  6%|▌         | 3/50 [1:25:27<22:57:43, 1758.79s/it]

  Epoch: 03, Loss: 0.3036
  Epoch: 04, Loss: 0.2856
  Epoch: 04, Loss: 0.2647


  8%|▊         | 4/50 [2:01:01<24:22:03, 1907.03s/it]

  Epoch: 04, Loss: 0.2971
  Epoch: 05, Loss: 0.2862
  Epoch: 05, Loss: 0.2599


 10%|█         | 5/50 [2:37:32<25:07:12, 2009.60s/it]

  Epoch: 05, Loss: 0.2933
  Epoch: 06, Loss: 0.2832
  Epoch: 06, Loss: 0.2599


 12%|█▏        | 6/50 [3:11:29<24:40:32, 2018.91s/it]

  Epoch: 06, Loss: 0.2921
  Epoch: 07, Loss: 0.2834
  Epoch: 07, Loss: 0.2567


 14%|█▍        | 7/50 [3:43:53<23:49:08, 1994.16s/it]

  Epoch: 07, Loss: 0.2918
  Epoch: 08, Loss: 0.2811
  Epoch: 08, Loss: 0.2564


 16%|█▌        | 8/50 [4:15:07<22:49:09, 1955.93s/it]

  Epoch: 08, Loss: 0.2940
  Epoch: 09, Loss: 0.2810
  Epoch: 09, Loss: 0.2550


 18%|█▊        | 9/50 [4:49:00<22:33:03, 1980.08s/it]

  Epoch: 09, Loss: 0.2903
  Epoch: 10, Loss: 0.2789
  Epoch: 10, Loss: 0.2548


 20%|██        | 10/50 [5:20:51<21:45:48, 1958.70s/it]

  Epoch: 10, Loss: 0.2896
  Epoch: 11, Loss: 0.2777
  Epoch: 11, Loss: 0.2528


 22%|██▏       | 11/50 [5:54:28<21:24:45, 1976.54s/it]

  Epoch: 11, Loss: 0.2885
  Epoch: 12, Loss: 0.2777
  Epoch: 12, Loss: 0.2529


 24%|██▍       | 12/50 [6:31:27<21:38:32, 2050.33s/it]

  Epoch: 12, Loss: 0.2883
  Epoch: 13, Loss: 0.2769
  Epoch: 13, Loss: 0.2508


 26%|██▌       | 13/50 [7:08:09<21:32:46, 2096.40s/it]

  Epoch: 13, Loss: 0.2879
  Epoch: 14, Loss: 0.2756
  Epoch: 14, Loss: 0.2508


 28%|██▊       | 14/50 [7:42:25<20:50:26, 2084.08s/it]

  Epoch: 14, Loss: 0.2877
  Epoch: 15, Loss: 0.2763
  Epoch: 15, Loss: 0.2498


 30%|███       | 15/50 [8:19:22<20:39:05, 2124.16s/it]

  Epoch: 15, Loss: 0.2866
  Epoch: 16, Loss: 0.2771
  Epoch: 16, Loss: 0.2510


 32%|███▏      | 16/50 [8:54:52<20:04:47, 2126.11s/it]

  Epoch: 16, Loss: 0.2857
  Epoch: 17, Loss: 0.2754
  Epoch: 17, Loss: 0.2527


 34%|███▍      | 17/50 [9:31:01<19:36:23, 2138.91s/it]

  Epoch: 17, Loss: 0.2874
  Epoch: 18, Loss: 0.2756
  Epoch: 18, Loss: 0.2488


 36%|███▌      | 18/50 [10:04:44<18:42:14, 2104.20s/it]

  Epoch: 18, Loss: 0.2843
  Epoch: 19, Loss: 0.2771
  Epoch: 19, Loss: 0.2510


 38%|███▊      | 19/50 [10:39:12<18:01:28, 2093.17s/it]

  Epoch: 19, Loss: 0.2861
  Epoch: 20, Loss: 0.2752
  Epoch: 20, Loss: 0.2489


 40%|████      | 20/50 [11:15:10<17:36:23, 2112.78s/it]

  Epoch: 20, Loss: 0.2845
  Epoch: 21, Loss: 0.2772
  Epoch: 21, Loss: 0.2504


 42%|████▏     | 21/50 [11:49:29<16:53:20, 2096.57s/it]

  Epoch: 21, Loss: 0.2879
  Epoch: 22, Loss: 0.2777
  Epoch: 22, Loss: 0.2494


 44%|████▍     | 22/50 [12:23:04<16:06:59, 2072.14s/it]

  Epoch: 22, Loss: 0.2825
  Epoch: 23, Loss: 0.2763
  Epoch: 23, Loss: 0.2468


 46%|████▌     | 23/50 [12:58:13<15:37:22, 2083.06s/it]

  Epoch: 23, Loss: 0.2858
  Epoch: 24, Loss: 0.2772
  Epoch: 24, Loss: 0.2516


 48%|████▊     | 24/50 [13:33:16<15:05:14, 2089.02s/it]

  Epoch: 24, Loss: 0.2815
  Epoch: 25, Loss: 0.2788
  Epoch: 25, Loss: 0.2492


 50%|█████     | 25/50 [14:10:30<14:48:37, 2132.72s/it]

  Epoch: 25, Loss: 0.2775
  Epoch: 26, Loss: 0.2777
  Epoch: 26, Loss: 0.2484


 52%|█████▏    | 26/50 [14:45:23<14:08:14, 2120.62s/it]

  Epoch: 26, Loss: 0.2779
  Epoch: 27, Loss: 0.2763
  Epoch: 27, Loss: 0.2476


 54%|█████▍    | 27/50 [15:24:56<14:01:56, 2196.36s/it]

  Epoch: 27, Loss: 0.2780
  Epoch: 28, Loss: 0.2745
  Epoch: 28, Loss: 0.2481


 56%|█████▌    | 28/50 [15:59:08<13:09:24, 2152.92s/it]

  Epoch: 28, Loss: 0.2769
  Epoch: 29, Loss: 0.2760
  Epoch: 29, Loss: 0.2473


 58%|█████▊    | 29/50 [16:35:17<12:35:16, 2157.95s/it]

  Epoch: 29, Loss: 0.2765
  Epoch: 30, Loss: 0.2759
  Epoch: 30, Loss: 0.2468


 60%|██████    | 30/50 [17:11:01<11:57:52, 2153.64s/it]

  Epoch: 30, Loss: 0.2747
  Epoch: 31, Loss: 0.2756
  Epoch: 31, Loss: 0.2453


 62%|██████▏   | 31/50 [17:45:52<11:16:02, 2134.89s/it]

  Epoch: 31, Loss: 0.2756
  Epoch: 32, Loss: 0.2749
  Epoch: 32, Loss: 0.2459


 64%|██████▍   | 32/50 [18:21:47<10:42:17, 2140.99s/it]

  Epoch: 32, Loss: 0.2727
  Epoch: 33, Loss: 0.2735
  Epoch: 33, Loss: 0.2451


 66%|██████▌   | 33/50 [18:59:54<10:19:01, 2184.79s/it]

  Epoch: 33, Loss: 0.2744
  Epoch: 34, Loss: 0.2759
  Epoch: 34, Loss: 0.2451


 68%|██████▊   | 34/50 [19:37:53<9:50:05, 2212.87s/it] 

  Epoch: 34, Loss: 0.2719
  Epoch: 35, Loss: 0.2752
  Epoch: 35, Loss: 0.2446


 70%|███████   | 35/50 [20:18:34<9:30:24, 2281.60s/it]

  Epoch: 35, Loss: 0.2745
  Epoch: 36, Loss: 0.2747
  Epoch: 36, Loss: 0.2448


 72%|███████▏  | 36/50 [21:00:18<9:07:53, 2348.09s/it]

  Epoch: 36, Loss: 0.2677
  Epoch: 37, Loss: 0.2751
  Epoch: 37, Loss: 0.2452


 74%|███████▍  | 37/50 [21:41:26<8:36:33, 2384.13s/it]

  Epoch: 37, Loss: 0.2746
  Epoch: 38, Loss: 0.2753
  Epoch: 38, Loss: 0.2458


 76%|███████▌  | 38/50 [22:21:04<7:56:27, 2382.26s/it]

  Epoch: 38, Loss: 0.2704
  Epoch: 39, Loss: 0.2754
  Epoch: 39, Loss: 0.2443


 78%|███████▊  | 39/50 [23:01:53<7:20:26, 2402.40s/it]

  Epoch: 39, Loss: 0.2695
  Epoch: 40, Loss: 0.2748
  Epoch: 40, Loss: 0.2415


 80%|████████  | 40/50 [23:45:20<6:50:36, 2463.63s/it]

  Epoch: 40, Loss: 0.2657
  Epoch: 41, Loss: 0.2716
  Epoch: 41, Loss: 0.2399


 82%|████████▏ | 41/50 [24:29:47<6:18:41, 2524.60s/it]

  Epoch: 41, Loss: 0.2663
  Epoch: 42, Loss: 0.2718
  Epoch: 42, Loss: 0.2388


 84%|████████▍ | 42/50 [25:15:11<5:44:35, 2584.49s/it]

  Epoch: 42, Loss: 0.2665
  Epoch: 43, Loss: 0.2705
  Epoch: 43, Loss: 0.2396


 86%|████████▌ | 43/50 [25:56:45<4:58:22, 2557.49s/it]

  Epoch: 43, Loss: 0.2682
  Epoch: 44, Loss: 0.2705
  Epoch: 44, Loss: 0.2382


 88%|████████▊ | 44/50 [26:39:34<4:16:05, 2560.88s/it]

  Epoch: 44, Loss: 0.2663
  Epoch: 45, Loss: 0.2704
  Epoch: 45, Loss: 0.2386


 90%|█████████ | 45/50 [27:23:30<3:35:17, 2583.53s/it]

  Epoch: 45, Loss: 0.2645
  Epoch: 46, Loss: 0.2696
  Epoch: 46, Loss: 0.2387


 92%|█████████▏| 46/50 [28:09:24<2:55:37, 2634.38s/it]

  Epoch: 46, Loss: 0.2627
  Epoch: 47, Loss: 0.2694
  Epoch: 47, Loss: 0.2388


 94%|█████████▍| 47/50 [28:54:03<2:12:23, 2647.84s/it]

  Epoch: 47, Loss: 0.2631
  Epoch: 48, Loss: 0.2696
  Epoch: 48, Loss: 0.2379


 96%|█████████▌| 48/50 [29:38:45<1:28:36, 2658.22s/it]

  Epoch: 48, Loss: 0.2629
  Epoch: 49, Loss: 0.2697
  Epoch: 49, Loss: 0.2376


 98%|█████████▊| 49/50 [30:23:17<44:22, 2662.40s/it]  

  Epoch: 49, Loss: 0.2630
  Epoch: 50, Loss: 0.2688
  Epoch: 50, Loss: 0.2371


100%|██████████| 50/50 [31:05:38<00:00, 2238.77s/it]

  Epoch: 50, Loss: 0.2628





# Test

In [None]:
import time 

# Disable gradient computation for inference
@torch.no_grad()

def test_day_new(inference_data, path):
    """
    Evaluate the Temporal Graph Network (TGN) on inference data.
    Stores result for each time window in txt file.
    Each text file contains event information. 
    Example: {'loss': 2.8073678016662598, 'srcnode': 76938, 'dstnode': 868, 'srcmsg': "{'subject': 'system_server'}", 'dstmsg': "{'file': '/data/system/sync/pending.xml'}", 'edge_type': 'EVENT_OPEN', 'time': 1522814400030000000}

    Parameters:
    - inference_data: Dataset containing sequential batches for temporal graph inference.
    - path (str): Path to save logs and results.

    Returns:
    - dict: Summary of results, including loss and timing for each time window.
    """

    if os.path.exists(path):
        pass
    else:
        os.mkdir(path)
    
    # Set networks to evaluation mode
    memory.eval()
    gnn.eval()
    link_pred.eval()
    
    memory.reset_state()  # Start with a fresh memory.  
    neighbor_loader.reset_state()  # Start with an empty graph.
    
    # Initialize tracking variables
    time_with_loss={}
    total_loss = 0    
    edge_list=[]
    
    unique_nodes=torch.tensor([]).to(device=device)
    total_edges=0


    start_time=inference_data.t[0]
    event_count=0
    
    pos_o=[]
    loss_list=[]

    print("after merge:",inference_data)
    
    # Record the running time to evaluate the performance
    start = time.perf_counter()

    for batch in inference_data.seq_batches(batch_size=BATCH):
        
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
        unique_nodes=torch.cat([unique_nodes,src,pos_dst]).unique()
        total_edges+=BATCH
        
        # Retrieve embeddings and neighbors
        n_id = torch.cat([src, pos_dst]).unique()       
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get memory and update embeddings via GNN
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, inference_data.t[e_id], inference_data.msg[e_id])

        # Predict edge properties
        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
        pos_o.append(pos_out)
        y_pred = torch.cat([pos_out], dim=0)

        # Generate ground-truth labels
        y_true=[]
        for m in msg:
            # Extract label index from the message
            l=tensor_find(m[16:-16],1)-1
            y_true.append(l) 
        y_true = torch.tensor(y_true).to(device=device)
        y_true=y_true.reshape(-1).to(torch.long).to(device=device)

        # Only consider which edge hasn't been correctly predicted.
        # For benign graphs, the behaviors patterns are similar and therefore their losses are small
        # For anoamlous behaviors, some behaviors might not be seen before, so the probability of predicting those edges are low. Thus their losses are high.
        loss = criterion(y_pred, y_true)
        total_loss += float(loss) * batch.num_events
        
        # update the edges in the batch to the memory and neighbor_loader
        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        
        # compute the loss for each edge
        each_edge_loss= cal_pos_edges_loss_multiclass(pos_out,y_true)
        
        # Process edges for logging
        for i in range(len(pos_out)):
            srcnode=int(src[i])
            dstnode=int(pos_dst[i])  
            
            srcmsg=str(nodeid2msg[srcnode]) 
            dstmsg=str(nodeid2msg[dstnode])
            t_var=int(t[i])
            edgeindex=tensor_find(msg[i][16:-16],1)   
            edge_type=rel2id[edgeindex]
            loss=each_edge_loss[i]    

            temp_dic={}
            temp_dic['loss']=float(loss)
            temp_dic['srcnode']=srcnode
            temp_dic['dstnode']=dstnode
            temp_dic['srcmsg']=srcmsg
            temp_dic['dstmsg']=dstmsg
            temp_dic['edge_type']=edge_type
            temp_dic['time']=t_var
            
            edge_list.append(temp_dic)
        
        # Check if the time interval is over (15 minutes - Window Size)
        event_count+=len(batch.src)
        if t[-1]>start_time+60000000000*15:
            # Here is a checkpoint, which records all edge losses in the current time window
            time_interval=ns_time_to_datetime_US(start_time)+"~"+ns_time_to_datetime_US(t[-1])

            end = time.perf_counter()
            time_with_loss[time_interval]={'loss':loss,
                                          'nodes_count':len(unique_nodes),
                                          'total_edges':total_edges,
                                          'costed_time':(end-start)}
            
            
            log=open(path+"/"+time_interval+".txt",'w')
            
            # Compute average loss for the interval
            for e in edge_list:
                loss+=e['loss']
            loss=loss/event_count   

            # Save results to log file
            print(f'Time: {time_interval}, Loss: {loss:.4f}, Nodes_count: {len(unique_nodes)}, Cost Time: {(end-start):.2f}s')
            edge_list = sorted(edge_list, key=lambda x:x['loss'],reverse=True)  # Rank the results based on edge losses
            for e in edge_list: 
                log.write(str(e))
                log.write("\n") 
            
            # Reset tracking variables for the next interval
            event_count=0
            total_loss=0
            loss=0
            start_time=t[-1]
            log.close()
            edge_list.clear()

    return time_with_loss

# Test 4-9 ~ 4-12

In [None]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model

# Test the data for 2018-04-03, create and store time window with loss
ans_4_3=test_day_new(graph_4_3,"graph_4_3")

after merge: TemporalData(dst=[8230837], msg=[8230837, 41], src=[8230837], t=[8230837])
Time: 2018-04-03 10:02:14.348882872~2018-04-03 10:17:22.845633004, Loss: 0.5784, Nodes_count: 2021, Cost Time: 4.19s
Time: 2018-04-03 10:17:22.845633004~2018-04-03 10:32:52.909926097, Loss: 0.3849, Nodes_count: 5070, Cost Time: 11.01s
Time: 2018-04-03 10:32:52.909926097~2018-04-03 10:47:54.761593566, Loss: 0.2185, Nodes_count: 9219, Cost Time: 23.90s
Time: 2018-04-03 10:47:54.761593566~2018-04-03 11:03:00.878438338, Loss: 0.2021, Nodes_count: 13173, Cost Time: 37.13s
Time: 2018-04-03 11:03:00.878438338~2018-04-03 11:18:31.001956600, Loss: 0.2005, Nodes_count: 17464, Cost Time: 55.36s
Time: 2018-04-03 11:18:31.001956600~2018-04-03 11:33:31.167232690, Loss: 0.2882, Nodes_count: 20766, Cost Time: 74.99s
Time: 2018-04-03 11:33:31.167232690~2018-04-03 11:48:41.419633970, Loss: 0.2392, Nodes_count: 24304, Cost Time: 93.79s
Time: 2018-04-03 11:48:41.419633970~2018-04-03 12:03:45.030247497, Loss: 0.2051, No

In [15]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model
ans_4_4=test_day_new(graph_4_4,"graph_4_4")

after merge: TemporalData(dst=[4930304], msg=[4930304, 41], src=[4930304], t=[4930304])
Time: 2018-04-04 00:00:00.001798512~2018-04-04 00:15:02.197293670, Loss: 0.4715, Nodes_count: 363, Cost Time: 0.92s
Time: 2018-04-04 00:15:02.197293670~2018-04-04 00:30:28.978098397, Loss: 0.1794, Nodes_count: 658, Cost Time: 1.96s
Time: 2018-04-04 00:30:28.978098397~2018-04-04 00:45:32.001547272, Loss: 0.1701, Nodes_count: 987, Cost Time: 2.99s
Time: 2018-04-04 00:45:32.001547272~2018-04-04 01:01:12.511397075, Loss: 0.1535, Nodes_count: 1228, Cost Time: 4.02s
Time: 2018-04-04 01:01:12.511397075~2018-04-04 01:16:50.001162895, Loss: 0.1579, Nodes_count: 1488, Cost Time: 5.10s
Time: 2018-04-04 01:16:50.001162895~2018-04-04 01:32:05.001431940, Loss: 0.1720, Nodes_count: 1751, Cost Time: 6.16s
Time: 2018-04-04 01:32:05.001431940~2018-04-04 01:47:27.933058106, Loss: 0.1663, Nodes_count: 2030, Cost Time: 7.22s
Time: 2018-04-04 01:47:27.933058106~2018-04-04 02:02:58.001409515, Loss: 0.1624, Nodes_count: 22

In [16]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model
ans_4_5=test_day_new(graph_4_5,"graph_4_5")

after merge: TemporalData(dst=[1489011], msg=[1489011, 41], src=[1489011], t=[1489011])
Time: 2018-04-05 00:00:00.001766694~2018-04-05 00:16:00.002476684, Loss: 0.5540, Nodes_count: 346, Cost Time: 0.87s
Time: 2018-04-05 00:16:00.002476684~2018-04-05 00:31:52.143639968, Loss: 0.2162, Nodes_count: 650, Cost Time: 1.84s
Time: 2018-04-05 00:31:52.143639968~2018-04-05 00:47:52.002321096, Loss: 0.2189, Nodes_count: 937, Cost Time: 2.81s
Time: 2018-04-05 00:47:52.002321096~2018-04-05 01:03:35.001510674, Loss: 0.2135, Nodes_count: 1217, Cost Time: 3.78s
Time: 2018-04-05 01:03:35.001510674~2018-04-05 01:19:09.002324977, Loss: 0.2047, Nodes_count: 1489, Cost Time: 4.76s
Time: 2018-04-05 01:19:09.002324977~2018-04-05 01:35:25.002311363, Loss: 0.2005, Nodes_count: 1746, Cost Time: 5.73s
Time: 2018-04-05 01:35:25.002311363~2018-04-05 01:51:29.002313221, Loss: 0.2080, Nodes_count: 2011, Cost Time: 6.71s
Time: 2018-04-05 01:51:29.002313221~2018-04-05 02:07:14.002084758, Loss: 0.2007, Nodes_count: 22

In [17]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model
ans_4_9=test_day_new(graph_4_9,"graph_4_9")

after merge: TemporalData(dst=[685635], msg=[685635, 41], src=[685635], t=[685635])
Time: 2018-04-09 08:46:55.004764124~2018-04-09 09:03:31.001287346, Loss: 2.8597, Nodes_count: 27, Cost Time: 0.06s
Time: 2018-04-09 09:03:31.001287346~2018-04-09 09:20:23.001295997, Loss: 0.1001, Nodes_count: 67, Cost Time: 0.47s
Time: 2018-04-09 09:20:23.001295997~2018-04-09 09:36:02.001305059, Loss: 0.1480, Nodes_count: 100, Cost Time: 0.77s
Time: 2018-04-09 09:36:02.001305059~2018-04-09 09:51:31.358485271, Loss: 0.4357, Nodes_count: 5561, Cost Time: 2.98s
Time: 2018-04-09 09:51:31.358485271~2018-04-09 10:06:48.907525397, Loss: 0.4159, Nodes_count: 6439, Cost Time: 7.21s
Time: 2018-04-09 10:06:48.907525397~2018-04-09 10:23:05.001376520, Loss: 0.3307, Nodes_count: 13241, Cost Time: 14.88s
Time: 2018-04-09 10:23:05.001376520~2018-04-09 10:39:33.001366978, Loss: 0.3014, Nodes_count: 13594, Cost Time: 17.01s
Time: 2018-04-09 10:39:33.001366978~2018-04-09 10:58:03.657948688, Loss: 0.2514, Nodes_count: 1384

In [18]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model
ans_4_10=test_day_new(graph_4_10,"graph_4_10")

after merge: TemporalData(dst=[6274151], msg=[6274151, 41], src=[6274151], t=[6274151])
Time: 2018-04-10 12:44:33.449564893~2018-04-10 13:00:02.700560774, Loss: 0.5550, Nodes_count: 701, Cost Time: 0.39s
Time: 2018-04-10 13:00:02.700560774~2018-04-10 13:16:13.551944728, Loss: 0.1340, Nodes_count: 1373, Cost Time: 8.39s
Time: 2018-04-10 13:16:13.551944728~2018-04-10 13:31:14.548738409, Loss: 0.3707, Nodes_count: 6803, Cost Time: 19.51s
Time: 2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223, Loss: 0.3934, Nodes_count: 11267, Cost Time: 40.40s
Time: 2018-04-10 13:46:36.161065223~2018-04-10 14:02:17.001271389, Loss: 0.1005, Nodes_count: 11336, Cost Time: 42.41s
Time: 2018-04-10 14:02:17.001271389~2018-04-10 14:17:34.001373488, Loss: 0.2240, Nodes_count: 11950, Cost Time: 45.23s
Time: 2018-04-10 14:17:34.001373488~2018-04-10 14:33:18.350772859, Loss: 0.0153, Nodes_count: 13208, Cost Time: 214.04s
Time: 2018-04-10 14:33:18.350772859~2018-04-10 14:48:47.320442910, Loss: 0.3234, Nod

In [19]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model
ans_4_11=test_day_new(graph_4_11,"graph_4_11")

after merge: TemporalData(dst=[7285220], msg=[7285220, 41], src=[7285220], t=[7285220])
Time: 2018-04-11 00:00:00.001151329~2018-04-11 00:16:06.001274623, Loss: 0.4556, Nodes_count: 117, Cost Time: 0.73s
Time: 2018-04-11 00:16:06.001274623~2018-04-11 00:32:08.001169192, Loss: 0.1016, Nodes_count: 257, Cost Time: 1.55s
Time: 2018-04-11 00:32:08.001169192~2018-04-11 00:48:31.001594545, Loss: 0.1004, Nodes_count: 336, Cost Time: 2.37s
Time: 2018-04-11 00:48:31.001594545~2018-04-11 01:04:53.001888013, Loss: 0.0929, Nodes_count: 413, Cost Time: 3.20s
Time: 2018-04-11 01:04:53.001888013~2018-04-11 01:20:59.002307553, Loss: 0.0857, Nodes_count: 492, Cost Time: 4.02s
Time: 2018-04-11 01:20:59.002307553~2018-04-11 01:37:06.826246565, Loss: 0.0996, Nodes_count: 575, Cost Time: 4.85s
Time: 2018-04-11 01:37:06.826246565~2018-04-11 01:53:33.001963923, Loss: 0.0826, Nodes_count: 649, Cost Time: 5.66s
Time: 2018-04-11 01:53:33.001963923~2018-04-11 02:10:00.001194307, Loss: 0.0807, Nodes_count: 724, C

In [20]:
model=torch.load("./models/model_saved_emb100_BATCH_1024_LastAggregator_multiclass_without_neg_edge.pt",map_location=device)
memory,gnn, link_pred,neighbor_loader=model
ans_4_12=test_day_new(graph_4_12,"graph_4_12")

after merge: TemporalData(dst=[7024937], msg=[7024937, 41], src=[7024937], t=[7024937])
Time: 2018-04-12 00:00:00.001773757~2018-04-12 00:15:26.001647081, Loss: 0.5624, Nodes_count: 120, Cost Time: 0.59s
Time: 2018-04-12 00:15:26.001647081~2018-04-12 00:30:58.002002412, Loss: 0.1480, Nodes_count: 252, Cost Time: 1.26s
Time: 2018-04-12 00:30:58.002002412~2018-04-12 00:46:27.001835122, Loss: 0.1243, Nodes_count: 330, Cost Time: 1.93s
Time: 2018-04-12 00:46:27.001835122~2018-04-12 01:02:05.698953597, Loss: 0.1239, Nodes_count: 402, Cost Time: 2.60s
Time: 2018-04-12 01:02:05.698953597~2018-04-12 01:17:27.002277528, Loss: 0.1177, Nodes_count: 477, Cost Time: 3.28s
Time: 2018-04-12 01:17:27.002277528~2018-04-12 01:32:48.001732365, Loss: 0.1406, Nodes_count: 594, Cost Time: 3.95s
Time: 2018-04-12 01:32:48.001732365~2018-04-12 01:48:20.001593370, Loss: 0.1188, Nodes_count: 671, Cost Time: 4.63s
Time: 2018-04-12 01:48:20.001593370~2018-04-12 02:04:08.001895394, Loss: 0.1234, Nodes_count: 742, C

## Compute anomlous score and Initialize the node IDF

Calculates the Inverse Document Frequencey of each nodes 

IDF_of_node_x = log(nummber_of_time_widows / (1 + nuber_of_time_window_that_has_node_x))

In [None]:
def cal_train_IDF(find_str,file_list):
    """
    Calculate the IDF value for a term in a set of time windows.

    Parameters:
        find_str (str): The term to calculate the IDF for.
        file_list (list): List of file paths representing the time windows.

    Returns:
        float: The computed IDF value.
    """
    include_count=0
    for f_path in (file_list):
        f=open(f_path)
        if find_str in f.read():
            include_count+=1             
    IDF=math.log(len(file_list)/(include_count+1))
    return IDF


def cal_IDF(find_str,file_path,file_list):
    file_list=os.listdir(file_path)
    include_count=0
    different_neighbor=set()
    for f_path in (file_list):
        f=open(file_path+f_path)
        if find_str in f.read():
            include_count+=1
    IDF=math.log(len(file_list)/(include_count+1))
    
    return IDF,1

def cal_IDF_by_file_in_mem(find_str,file_list):

    include_count=0
    different_neighbor=set()
    for f_path in (file_list):
        f=open(file_path+f_path)
        if find_str in f.read():
            include_count+=1 
    IDF=math.log(len(file_list)/(include_count+1))
    
    return IDF,1

def cal_redundant(find_str,edge_list):
    
    different_neighbor=set()
    for e in edge_list:
        if find_str in str(e):
            different_neighbor.add(e[0])
            different_neighbor.add(e[1])
    return len(different_neighbor)-2

def cal_anomaly_loss(loss_list,edge_list,file_path):
    """
    Calculate anomaly loss by analyzing loss values exceeding a threshold.

    Parameters:
        loss_list (list): List of loss values for each edge.
        edge_list (list): List of edges corresponding to the loss values.
        file_path (str): Path to data files (currently unused).

    Returns:
        tuple: A tuple containing:
            - count (int): Number of anomalies detected.
            - avg_loss (float): Average loss of anomalous edges.
            - node_set (set): Set of unique nodes involved in anomalies.
            - edge_set (set): Set of unique edges with anomalies.
    """
    if len(loss_list)!=len(edge_list):
        print("error!")
        return 0
    count=0
    loss_sum=0
    loss_std=std(loss_list)
    loss_mean=mean(loss_list)
    edge_set=set()
    node_set=set()
    node2redundant={}
    
    thr=loss_mean+1.5*loss_std

    print("thr:",thr)
  
    for i in range(len(loss_list)):
        if loss_list[i]>thr:
            count+=1
            src_node=edge_list[i][0]
            dst_node=edge_list[i][1]
          
            loss_sum+=loss_list[i]
    
            node_set.add(src_node)
            node_set.add(dst_node)
            edge_set.add(edge_list[i][0]+edge_list[i][1])
    return count, loss_sum/(count+0.00000000001),node_set,edge_set

In [None]:
# list of all time windows. Times windows are generated and stored as files using test_day_new()
file_list=[]

# Stores the IDF for each node
node_IDF={}

# Stores the time window for each node
node_set=set()

file_path="graph_4_3/"
file_l=os.listdir("graph_4_3/")
for i in file_l:
    file_list.append(file_path+i)

file_path="graph_4_4/"
file_l=os.listdir("graph_4_4/")
for i in file_l:
    file_list.append(file_path+i)

file_path="graph_4_5/"
file_l=os.listdir("graph_4_5/")
for i in file_l:
    file_list.append(file_path+i)

    
node_set = {}
for f_path in tqdm(file_list):
    f=open(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        if jdata['loss']>0:
            if 'netflow' not in str(jdata['srcmsg']) or True:
                if str(jdata['srcmsg']) not in node_set.keys():
                    node_set[str(jdata['srcmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['srcmsg'])].add(f_path)
            if 'netflow' not in str(jdata['dstmsg']) or True:
                if str(jdata['dstmsg']) not in node_set.keys():
                    node_set[str(jdata['dstmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['dstmsg'])].add(f_path)
for n in node_set:
    include_count = len(node_set[n])   
    IDF=math.log(len(file_list)/(include_count+1))
    node_IDF[n] = IDF 


torch.save(node_IDF,"node_IDF_4_3-5")
print("IDF weight calculate complete!")

100%|██████████| 194/194 [04:34<00:00,  1.41s/it]

IDF weight calculate complete!





In [23]:
node_IDF={}
node_set=set()

file_list=[]

file_path="graph_4_10/"
file_l=os.listdir("graph_4_10/")
for i in file_l:
    file_list.append(file_path+i)


node_set = {}
for f_path in tqdm(file_list):
    f=open(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        if jdata['loss']>0:
            if 'netflow' not in str(jdata['srcmsg']) or True:
                if str(jdata['srcmsg']) not in node_set.keys():
                    node_set[str(jdata['srcmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['srcmsg'])].add(f_path)
            if 'netflow' not in str(jdata['dstmsg']) or True:
                if str(jdata['dstmsg']) not in node_set.keys():
                    node_set[str(jdata['dstmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['dstmsg'])].add(f_path)
for n in node_set:
    include_count = len(node_set[n])   
    IDF=math.log(len(file_list)/(include_count+1))
    node_IDF[n] = IDF 
    

torch.save(node_IDF,"node_IDF_4_10")
print("IDF weight calculate complete!")

100%|██████████| 43/43 [01:55<00:00,  2.69s/it]

IDF weight calculate complete!





In [None]:
node_IDF={}
node_set=set()

file_list=[]

file_path="graph_4_11/"
file_l=os.listdir("graph_4_11/")
for i in file_l:
    file_list.append(file_path+i)

node_set = {}
for f_path in tqdm(file_list):
    f=open(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        if jdata['loss']>0:
            if 'netflow' not in str(jdata['srcmsg']) or True:
                if str(jdata['srcmsg']) not in node_set.keys():
                    node_set[str(jdata['srcmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['srcmsg'])].add(f_path)
            if 'netflow' not in str(jdata['dstmsg']) or True:
                if str(jdata['dstmsg']) not in node_set.keys():
                    node_set[str(jdata['dstmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['dstmsg'])].add(f_path)
for n in node_set:
    include_count = len(node_set[n])   
    IDF=math.log(len(file_list)/(include_count+1))
    node_IDF[n] = IDF 


torch.save(node_IDF,"node_IDF_4_11")
print("IDF weight calculate complete!")

100%|██████████| 91/91 [02:16<00:00,  1.50s/it]


IDF weight calculate complete!


In [25]:
node_IDF={}
node_set=set()

file_list=[]

file_path="graph_4_12/"
file_l=os.listdir("graph_4_12/")
for i in file_l:
    file_list.append(file_path+i)


node_set = {}
for f_path in tqdm(file_list):
    f=open(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        if jdata['loss']>0:
            if 'netflow' not in str(jdata['srcmsg']) or True:
                if str(jdata['srcmsg']) not in node_set.keys():
                    node_set[str(jdata['srcmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['srcmsg'])].add(f_path)
            if 'netflow' not in str(jdata['dstmsg']) or True:
                if str(jdata['dstmsg']) not in node_set.keys():
                    node_set[str(jdata['dstmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['dstmsg'])].add(f_path)
for n in node_set:
    include_count = len(node_set[n])   
    IDF=math.log(len(file_list)/(include_count+1))
    node_IDF[n] = IDF 
    

torch.save(node_IDF,"node_IDF_4_12")
print("IDF weight calculate complete!")

100%|██████████| 93/93 [02:09<00:00,  1.39s/it]


IDF weight calculate complete!


# Construct the relations between time windows

In [None]:
# 4-10,11
def is_include_key_word_bak(s):
    """
    Check if a given string includes any predefined keywords.

    Parameters:
        s (str): The string to check.

    Returns:
        bool: True if any keyword is found in the string, False otherwise.
    """
    keywords=[
        'netflow',
        'null',
        '/dev/pts',
        'salt-minion.log',
        '675',
        'usr',
        'proc',
        '/.cache/mozilla/',
        'tmp',
        'thunderbird',
        '/bin/',
        '/sbin/sysctl',
        '/data/replay_logdb/',
        '/home/admin/eraseme',
        '/stat',
      ]
    flag=False
    for i in keywords:
        if i in s:
            flag=True
    return flag


def cal_set_rel_bak(s1,s2,file_list):
    """
    Calculate the relevance of the intersection of two sets based on IDF values.

    Parameters:
        s1 (set): The first set of elements.
        s2 (set): The second set of elements.
        node_IDF (dict): Dictionary containing precomputed IDF values for nodes.
        file_list (list): List of files to calculate default IDF values if a node is missing in node_IDF.

    Returns:
        int: The count of elements in the intersection of s1 and s2 that pass the IDF threshold.
    """
    # Intersection of the two sets
    new_s=s1 & s2
    count=0
    for i in new_s:
        # Skip processing if the element includes some specific keywords
        if is_include_key_word_bak(i) is not True:
            # Fetch the IDF value from the dictionary or calculate a default value
            if i in node_IDF.keys():
                IDF=node_IDF[i]
            else:
                IDF=math.log(len(file_list)/(1))         

            # Count the node if IDF is above the threshold
            if (IDF)>math.log(len(file_list)*0.9/(1))  :
                print("node:",i," IDF:",IDF)
                count+=1
    return count

In [27]:
# 4-12

# def is_include_key_word_bak(s):
#     keywords=[
#          'netflow',        
#         '/dev/pts',
#         'salt-minion.log',
#         'null',
#         'usr',
#          'proc',
#         'firefox',
#         'tmp',
#         'thunderbird',
#         'bin/',
#         '/data/replay_logdb',
#         '/stat',
#         '/boot',
#         'qt-opensource-linux-x64',
#         '/eraseme',
#         '675',
        
# #       
# #         'etc',  
# #         'cdrom', 
# #         'shm'
#       ]
#     flag=False
#     for i in keywords:
#         if i in s:
#             flag=True
#     return flag

# def cal_set_rel_bak(s1,s2,file_list):
#     new_s=s1 & s2
#     count=0
#     for i in new_s:
# #     jdata=json.loads(i)
#         if is_include_key_word_bak(i) is not True:
#             if i in node_IDF.keys():
#                 IDF=node_IDF[i]
#             else:
#                 IDF=math.log(len(file_list)/(1))         

#             if (IDF)>math.log(len(file_list)*0.9/(1))  :
#                 print("node:",i," IDF:",IDF)
#                 count+=1
#     return count


def is_include_key_word(s):
    keywords=[
         'netflow',        
        '/dev/pts',
        'salt-minion.log',
        'null',
        'usr',
         'proc',
        'firefox',
        'tmp',
        'thunderbird',
        'bin/',
        '/data/replay_logdb',
        '/stat',
        '/boot',
        'qt-opensource-linux-x64',
        '/eraseme',
        '675',
      ]
    flag=False
    for i in keywords:
        if i in s:
            flag=True
    return flag


file_list=[]

file_path="graph_4_12/"
file_l=os.listdir("graph_4_12/")
for i in file_l:
    file_list.append(file_path+i)
    


def cal_set_rel(s1,s2,file_list, file_list_4_3_5):
    IDF3 = node_IDF_3
    new_s=s1 & s2
    count=0
    for i in new_s:
#     jdata=json.loads(i)
       if is_include_key_word(i) is not True:
        
#         'netflow' not in i
#         and 'usr' not in i and 'var' not in i
            if i in node_IDF.keys():
                IDF=node_IDF[i]
            else:
                IDF=math.log(len(file_list)/(1))
                
            if i in node_IDF_3.keys():
                IDF3=node_IDF_3[i]
            else:
                IDF3=math.log(len(file_list_4_3_5)/(1))    
            
#             print(IDF)
            if (IDF+IDF3)>5 :
                print("node:",i," IDF:",IDF)
                count+=1
    return count

# label generation

In [None]:
# Store ground truth labels for each time window
# Initialize all truth value as 0 (benign)
labels={}
    
filelist = os.listdir("graph_4_10/")
for f in filelist:
    labels["graph_4_10/"+f]=0

filelist = os.listdir("graph_4_11/")
for f in filelist:
    labels["graph_4_11/"+f]=0
    
filelist = os.listdir("graph_4_12/")
for f in filelist:
    labels["graph_4_12/"+f]=0

In [29]:
sorted(labels)

['graph_4_10/2018-04-10 12:44:33.449564893~2018-04-10 13:00:02.700560774.txt',
 'graph_4_10/2018-04-10 13:00:02.700560774~2018-04-10 13:16:13.551944728.txt',
 'graph_4_10/2018-04-10 13:16:13.551944728~2018-04-10 13:31:14.548738409.txt',
 'graph_4_10/2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223.txt',
 'graph_4_10/2018-04-10 13:46:36.161065223~2018-04-10 14:02:17.001271389.txt',
 'graph_4_10/2018-04-10 14:02:17.001271389~2018-04-10 14:17:34.001373488.txt',
 'graph_4_10/2018-04-10 14:17:34.001373488~2018-04-10 14:33:18.350772859.txt',
 'graph_4_10/2018-04-10 14:33:18.350772859~2018-04-10 14:48:47.320442910.txt',
 'graph_4_10/2018-04-10 14:48:47.320442910~2018-04-10 15:03:54.307022037.txt',
 'graph_4_10/2018-04-10 15:03:54.307022037~2018-04-10 15:19:25.001773315.txt',
 'graph_4_10/2018-04-10 15:19:25.001773315~2018-04-10 15:36:13.002273705.txt',
 'graph_4_10/2018-04-10 15:36:13.002273705~2018-04-10 15:51:24.614585595.txt',
 'graph_4_10/2018-04-10 15:51:24.614585595~2018-04-1

In [None]:
# 9 attack time window. Set thei truth value to 1
attack_list=[
'graph_4_10/2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223.txt',
'graph_4_10/2018-04-10 14:02:17.001271389~2018-04-10 14:17:34.001373488.txt',
'graph_4_10/2018-04-10 14:17:34.001373488~2018-04-10 14:33:18.350772859.txt',
'graph_4_10/2018-04-10 14:33:18.350772859~2018-04-10 14:48:47.320442910.txt',
'graph_4_10/2018-04-10 14:48:47.320442910~2018-04-10 15:03:54.307022037.txt', 
 
'graph_4_12/2018-04-12 12:39:06.592684498~2018-04-12 12:54:44.001888457.txt',
'graph_4_12/2018-04-12 12:54:44.001888457~2018-04-12 13:09:55.026832462.txt',
'graph_4_12/2018-04-12 13:09:55.026832462~2018-04-12 13:25:06.588370709.txt',
'graph_4_12/2018-04-12 13:25:06.588370709~2018-04-12 13:40:07.178206094.txt',
]

for i in attack_list:
    labels[i]=1

In [None]:
print(f"Benign count: {len(labels.values()) - sum(labels.values())}")
print(f"Attack count: {sum(labels.values())}")

# Anomaly Detection
Steps:

1. Create a list of time window queues
2. Process all time window serially
3. For a time window calculate relevance with each of the time window in the time window queues
4. If the relevance passes a certain threshold then push it in the corresponding time window queue and move on to the next queue to check
5. If no relevance is found then create a new queue with the time window and store it


from 4-9 to 4-12 follow the same steps. Proper comments have been added to 4-9 only to reduce redandancy

## 4-9

In [None]:
# Load the IDF data
node_IDF=torch.load("node_IDF_4_3-5")
file_list=[]

file_path="graph_4_9/"
file_l=os.listdir("graph_4_9/")
for i in file_l:
    file_list.append(file_path+i)
    
y_data_4_10=[]
df_list_4_10=[]

# list of time window queues
# It is a list (collection of queues) of list (queue) of dictionary (time window)
history_list=[]

tw_que=[]
his_tw={}

# Stores data for the currently procesing time window
current_tw={}

file_path_list=[]
file_path="graph_4_9/"
file_l=os.listdir("graph_4_9/")
for i in file_l:
    file_path_list.append(file_path+i)

index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)

    # List to store loss values for edges
    edge_loss_list=[]
    
    # List to store edges (source-destination pairs)
    edge_list=[]

    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
        
    # Convert edge loss list to a DataFrame and store it
    df_list_4_10.append(pd.DataFrame(edge_loss_list))
    
    # Calculate anomaly loss metrics for the current time window
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_4_10_without_neg_edge/")
    current_tw={}
    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    # To check if the current time window is related to any historical time window
    added_que_flag=False
    
    # For each queues 
    for hq in history_list:
        # For each time window in a queue
        for his_tw in hq:
            # Calculate relvance between two time windows and if related push the current time window in the current queue and move on to check the next queue
            if cal_set_rel_bak(current_tw['nodeset'],his_tw['nodeset'],file_list)!=0 and current_tw['name']!=his_tw['name']:
                print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
        if added_que_flag:
            break
    
    # If not time window on any of the queues is similar, create a new queue with the current time window and add to the list
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list.append(temp_hq)
  
    index_count+=1
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))

index_count: 0
thr: 5.387413690877365
graph_4_9/2018-04-09 08:46:55.004764124~2018-04-09 09:03:31.001287346.txt    0.0  count: 0  percentage: 0.0  node count: 0  edge count: 0
index_count: 1
thr: 1.1272507057874195
graph_4_9/2018-04-09 09:03:31.001287346~2018-04-09 09:20:23.001295997.txt    6.08346370000012  count: 65  percentage: 0.010579427083333334  node count: 12  edge count: 12
index_count: 2
thr: 1.595366808323317
graph_4_9/2018-04-09 09:20:23.001295997~2018-04-09 09:36:02.001305059.txt    5.62544135914846  count: 90  percentage: 0.02197265625  node count: 13  edge count: 15
index_count: 3
thr: 1.792668849508316
graph_4_9/2018-04-09 09:36:02.001305059~2018-04-09 09:51:31.358485271.txt    3.9401310086249968  count: 1168  percentage: 0.04224537037037037  node count: 88  edge count: 106
index_count: 4
thr: 1.9019337936664873
graph_4_9/2018-04-09 09:51:31.358485271~2018-04-09 10:06:48.907525397.txt    3.674581927379024  count: 2362  percentage: 0.051258680555555554  node count: 77  e

In [None]:

# pred_label={}

# files = os.listdir("graph_4_9")
# for f in files:
#     pred_label["graph_4_9/"+f]=0

# files = os.listdir("graph_4_10")
# for f in files:
#     pred_label["graph_4_10/"+f]=0

# files = os.listdir("graph_4_11")
# for f in files:
#     pred_label["graph_4_11/"+f]=0

# files = os.listdir("graph_4_12")
# for f in files:
#     pred_label["graph_4_12/"+f]=0


# For each queue
for hl in history_list:
    # Calculate loss for each queue
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)

    # If loss count is greater than 10 then it is anamalous queue
    if loss_count>10:
#     if loss_count>50:
        name_list=[]
        for i in hl:
            name_list.append(i['name'])
            print(i['name'])
#         print(name_list)
        print(loss_count)

## 4-10

In [33]:
# node_IDF_3=torch.load("node_IDF_4_3")
node_IDF=torch.load("node_IDF_4_3-5")
file_list=[]

file_path="graph_4_10/"
file_l=os.listdir("graph_4_10/")
for i in file_l:
    file_list.append(file_path+i)
    
    
    
# node_IDF_410=torch.load("node_IDF_4_10")
# node_IDF=torch.load("node_IDF_4_12")
y_data_4_10=[]
df_list_4_10=[]
# node_set_list=[]
history_list=[]
tw_que=[]
his_tw={}
current_tw={}



file_path_list=[]


file_path="graph_4_10/"
file_l=os.listdir("graph_4_10/")
for i in file_l:
    file_path_list.append(file_path+i)

# file_path="graph_4_12/"
# file_l=os.listdir("graph_4_12/")
# for i in file_l:
#     file_path_list.append(file_path+i)


index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_4_10.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_4_10_without_neg_edge/")
    current_tw={}
    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list:
        for his_tw in hq:
            if cal_set_rel_bak(current_tw['nodeset'],his_tw['nodeset'],file_list)!=0 and current_tw['name']!=his_tw['name']:
                print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
        if added_que_flag:
            break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list.append(temp_hq)
  
    index_count+=1
#     node_set_list.append(node_set)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))
#     y_data_4_10.append([loss_avg,labels_4_10[f_path],f_path])

index_count: 0
thr: 1.9090705605478997
graph_4_10/2018-04-10 12:44:33.449564893~2018-04-10 13:00:02.700560774.txt    3.5780684604002553  count: 253  percentage: 0.0494140625  node count: 50  edge count: 52
index_count: 1
thr: 1.0730868876354163
graph_4_10/2018-04-10 13:00:02.700560774~2018-04-10 13:16:13.551944728.txt    2.942399242895489  count: 4242  percentage: 0.03540665064102564  node count: 307  edge count: 380
index_count: 2
thr: 1.9097335993665003
graph_4_10/2018-04-10 13:16:13.551944728~2018-04-10 13:31:14.548738409.txt    3.844639180965241  count: 7520  percentage: 0.06385869565217392  node count: 952  edge count: 1660
index_count: 3
thr: 1.848338632776816
graph_4_10/2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223.txt    3.6045949790014675  count: 15709  percentage: 0.06879291619955157  node count: 827  edge count: 1253
index_count: 4
thr: 0.7036397467480755
graph_4_10/2018-04-10 13:46:36.161065223~2018-04-10 14:02:17.001271389.txt    1.6523426022163337  count: 37

In [34]:

pred_label={}

# files = os.listdir("graph_4_9")
# for f in files:
#     pred_label["graph_4_9/"+f]=0

files = os.listdir("graph_4_10")
for f in files:
    pred_label["graph_4_10/"+f]=0

files = os.listdir("graph_4_11")
for f in files:
    pred_label["graph_4_11/"+f]=0

files = os.listdir("graph_4_12")
for f in files:
    pred_label["graph_4_12/"+f]=0


for hl in history_list:
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)

    if loss_count>14:
#     if loss_count>50:
        name_list=[]
        for i in hl:
            name_list.append(i['name'])
            print(i['name'])
#         print(name_list)
        for i in name_list:
            pred_label[i]=1
        print(loss_count)

graph_4_10/2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223.txt
graph_4_10/2018-04-10 14:02:17.001271389~2018-04-10 14:17:34.001373488.txt
graph_4_10/2018-04-10 14:17:34.001373488~2018-04-10 14:33:18.350772859.txt
graph_4_10/2018-04-10 14:33:18.350772859~2018-04-10 14:48:47.320442910.txt
graph_4_10/2018-04-10 15:03:54.307022037~2018-04-10 15:19:25.001773315.txt
681.9766473550785


## 4-11

In [35]:
# node_IDF_3=torch.load("node_IDF_4_3")
node_IDF=torch.load("node_IDF_4_3-5")
file_list=[]

file_path="graph_4_11/"
file_l=os.listdir("graph_4_11/")
for i in file_l:
    file_list.append(file_path+i)
    
    
    
# node_IDF_410=torch.load("node_IDF_4_10")
# node_IDF=torch.load("node_IDF_4_12")
y_data_4_10=[]
df_list_4_10=[]
# node_set_list=[]
history_list=[]
tw_que=[]
his_tw={}
current_tw={}



file_path_list=[]


file_path="graph_4_11/"
file_l=os.listdir("graph_4_11/")
for i in file_l:
    file_path_list.append(file_path+i)

# file_path="graph_4_12/"
# file_l=os.listdir("graph_4_12/")
# for i in file_l:
#     file_path_list.append(file_path+i)


index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_4_10.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_4_11_without_neg_edge/")
    current_tw={}
    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list:
        for his_tw in hq:
            if cal_set_rel_bak(current_tw['nodeset'],his_tw['nodeset'],file_list)!=0 and current_tw['name']!=his_tw['name']:
                print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
        if added_que_flag:
            break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list.append(temp_hq)
  
    index_count+=1
#     node_set_list.append(node_set)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))
#     y_data_4_10.append([loss_avg,labels_4_10[f_path],f_path])

index_count: 0
thr: 2.3507490297673246
graph_4_11/2018-04-11 00:00:00.001151329~2018-04-11 00:16:06.001274623.txt    4.335114860909239  count: 891  percentage: 0.08701171875  node count: 36  edge count: 39
index_count: 1
thr: 0.747704295762974
graph_4_11/2018-04-11 00:16:06.001274623~2018-04-11 00:32:08.001169192.txt    1.6687422044953704  count: 472  percentage: 0.04609375  node count: 86  edge count: 87
index_count: 2
thr: 0.7916274619777288
graph_4_11/2018-04-11 00:32:08.001169192~2018-04-11 00:48:31.001594545.txt    2.0889478795209517  count: 367  percentage: 0.03583984375  node count: 31  edge count: 36
index_count: 3
thr: 0.6937735663359536
graph_4_11/2018-04-11 00:48:31.001594545~2018-04-11 01:04:53.001888013.txt    1.6538311558921852  count: 451  percentage: 0.04404296875  node count: 35  edge count: 38
index_count: 4
thr: 0.6772393588548198
graph_4_11/2018-04-11 01:04:53.001888013~2018-04-11 01:20:59.002307553.txt    1.6539954994365464  count: 424  percentage: 0.04140625  node

In [36]:
for hl in history_list:
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)
    name_list=[]
    if loss_count>25:
#     if loss_count>50:
        name_list=[]
        for i in hl:
            name_list.append(i['name'])
        print(name_list)
#         for i in name_list:
#             pred_label[i]=1
        print(loss_count)

## 4-12

In [37]:
file_list_3_5=[]

file_path="graph_4_3/"
file_l=os.listdir("graph_4_3/")
for i in file_l:
    file_list_3_5.append(file_path+i)

file_path="graph_4_4/"
file_l=os.listdir("graph_4_4/")
for i in file_l:
    file_list_3_5.append(file_path+i)

file_path="graph_4_5/"
file_l=os.listdir("graph_4_5/")
for i in file_l:
    file_list_3_5.append(file_path+i)
    
len(file_list_3_5)

194

In [38]:
len(file_list)

91

In [39]:
node_IDF=torch.load("node_IDF_4_12") 
node_IDF_3=torch.load("node_IDF_4_3-5")
# node_IDF=torch.load("node_IDF_4_3")
file_list=[]

file_path="graph_4_12/"
file_l=os.listdir("graph_4_12/")
for i in file_l:
    file_list.append(file_path+i)
    
# the variable names doesn't change the results.   
y_data_4_10=[]
df_list_4_10=[]
history_list=[]
tw_que=[]
his_tw={}
current_tw={}



file_path_list=[]


file_path="graph_4_12/"
file_l=os.listdir("graph_4_12/")
for i in file_l:
    file_path_list.append(file_path+i)


index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_4_10.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_4_12_without_neg_edge/")
    current_tw={}
    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list:
        for his_tw in hq:
            cal_re = cal_set_rel(current_tw['nodeset'],his_tw['nodeset'],file_list, file_list_3_5)
            if cal_re != 0 and current_tw['name']!=his_tw['name']:
#             if cal_set_rel_bak(current_tw['nodeset'],his_tw['nodeset'],file_l)!=0 and current_tw['name']!=his_tw['name']:
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
        if added_que_flag:
            break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list.append(temp_hq)
    index_count+=1
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))
    
    
    
#     for i in history_list:
#         print(len(i))
#         if len(i) >= 2:
#             for tw in i:
#                 print(tw['name'])
#     input()
    
    

index_count: 0
thr: 2.647840738583332
graph_4_12/2018-04-12 00:00:00.001773757~2018-04-12 00:15:26.001647081.txt    4.3237672770543485  count: 893  percentage: 0.1090087890625  node count: 39  edge count: 41
index_count: 1
thr: 0.9334369701462344
graph_4_12/2018-04-12 00:15:26.001647081~2018-04-12 00:30:58.002002412.txt    1.9500558467585738  count: 431  percentage: 0.0526123046875  node count: 91  edge count: 92
index_count: 2
thr: 0.8316535527632508
graph_4_12/2018-04-12 00:30:58.002002412~2018-04-12 00:46:27.001835122.txt    1.8446585508841036  count: 395  percentage: 0.0482177734375  node count: 35  edge count: 38
index_count: 3
thr: 0.8294326854125519
graph_4_12/2018-04-12 00:46:27.001835122~2018-04-12 01:02:05.698953597.txt    1.948690851374168  count: 384  percentage: 0.046875  node count: 35  edge count: 36
index_count: 4
thr: 0.8378034875841244
graph_4_12/2018-04-12 01:02:05.698953597~2018-04-12 01:17:27.002277528.txt    2.0346812708623894  count: 348  percentage: 0.0424804687

In [None]:
# For each queue
for hl in history_list:
    # Calculate loss for each queue
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)
    name_list=[]

    # If loss count is greater than 9 then it is anamalous queue
    if loss_count>20:
#     if loss_count>50:
        name_list=[]
        for i in hl:
            name_list.append(i['name'])
        print(name_list)
        
        # Set the predicted label for the anamalous queue to 1
        for i in name_list:
            pred_label[i]=1
        print(loss_count)

['graph_4_12/2018-04-12 09:19:03.668714727~2018-04-12 09:34:16.330657912.txt', 'graph_4_12/2018-04-12 12:39:06.592684498~2018-04-12 12:54:44.001888457.txt', 'graph_4_12/2018-04-12 12:54:44.001888457~2018-04-12 13:09:55.026832462.txt', 'graph_4_12/2018-04-12 13:09:55.026832462~2018-04-12 13:25:06.588370709.txt', 'graph_4_12/2018-04-12 13:25:06.588370709~2018-04-12 13:40:07.178206094.txt']
1991.8845109087163


# Evaluation

In [None]:
from sklearn.metrics import average_precision_score, roc_auc_score

from sklearn.metrics import roc_auc_score
import torch
from sklearn import preprocessing
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix

def plot_thr():
    np.seterr(invalid='ignore')
    step=0.01
    thr_list=torch.arange(-5,5,step)
    
    

    precision_list=[]
    recall_list=[]
    fscore_list=[]
    accuracy_list=[]
    auc_val_list=[]
    for thr in thr_list:
        threshold=thr
        y_prediction=[]
        for i in y_test_scores:
            if i >threshold:
                y_prediction.append(1)
            else:
                y_prediction.append(0)
        precision,recall,fscore,accuracy,auc_val=classifier_evaluation(y_test, y_prediction)   
        precision_list.append(float(precision))
        recall_list.append(float(recall))
        fscore_list.append(float(fscore))
        accuracy_list.append(float(accuracy))
        auc_val_list.append(float(auc_val))

    max_fscore=max(fscore_list)
    max_fscore_index=fscore_list.index(max_fscore)
    print(max_fscore_index)
    print("max threshold:",thr_list[max_fscore_index])
    print('precision:',precision_list[max_fscore_index])
    print('recall:',recall_list[max_fscore_index])
    print('fscore:',fscore_list[max_fscore_index])
    print('accuracy:',accuracy_list[max_fscore_index])    
    print('auc:',auc_val_list[max_fscore_index])
    
         
#     precision_list=torch.tensor(precision_list)   
#     recall_list=torch.tensor(recall_list)   
#     fscore_list=torch.tensor(fscore_list)   
#     accuracy_list=torch.tensor(accuracy_list)   
#     auc_val_list=torch.tensor(auc_val_list)   

    


    
    # plt.scatter(attack_x, attack_y, s=20, c='r', label='Attack graph',marker='*')
    # plt.scatter(bengin_x, bengin_y, s=20, c='g', label='Bengin graph',marker='1')
    # plt.scatter(bengin_x, bengin_y, s=20, c='g', label='Bengin graph',marker='1')

    plt.plot(thr_list,precision_list,color='red',label='precision',linewidth=2.0,linestyle='-')
    plt.plot(thr_list,recall_list,color='orange',label='recall',linewidth=2.0,linestyle='solid')
    plt.plot(thr_list,fscore_list,color='y',label='F-score',linewidth=2.0,linestyle='dashed')
    plt.plot(thr_list,accuracy_list,color='g',label='accuracy',linewidth=2.0,linestyle='dashdot')
    plt.plot(thr_list,auc_val_list,color='b',label='auc_val',linewidth=2.0,linestyle='dotted')
    # '-', '--', '-.', ':', 'None', ' ', '', 'solid', 'dashed', 'dashdot', 'dotted'


    # plt.scatter(turnovers, graph_loss, c=color)
    plt.xlabel("Threshold", fontdict={'size': 16})
    plt.ylabel("Rate", fontdict={'size': 16})
    plt.title("Different evaluation Indicators by varying threshold value", fontdict={'size': 12})
    plt.legend(loc='best', fontsize=12, markerscale=0.5)
    plt.show()

# Function to compute classification evaluation metrics
def classifier_evaluation(y_test, y_test_pred):
    """
    Calculate evaluation metrics based on ground truth and predicted labels.

    Parameters:
        y_test (list or array): Ground truth binary labels.
        y_test_pred (list or array): Predicted binary labels.

    Returns:
        tuple: precision, recall, F-score, accuracy, and AUC values.
    """
    tn, fp, fn, tp =confusion_matrix(y_test, y_test_pred).ravel()
    print('tn:',tn)
    print('fp:',fp)
    print('fn:',fn)
    print('tp:',tp)

    # Calculate evaluation metrics
    precision=tp/(tp+fp)
    recall=tp/(tp+fn)
    accuracy=(tp+tn)/(tp+tn+fp+fn)
    fscore=2*(precision*recall)/(precision+recall)    
    auc_val=roc_auc_score(y_test, y_test_pred)

    print("precision:",precision)
    print("recall:",recall)
    print("fscore:",fscore)
    print("accuracy:",accuracy)
    print("auc_val:",auc_val)
    
    return precision,recall,fscore,accuracy,auc_val

# Function to apply Min-Max scaling to a dataset
def minmax(data):
    """
    Apply Min-Max scaling to normalize data to a range [0, 1].

    Parameters:
        data (list or array): Input data to be normalized.

    Returns:
        list: Normalized data.
    """
    min_val=min(data)
    max_val=max(data)
    ans=[]
    for i in data:
        ans.append((i-min_val)/(max_val-min_val))
    return ans



In [None]:
# Create list of truth and predicted values from ground truth and predict dictionaries
y=[]
y_pred=[]
for i in labels:
    y.append(labels[i])
    y_pred.append(pred_label[i])

In [None]:
# Evaluate
classifier_evaluation(y,y_pred)

tn: 216
fp: 2
fn: 1
tp: 8
precision: 0.8
recall: 0.8888888888888888
fscore: 0.8421052631578948
accuracy: 0.986784140969163
auc_val: 0.9398572884811417


(0.8,
 0.8888888888888888,
 0.8421052631578948,
 0.986784140969163,
 0.9398572884811417)

# Count the attack edges numbers

In [44]:
def keyword_hit(line):
    attack_nodes=[
            '/home/admin/clean',
            '/dev/glx_alsa_675',
            '/home/admin/profile',
#             '/var/log/mail',  
            '/tmp/memtrace.so',
            '/var/log/xdev',
             '/var/log/wdev',
            'gtcache',
#             'firefox',
        '161.116.88.72',
        '146.153.68.151',
        '145.199.103.57',
        '61.130.69.232',
        '5.214.163.155',
        '104.228.117.212',
        '141.43.176.203',
        '7.149.198.40',
        '5.214.163.155',
        '149.52.198.23',
        ]
    flag=False
    for i in attack_nodes:
        if i in line:
            flag=True
            break
    return flag



files=[]
temp_file=[
        '2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223.txt',
        '2018-04-10 14:02:17.001271389~2018-04-10 14:17:34.001373488.txt',
        '2018-04-10 14:17:34.001373488~2018-04-10 14:33:18.350772859.txt',
        '2018-04-10 14:33:18.350772859~2018-04-10 14:48:47.320442910.txt',
        '2018-04-10 14:48:47.320442910~2018-04-10 15:03:54.307022037.txt',
]
for f in temp_file:
    files.append("./graph_4_10/"+f)
    
    
temp_file=[
         '2018-04-12 12:39:06.592684498~2018-04-12 12:54:44.001888457.txt',
        '2018-04-12 12:54:44.001888457~2018-04-12 13:09:55.026832462.txt',
        '2018-04-12 13:09:55.026832462~2018-04-12 13:25:06.588370709.txt',
        '2018-04-12 13:25:06.588370709~2018-04-12 13:40:07.178206094.txt',
]    
for f in temp_file:
    files.append("./graph_4_12/"+f)
    
attack_edge_count=0
for fpath in tqdm(files):
    f=open(fpath)
    for line in f:
        if keyword_hit(line):
            attack_edge_count+=1
print(attack_edge_count)

100%|██████████| 9/9 [00:05<00:00,  1.66it/s]

3039





# Visualization

For the provided attack time windows create graphs for visualization

In [None]:
import os

from graphviz import Digraph
import networkx as nx
import datetime
import community.community_louvain as community_louvain
from tqdm import tqdm




# Some common path abstraction for visualization
replace_dic = {
        '/run/shm/':'/run/shm/*',
        #     '/home/admin/.cache/mozilla/firefox/pe11scpa.default/cache2/entries/':'/home/admin/.cache/mozilla/firefox/pe11scpa.default/cache2/entries/*',
        '/home/admin/.cache/mozilla/firefox/':'/home/admin/.cache/mozilla/firefox/*',
        '/home/admin/.mozilla/firefox':'/home/admin/.mozilla/firefox*',
        '/data/replay_logdb/':'/data/replay_logdb/*',
        '/home/admin/.local/share/applications/':'/home/admin/.local/share/applications/*',
        '/usr/share/applications/':'/usr/share/applications/*',
        '/lib/x86_64-linux-gnu/':'/lib/x86_64-linux-gnu/*',
        '/proc/':'/proc/*',
        '/stat':'*/stat',
        '/etc/bash_completion.d/':'/etc/bash_completion.d/*',
        '/usr/bin/python2.7':'/usr/bin/python2.7/*',
        '/usr/lib/python2.7':'/usr/lib/python2.7/*',
}


def replace_path_name(path_name):
    for i in replace_dic:
        if i in path_name:
            return replace_dic[i]
    return path_name


# Users should manually put the detected anomalous time windows here
attack_list = [
        'graph_4_10/2018-04-10 13:31:14.548738409~2018-04-10 13:46:36.161065223.txt',
        'graph_4_10/2018-04-10 14:02:17.001271389~2018-04-10 14:17:34.001373488.txt',
        'graph_4_10/2018-04-10 14:17:34.001373488~2018-04-10 14:33:18.350772859.txt',
        'graph_4_10/2018-04-10 14:33:18.350772859~2018-04-10 14:48:47.320442910.txt',
        'graph_4_10/2018-04-10 14:48:47.320442910~2018-04-10 15:03:54.307022037.txt',
    
        'graph_4_12/2018-04-12 12:39:06.592684498~2018-04-12 12:54:44.001888457.txt', 
        'graph_4_12/2018-04-12 12:54:44.001888457~2018-04-12 13:09:55.026832462.txt', 
        'graph_4_12/2018-04-12 13:09:55.026832462~2018-04-12 13:25:06.588370709.txt', 
        'graph_4_12/2018-04-12 13:25:06.588370709~2018-04-12 13:40:07.178206094.txt'
]

original_edges_count = 0
graphs = []
gg = nx.DiGraph()
count = 0

# Process each anomalous time window file
# Calculate threshold and identify anomalous edge
for path in tqdm(attack_list):
    if ".txt" in path:
        line_count = 0
        node_set = set()
        tempg = nx.DiGraph()
        f = open(path, "r")
        edge_list = []
        
        # Collect edge info
        for line in f:
            count += 1
            l = line.strip()
            jdata = eval(l)
            edge_list.append(jdata)

        edge_list = sorted(edge_list, key=lambda x: x['loss'], reverse=True)
        original_edges_count += len(edge_list)

        loss_list = []
        for i in edge_list:
            loss_list.append(i['loss'])

        loss_mean = mean(loss_list)
        loss_std = std(loss_list)
        print(loss_mean)
        print(loss_std)
        
        # Calculate threshold
        thr = loss_mean + 1.5 * loss_std
        print("thr:", thr)

        # Find anomalous edges
        for e in edge_list:
            if e['loss'] > thr:
                tempg.add_edge(str(hashgen(replace_path_name(e['srcmsg']))),
                               str(hashgen(replace_path_name(e['dstmsg']))))
                gg.add_edge(str(hashgen(replace_path_name(e['srcmsg']))), str(hashgen(replace_path_name(e['dstmsg']))),
                            loss=e['loss'], srcmsg=e['srcmsg'], dstmsg=e['dstmsg'], edge_type=e['edge_type'],
                            time=e['time'])


partition = community_louvain.best_partition(gg.to_undirected())

# Generate the candidate subgraphs based on community discovery results
communities = {}
max_partition = 0
for i in partition:
    if partition[i] > max_partition:
        max_partition = partition[i]
for i in range(max_partition + 1):
    communities[i] = nx.DiGraph()
for e in gg.edges:
    communities[partition[e[0]]].add_edge(e[0], e[1])
    communities[partition[e[1]]].add_edge(e[0], e[1])


# Define the attack nodes. They are **only be used to plot the colors of attack nodes and edges**.
# They won't change the detection results.
def attack_edge_flag(msg):
    attack_nodes = [
        '/home/admin/clean',
        '/dev/glx_alsa_675',
        '/home/admin/profile',
        '/var/log/xdev',
        '/etc/passwd',
        '161.116.88.72',
        '146.153.68.151',
        '/var/log/mail',
        '/tmp/memtrace.so',
        #         '/tmp',
        '/var/log/xdev',
        '/var/log/wdev',
        'gtcache',
        'firefox',
        #         '/var/log',
    ]
    flag = False
    for i in attack_nodes:
        if i in str(msg):
            flag = True
    return flag


# Plot and render candidate subgraph
os.system(f"mkdir -p ./graph_visual/")
graph_index = 0
for c in communities:
    dot = Digraph(name="MyPicture", comment="the test", format="pdf")
    dot.graph_attr['rankdir'] = 'LR'

    for e in communities[c].edges:
        try:
            temp_edge = gg.edges[e]
            srcnode = e['srcnode']
            dstnode = e['dstnode']
        except:
            pass

        if True:
            # source node
            if "'subject': '" in temp_edge['srcmsg']:
                src_shape = 'box'
            elif "'file': '" in temp_edge['srcmsg']:
                src_shape = 'oval'
            elif "'netflow': '" in temp_edge['srcmsg']:
                src_shape = 'diamond'
            if attack_edge_flag(temp_edge['srcmsg']):
                src_node_color = 'red'
            else:
                src_node_color = 'blue'
            dot.node(name=str(hashgen(replace_path_name(temp_edge['srcmsg']))), label=str(
                replace_path_name(temp_edge['srcmsg']) + str(
                    partition[str(hashgen(replace_path_name(temp_edge['srcmsg'])))])), color=src_node_color,
                     shape=src_shape)

            # destination node
            if "'subject': '" in temp_edge['dstmsg']:
                dst_shape = 'box'
            elif "'file': '" in temp_edge['dstmsg']:
                dst_shape = 'oval'
            elif "'netflow': '" in temp_edge['dstmsg']:
                dst_shape = 'diamond'
            if attack_edge_flag(temp_edge['dstmsg']):
                dst_node_color = 'red'
            else:
                dst_node_color = 'blue'
            dot.node(name=str(hashgen(replace_path_name(temp_edge['dstmsg']))), label=str(
                replace_path_name(temp_edge['dstmsg']) + str(
                    partition[str(hashgen(replace_path_name(temp_edge['dstmsg'])))])), color=dst_node_color,
                     shape=dst_shape)

            if attack_edge_flag(temp_edge['srcmsg']) and attack_edge_flag(temp_edge['dstmsg']):
                edge_color = 'red'
            else:
                edge_color = 'blue'
            dot.edge(str(hashgen(replace_path_name(temp_edge['srcmsg']))),
                     str(hashgen(replace_path_name(temp_edge['dstmsg']))), label=temp_edge['edge_type'],
                     color=edge_color)

    dot.render(f'./graph_visual/subgraph_' + str(graph_index), view=False)
    graph_index += 1

  0%|          | 0/9 [00:00<?, ?it/s]

0.39349401528288985
0.9698964116626174
thr: 1.848338632776816


 22%|██▏       | 2/9 [00:05<00:16,  2.29s/it]

0.22399647636335385
0.7824945255043628
thr: 1.397738264619898
0.01529548680480501
0.14745660647336425
thr: 0.2364803965148514


 44%|████▍     | 4/9 [00:53<01:11, 14.30s/it]

0.3234040941077674
0.8278469702270279
thr: 1.5651745494483094


 56%|█████▌    | 5/9 [00:56<00:40, 10.10s/it]

0.28987987909855134
0.7848080091669775
thr: 1.4670918928490175


 67%|██████▋   | 6/9 [00:57<00:21,  7.13s/it]

0.264065730717139
0.7438770927824745
thr: 1.3798813698908508


 78%|███████▊  | 7/9 [00:59<00:10,  5.37s/it]

0.22431252344519415
0.7273518535556521
thr: 1.3153403037786724


 89%|████████▉ | 8/9 [01:00<00:04,  4.23s/it]

2.844655554887345
2.1281240591882953
thr: 6.0368416436697885


100%|██████████| 9/9 [01:03<00:00,  7.06s/it]

0.29558899332005617
0.7688963673759914
thr: 1.4489335443840432



