In [1]:
# encoding=utf-8
import os.path as osp
import os
import copy
import matplotlib.pyplot as plt
import torch
from torch.nn import Linear
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.data import TemporalData
from torch_geometric.datasets import JODIEDataset
from torch_geometric.datasets import ICEWS18
from torch_geometric.nn import TGNMemory, TransformerConv
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models.tgn import (LastNeighborLoader, IdentityMessage, MeanAggregator,
                                           LastAggregator)
from torch_geometric import *
from torch_geometric.utils import negative_sampling
from tqdm import tqdm
import networkx as nx
import numpy as np
import math
import copy
import re
import time
import json
import pandas as pd
from random import choice
import gc
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
# msg structure:      [src_node_feature,edge_attr,dst_node_feature]

# compute the best partition
import datetime
# import community as community_louvain

import xxhash

# Find the edge index which the edge vector is corresponding to
def tensor_find(t,x):
    """
    Find the 1-based row index of the first occurrence of a value in a PyTorch tensor.

    Parameters:
    t (torch.Tensor): The PyTorch tensor to search.
    x (Any): The value to locate within the tensor.

    Returns:
    int: The 1-based row index of the first occurrence of the value in the tensor.
    """
    t_np=t.cpu().numpy()
    idx=np.argwhere(t_np==x)
    return idx[0][0]+1


def std(t):
    """
    Calculate the standard deviation of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the standard deviation.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The standard deviation of the input data.
    """
    t = np.array(t)
    return np.std(t)


def var(t):
    """
    Calculate the variance of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the variance.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The variance of the input data.
    """
    t = np.array(t)
    return np.var(t)


def mean(t):
    """
    Calculate the mean of elements in the input data.

    Parameters:
    t (list, tuple, or ndarray): The input data to compute the mean.
                                 This can be a Python list, tuple, or a NumPy ndarray.

    Returns:
    float: The mean of the input data.
    """
    t = np.array(t)
    return np.mean(t)

def hashgen(l):
    """
    Generate a single hash value from a list of string values.

    Parameters:
    l (list of str): A list of string values, which can represent properties of a node or edge.

    Returns:
    int: A single hashed integer value generated from the input list.
    """
    hasher = xxhash.xxh64()
    for e in l:
        hasher.update(e)
    return hasher.intdigest()


def cal_pos_edges_loss(link_pred_ratio):
    """
    Calculate the loss for positive edges using a binary classification approach.

    Parameters:
    link_pred_ratio (list or Tensor): A list or tensor containing predicted link probabilities.

    Returns:
    Tensor: A tensor containing the computed loss for each prediction.
    """
    loss = []
    for i in link_pred_ratio:
        # Compare predicted values with a target tensor of ones (positive class)
        loss.append(criterion(i, torch.ones(1)))
    return torch.tensor(loss)


def cal_pos_edges_loss_multiclass(link_pred_ratio, labels):
    """
    Calculate the loss for positive edges in a multi-class classification setting.

    Parameters:
    link_pred_ratio (list of Tensors): A list of tensors containing predicted class probabilities for links.
    labels (list of Tensors): A list of tensors containing the ground truth labels for the links.

    Returns:
    Tensor: A tensor containing the computed loss for each prediction.
    """
    loss = []
    for i in range(len(link_pred_ratio)):
        # Compare predicted values with ground truth labels for multi-class classification
        loss.append(criterion(link_pred_ratio[i].reshape(1, -1), labels[i].reshape(-1)))
    return torch.tensor(loss)


def cal_pos_edges_loss_autoencoder(decoded, msg):
    """
    Calculate the loss for positive edges in an autoencoder setting.

    Parameters:
    decoded (list of Tensors): A list of tensors representing the decoded outputs.
    msg (list of Tensors): A list of tensors representing the original messages (inputs to the autoencoder).

    Returns:
    Tensor: A tensor containing the computed loss for each decoded message.
    """
    loss = []
    for i in range(len(decoded)):
        # Compare the decoded outputs with the original messages
        loss.append(criterion(decoded[i].reshape(1, -1), msg[i].reshape(-1)))
    return torch.tensor(loss)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%autosave 120  

Autosaving every 120 seconds


In [3]:
from datetime import datetime, timezone
import time
import pytz
from time import mktime
from datetime import datetime
import time
def ns_time_to_datetime(ns):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00.000000000
    """
    dt = datetime.fromtimestamp(int(ns) // 1000000000)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')
    s += '.' + str(int(int(ns) % 1000000000)).zfill(9)
    return s

def ns_time_to_datetime_US(ns):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00.000000000
    """
    tz = pytz.timezone('US/Eastern')
    dt = pytz.datetime.datetime.fromtimestamp(int(ns) // 1000000000, tz)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')
    s += '.' + str(int(int(ns) % 1000000000)).zfill(9)
    return s

def time_to_datetime_US(s):
    """
    :param ns: int nano timestamp
    :return: datetime   format: 2013-10-10 23:40:00
    """
    tz = pytz.timezone('US/Eastern')
    dt = pytz.datetime.datetime.fromtimestamp(int(s), tz)
    s = dt.strftime('%Y-%m-%d %H:%M:%S')

    return s

def datetime_to_ns_time(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    timeStamp = int(time.mktime(timeArray))
    timeStamp = timeStamp * 1000000000
    return timeStamp

def datetime_to_ns_time_US(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    tz = pytz.timezone('US/Eastern')
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    dt = datetime.fromtimestamp(mktime(timeArray))
    timestamp = tz.localize(dt)
    timestamp = timestamp.timestamp()
    timeStamp = timestamp * 1000000000
    return int(timeStamp)

def datetime_to_timestamp_US(date):
    """
    :param date: str   format: %Y-%m-%d %H:%M:%S   e.g. 2013-10-10 23:40:00
    :return: nano timestamp
    """
    tz = pytz.timezone('US/Eastern')
    timeArray = time.strptime(date, "%Y-%m-%d %H:%M:%S")
    dt = datetime.fromtimestamp(mktime(timeArray))
    timestamp = tz.localize(dt)
    timestamp = timestamp.timestamp()
    timeStamp = timestamp
    return int(timeStamp)

In [4]:
import psycopg2

from psycopg2 import extras as ex
connect = psycopg2.connect(database = 'tc_e5_theia_dataset_db',
                           host = 'localhost',
                           user = 'postgres',
                           password = '123456',
                           port = '5432'
                          )

cur = connect.cursor()

In [5]:
graph_5_8=torch.load("./train_graphs/graph_5_8.TemporalData.simple").to(device=device)
graph_5_9=torch.load("./train_graphs/graph_5_9.TemporalData.simple").to(device=device)


train_data=graph_5_8

In [6]:
# Constructing the map for nodeid to msg
sql="select * from node2id ORDER BY index_id;"
cur.execute(sql)
rows = cur.fetchall()

nodeid2msg={}  # nodeid => msg and node hash => nodeid
for i in rows:
    nodeid2msg[i[0]]=i[-1]
    nodeid2msg[i[-1]]={i[1]:i[2]}  

In [7]:
rel2id={1: 'EVENT_CONNECT',
 'EVENT_CONNECT': 1,
 2: 'EVENT_EXECUTE',
 'EVENT_EXECUTE': 2,
 3: 'EVENT_OPEN',
 'EVENT_OPEN': 3,
 4: 'EVENT_READ',
 'EVENT_READ': 4,
 5: 'EVENT_RECVFROM',
 'EVENT_RECVFROM': 5,
 6: 'EVENT_RECVMSG',
 'EVENT_RECVMSG': 6,
 7: 'EVENT_SENDMSG',
 'EVENT_SENDMSG': 7,
 8: 'EVENT_SENDTO',
 'EVENT_SENDTO': 8,
 9: 'EVENT_WRITE',
 'EVENT_WRITE': 9}

In [8]:
# train_data, val_data, test_data = data.train_val_test_split(val_ratio=0.15, test_ratio=0.15)
# max_node_num = max(torch.cat([data.dst,data.src]))+1
# max_node_num = data.num_nodes+1
max_node_num = 967389  # +1
# min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
min_dst_idx, max_dst_idx = 0, max_node_num
neighbor_loader = LastNeighborLoader(max_node_num, size=20, device=device)

In [9]:
class GraphAttentionEmbedding(torch.nn.Module):
    """
    The GraphAttentionEmbedding class implements a graph attention-based 
    embedding model using a two-layer TransformerConv. 
    It is designed for temporal graph networks, where the embedding considers 
    both node features and edge attributes, including temporal information.

    Parameters:
    - in_channels (int): Input feature dimensionality for nodes.
    - out_channels (int): Output feature dimensionality for nodes.
    - msg_dim (int): Dimensionality of edge message features.
    - time_enc (TimeEncoder): Time encoding module.
    """
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super(GraphAttentionEmbedding, self).__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels

        # First TransformerConv layer with 8 heads
        self.conv = TransformerConv(in_channels, out_channels, heads=8,
                                    dropout=0.0, edge_dim=edge_dim)
        
        # Second TransformerConv layer with 1 head, no concatenation
        self.conv2 = TransformerConv(out_channels*8, out_channels,heads=1, concat=False,
                             dropout=0.0, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        """
        Forward pass of the GraphAttentionEmbedding model.

        Parameters:
        - x (torch.Tensor): Node features of shape (num_nodes, in_channels).
        - last_update (torch.Tensor): Timestamps of last updates for each node.
        - edge_index (torch.Tensor): Edge index in COO format of shape (2, num_edges).
        - t (torch.Tensor): Edge timestamps of shape (num_edges,).
        - msg (torch.Tensor): Edge message features of shape (num_edges, msg_dim).

        Returns:
        - torch.Tensor: Updated node embeddings of shape (num_nodes, out_channels).
        """
        last_update.to(device)
        x = x.to(device)
        t = t.to(device)

        # Compute relative time differences and encode them
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))

        # Concatenate relative time encoding and message features as edge attributes
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)

        # Apply the first TransformerConv layer with ReLU activation
        x = F.relu(self.conv(x, edge_index, edge_attr))

        # Apply the second TransformerConv layer with ReLU activation
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        return x


class LinkPredictor(torch.nn.Module):
    """
    A neural network for Event Type pediction tasks, transforming node embeddings
    and predicting link properties.

    Parameters:
    - in_channels (int): The dimensionality of the input node embeddings.
    """

    def __init__(self, in_channels):
        super(LinkPredictor, self).__init__()
        
        # Linear transformations for source and destination embeddings
        self.lin_src = Linear(in_channels, in_channels*2)
        self.lin_dst = Linear(in_channels, in_channels*2)
        
        # Sequential feedforward network for link prediction
        self.lin_seq = nn.Sequential(
            Linear(in_channels*4, in_channels*8),
            torch.nn.Dropout(0.5),
            nn.Tanh(),
            Linear(in_channels*8, in_channels*2),
            torch.nn.Dropout(0.5),
            nn.Tanh(),
            Linear(in_channels*2, int(in_channels//2)),
            torch.nn.Dropout(0.5),
            nn.Tanh(),

            # Final prediction layer
            # train_data.msg contains feature vector of src and dest and edge (event) information.
            # train_data.msg.shape[1]-32 is the size of the edge (event) info
            Linear(int(in_channels//2), train_data.msg.shape[1]-32)                   
        )
        

    def forward(self, z_src, z_dst):
        """
        Forward pass for the LinkPredictor.

        Parameters:
        - z_src (torch.Tensor): Source node embeddings, shape (batch_size, in_channels).
        - z_dst (torch.Tensor): Destination node embeddings, shape (batch_size, in_channels).

        Returns:
        - torch.Tensor: Predicted Edge (event) properties, shape (batch_size, train_data.msg.shape[1] - 32).
        """
        h = torch.cat([self.lin_src(z_src) , self.lin_dst(z_dst)],dim=-1)      
         
        h = self.lin_seq (h)
        
        return h

memory_dim = 100         # node state
time_dim = 100
embedding_dim = 100      # edge embedding

# create the memory
memory = TGNMemory(
    max_node_num,
    train_data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(train_data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

# create the graph neural network with transformer
gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=train_data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

# create the MLP link predictor
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

# set Adam optimizer for all 3 networks
optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(link_pred.parameters()), lr=0.00005, eps=1e-08,weight_decay=0.01)


# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(max_node_num, dtype=torch.long, device=device)

saved_nodes=set()

In [10]:
BATCH=1024

def train(train_data):
    """
    Trains the Temporal Graph Network (TGN) using sequential batches of data.

    Parameters:
    - train_data: Dataset object with sequential batches, timestamps, and messages.

    Returns:
    - float: Average loss over all events in the training data.
    """
    
    # Set networks to training mode
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    saved_nodes=set()

    # Tracks total loss over the training data
    total_loss = 0
    
    # Process each batch in the training dataset
    for batch in train_data.seq_batches(batch_size=BATCH):
        # Reset gradients
        optimizer.zero_grad()

        # Extract batch data: source, destination, timestamps, and messages
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg        
        
        # Retrieve unique nodes involved in the batch and their neighbors
        n_id = torch.cat([src, pos_dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)

        # Helper vector to map global node indices to local ones.
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get updated memory of all nodes involved in the computation.
        z, last_update = memory(n_id)
        
        # Pass embeddings through the GNN to update them
        z = gnn(z, last_update, edge_index, train_data.t[e_id], train_data.msg[e_id])
        
        # Compute predictions for positive edges
        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])       

         # Concatenate predictions
        y_pred = torch.cat([pos_out], dim=0)
        
        # Ground-truth label generation using tensor_find
        # msg contains feature vector of src and dest and edge (event) information.
        # extracting event info from them
        y_true=[]
        for m in msg:
            l=tensor_find(m[16:-16],1)-1
            y_true.append(l)           
          
        y_true = torch.tensor(y_true).to(device=device)
        y_true=y_true.reshape(-1).to(torch.long).to(device=device)
        
        # Compute loss using predictions and ground-truth labels
        loss = criterion(y_pred, y_true)
        
        # Update memory and neighbor loader with ground-truth state.
        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        
        # Backpropagation
        loss.backward()
        optimizer.step()

        # Detach memory to free computation graph
        memory.detach()

        # Accumulate total loss scaled by the number of events in the batch
        total_loss += float(loss) * batch.num_events

    # Return the average loss over all events
    return total_loss / train_data.num_events

In [11]:
# train on benign graphs
train_graphs=[graph_5_8,graph_5_9]

# train fro 30 epochs
for epoch in tqdm(range(1, 31)):
    for g in train_graphs:
        loss = train(g)
        print(f'  Epoch: {epoch:02d}, Loss: {loss:.4f}')

# store the models in file
model=[memory,gnn, link_pred,neighbor_loader]
os.system("mkdir -p ./models/")
torch.save(model,"./models/model_saved_share.pt")

  0%|          | 0/30 [00:00<?, ?it/s]

  Epoch: 01, Loss: 0.4587


  3%|▎         | 1/30 [1:04:52<31:21:11, 3892.12s/it]

  Epoch: 01, Loss: 0.2003
  Epoch: 02, Loss: 0.2427


  7%|▋         | 2/30 [2:15:27<31:50:40, 4094.31s/it]

  Epoch: 02, Loss: 0.1733
  Epoch: 03, Loss: 0.2272


 10%|█         | 3/30 [3:37:47<33:36:06, 4480.24s/it]

  Epoch: 03, Loss: 0.1668
  Epoch: 04, Loss: 0.2313


 13%|█▎        | 4/30 [5:26:57<38:15:32, 5297.41s/it]

  Epoch: 04, Loss: 0.1602
  Epoch: 05, Loss: 0.2200


 17%|█▋        | 5/30 [6:53:44<36:33:40, 5264.83s/it]

  Epoch: 05, Loss: 0.1588
  Epoch: 06, Loss: 0.2219


 20%|██        | 6/30 [8:36:29<37:08:18, 5570.78s/it]

  Epoch: 06, Loss: 0.1575
  Epoch: 07, Loss: 0.2210


 23%|██▎       | 7/30 [10:11:47<35:53:56, 5618.97s/it]

  Epoch: 07, Loss: 0.1571
  Epoch: 08, Loss: 0.2159


 27%|██▋       | 8/30 [12:10:20<37:14:44, 6094.74s/it]

  Epoch: 08, Loss: 0.1558
  Epoch: 09, Loss: 0.2160


 30%|███       | 9/30 [13:54:48<35:52:05, 6148.82s/it]

  Epoch: 09, Loss: 0.1544
  Epoch: 10, Loss: 0.2238


 33%|███▎      | 10/30 [15:37:35<34:11:29, 6154.45s/it]

  Epoch: 10, Loss: 0.1536
  Epoch: 11, Loss: 0.2171


 37%|███▋      | 11/30 [17:26:27<33:05:26, 6269.84s/it]

  Epoch: 11, Loss: 0.1534
  Epoch: 12, Loss: 0.2129


 40%|████      | 12/30 [19:02:52<30:36:43, 6122.42s/it]

  Epoch: 12, Loss: 0.1521
  Epoch: 13, Loss: 0.2180


 43%|████▎     | 13/30 [21:00:07<30:13:01, 6398.91s/it]

  Epoch: 13, Loss: 0.1522
  Epoch: 14, Loss: 0.2136


 47%|████▋     | 14/30 [22:41:10<27:59:17, 6297.34s/it]

  Epoch: 14, Loss: 0.1515
  Epoch: 15, Loss: 0.2103


 50%|█████     | 15/30 [24:30:50<26:35:39, 6382.67s/it]

  Epoch: 15, Loss: 0.1520
  Epoch: 16, Loss: 0.2139


 53%|█████▎    | 16/30 [26:26:30<25:28:24, 6550.34s/it]

  Epoch: 16, Loss: 0.1512
  Epoch: 17, Loss: 0.2129


 57%|█████▋    | 17/30 [28:27:10<24:24:11, 6757.83s/it]

  Epoch: 17, Loss: 0.1516
  Epoch: 18, Loss: 0.2127


 60%|██████    | 18/30 [30:12:36<22:05:37, 6628.14s/it]

  Epoch: 18, Loss: 0.1508
  Epoch: 19, Loss: 0.2108


 63%|██████▎   | 19/30 [31:53:25<19:43:13, 6453.97s/it]

  Epoch: 19, Loss: 0.1506
  Epoch: 20, Loss: 0.2080


 67%|██████▋   | 20/30 [33:31:19<17:26:40, 6280.01s/it]

  Epoch: 20, Loss: 0.1505
  Epoch: 21, Loss: 0.2098


 70%|███████   | 21/30 [35:25:10<16:06:46, 6445.19s/it]

  Epoch: 21, Loss: 0.1503
  Epoch: 22, Loss: 0.2095


 73%|███████▎  | 22/30 [37:15:00<14:25:09, 6488.74s/it]

  Epoch: 22, Loss: 0.1504
  Epoch: 23, Loss: 0.2095


 77%|███████▋  | 23/30 [39:02:24<12:35:27, 6475.29s/it]

  Epoch: 23, Loss: 0.1499
  Epoch: 24, Loss: 0.2106


 80%|████████  | 24/30 [40:48:22<10:44:01, 6440.23s/it]

  Epoch: 24, Loss: 0.1498
  Epoch: 25, Loss: 0.2169


 83%|████████▎ | 25/30 [42:49:15<9:16:59, 6683.88s/it] 

  Epoch: 25, Loss: 0.1502


 83%|████████▎ | 25/30 [43:00:05<8:36:01, 6192.22s/it]


KeyboardInterrupt: 

In [12]:
model=[memory,gnn, link_pred,neighbor_loader]
os.system("mkdir -p ./models/")
torch.save(model,"./models/model_saved_share.pt")

# Test

In [13]:
import time 

# Disable gradient computation for inference
@torch.no_grad()

def test_day_new(inference_data, path):
    """
    Evaluate the Temporal Graph Network (TGN) on inference data.
    Stores result for each time window in txt file.
    Each text file contains event information. 
    Example: {'loss': 2.8073678016662598, 'srcnode': 76938, 'dstnode': 868, 'srcmsg': "{'subject': 'system_server'}", 'dstmsg': "{'file': '/data/system/sync/pending.xml'}", 'edge_type': 'EVENT_OPEN', 'time': 1522814400030000000}

    Parameters:
    - inference_data: Dataset containing sequential batches for temporal graph inference.
    - path (str): Path to save logs and results.

    Returns:
    - dict: Summary of results, including loss and timing for each time window.
    """

    if os.path.exists(path):
        pass
    else:
        os.mkdir(path)
    
    # Set networks to evaluation mode
    memory.eval()
    gnn.eval()
    link_pred.eval()
    
    memory.reset_state()  # Start with a fresh memory.  
    neighbor_loader.reset_state()  # Start with an empty graph.
    
    # Initialize tracking variables
    time_with_loss={}
    total_loss = 0    
    edge_list=[]
    
    unique_nodes=torch.tensor([]).to(device=device)
    total_edges=0


    start_time=inference_data.t[0]
    event_count=0
    
    pos_o=[]
    loss_list=[]

    print("after merge:",inference_data)
    
    # Record the running time to evaluate the performance
    start = time.perf_counter()

    for batch in inference_data.seq_batches(batch_size=BATCH):
        
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
        unique_nodes=torch.cat([unique_nodes,src,pos_dst]).unique()
        total_edges+=BATCH
        
        # Retrieve embeddings and neighbors
        n_id = torch.cat([src, pos_dst]).unique()       
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get memory and update embeddings via GNN
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, inference_data.t[e_id], inference_data.msg[e_id])

        # Predict edge properties
        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
        pos_o.append(pos_out)
        y_pred = torch.cat([pos_out], dim=0)

        # Generate ground-truth labels
        y_true=[]
        for m in msg:
            # Extract label index from the message
            l=tensor_find(m[16:-16],1)-1
            y_true.append(l) 
        y_true = torch.tensor(y_true).to(device=device)
        y_true=y_true.reshape(-1).to(torch.long).to(device=device)

        # Only consider which edge hasn't been correctly predicted.
        # For benign graphs, the behaviors patterns are similar and therefore their losses are small
        # For anoamlous behaviors, some behaviors might not be seen before, so the probability of predicting those edges are low. Thus their losses are high.
        loss = criterion(y_pred, y_true)
        total_loss += float(loss) * batch.num_events
        
        # update the edges in the batch to the memory and neighbor_loader
        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)
        
        # compute the loss for each edge
        each_edge_loss= cal_pos_edges_loss_multiclass(pos_out,y_true)
        
        # Process edges for logging
        for i in range(len(pos_out)):
            srcnode=int(src[i])
            dstnode=int(pos_dst[i])  
            
            srcmsg=str(nodeid2msg[srcnode]) 
            dstmsg=str(nodeid2msg[dstnode])
            t_var=int(t[i])
            edgeindex=tensor_find(msg[i][16:-16],1)   
            edge_type=rel2id[edgeindex]
            loss=each_edge_loss[i]    

            temp_dic={}
            temp_dic['loss']=float(loss)
            temp_dic['srcnode']=srcnode
            temp_dic['dstnode']=dstnode
            temp_dic['srcmsg']=srcmsg
            temp_dic['dstmsg']=dstmsg
            temp_dic['edge_type']=edge_type
            temp_dic['time']=t_var
            
            edge_list.append(temp_dic)
        
        # Check if the time interval is over (15 minutes - Window Size)
        event_count+=len(batch.src)
        if t[-1]>start_time+60000000000*15:
            # Here is a checkpoint, which records all edge losses in the current time window
            time_interval=ns_time_to_datetime_US(start_time)+"~"+ns_time_to_datetime_US(t[-1])

            end = time.perf_counter()
            time_with_loss[time_interval]={'loss':loss,
                                          'nodes_count':len(unique_nodes),
                                          'total_edges':total_edges,
                                          'costed_time':(end-start)}
            
            
            log=open(path+"/"+time_interval+".txt",'w')
            
            # Compute average loss for the interval
            for e in edge_list:
                loss+=e['loss']
            loss=loss/event_count   

            # Save results to log file
            print(f'Time: {time_interval}, Loss: {loss:.4f}, Nodes_count: {len(unique_nodes)}, Cost Time: {(end-start):.2f}s')
            edge_list = sorted(edge_list, key=lambda x:x['loss'],reverse=True)  # Rank the results based on edge losses
            for e in edge_list: 
                log.write(str(e))
                log.write("\n") 
            
            # Reset tracking variables for the next interval
            event_count=0
            total_loss=0
            loss=0
            start_time=t[-1]
            log.close()
            edge_list.clear()

    return time_with_loss

In [14]:
graph_5_11=torch.load("./train_graphs/graph_5_11.TemporalData.simple").to(device=device)
graph_5_14=torch.load("./train_graphs/graph_5_14.TemporalData.simple").to(device=device)
graph_5_15=torch.load("./train_graphs/graph_5_15.TemporalData.simple").to(device=device)

In [None]:
model=torch.load("./models/model_saved_share.pt")
memory,gnn, link_pred,neighbor_loader=model

In [17]:
# Test the data for 2019-05-08, create and store time window with loss
ans_5_8=test_day_new(graph_5_8,"graph_5_8")

after merge: TemporalData(dst=[9622633], msg=[9622633, 41], src=[9622633], t=[9622633])
Time: 2019-05-08 00:00:01.330026026~2019-05-08 00:15:29.528868509, Loss: 0.2662, Nodes_count: 2356, Cost Time: 10.62s
Time: 2019-05-08 00:15:29.528868509~2019-05-08 00:34:01.494985889, Loss: 0.1486, Nodes_count: 82297, Cost Time: 40.50s
Time: 2019-05-08 00:34:01.494985889~2019-05-08 00:49:53.634712131, Loss: 0.1085, Nodes_count: 96517, Cost Time: 82.49s
Time: 2019-05-08 00:49:53.634712131~2019-05-08 01:05:30.922847747, Loss: 0.0524, Nodes_count: 96653, Cost Time: 124.98s
Time: 2019-05-08 01:05:30.922847747~2019-05-08 01:20:31.497361112, Loss: 0.2423, Nodes_count: 96737, Cost Time: 134.82s
Time: 2019-05-08 01:20:31.497361112~2019-05-08 01:35:41.695160354, Loss: 0.1922, Nodes_count: 96837, Cost Time: 146.28s
Time: 2019-05-08 01:35:41.695160354~2019-05-08 01:53:50.594670267, Loss: 0.0796, Nodes_count: 97677, Cost Time: 220.94s
Time: 2019-05-08 01:53:50.594670267~2019-05-08 02:09:43.350837131, Loss: 0.2

In [18]:

ans_5_9=test_day_new(graph_5_9,"graph_5_9")

after merge: TemporalData(dst=[6898178], msg=[6898178, 41], src=[6898178], t=[6898178])
Time: 2019-05-09 00:00:00.207323955~2019-05-09 00:15:01.493243668, Loss: 0.3950, Nodes_count: 73, Cost Time: 0.27s
Time: 2019-05-09 00:15:01.493243668~2019-05-09 00:32:01.509088277, Loss: 0.1783, Nodes_count: 124, Cost Time: 0.70s
Time: 2019-05-09 00:32:01.509088277~2019-05-09 00:50:31.512334564, Loss: 0.2014, Nodes_count: 138, Cost Time: 1.17s
Time: 2019-05-09 00:50:31.512334564~2019-05-09 01:09:31.519558739, Loss: 0.2061, Nodes_count: 152, Cost Time: 1.63s
Time: 2019-05-09 01:09:31.519558739~2019-05-09 01:28:01.511221355, Loss: 0.1224, Nodes_count: 185, Cost Time: 2.09s
Time: 2019-05-09 01:28:01.511221355~2019-05-09 01:43:31.496122592, Loss: 0.1725, Nodes_count: 193, Cost Time: 2.46s
Time: 2019-05-09 01:43:31.496122592~2019-05-09 02:02:31.501625321, Loss: 0.1325, Nodes_count: 204, Cost Time: 2.93s
Time: 2019-05-09 02:02:31.501625321~2019-05-09 02:21:01.508194986, Loss: 0.1382, Nodes_count: 210, Co

In [19]:
ans_5_11=test_day_new(graph_5_11,"graph_5_11")

after merge: TemporalData(dst=[6488182], msg=[6488182, 41], src=[6488182], t=[6488182])
Time: 2019-05-11 00:00:00.500131269~2019-05-11 00:15:10.585413361, Loss: 0.4813, Nodes_count: 1858, Cost Time: 7.68s
Time: 2019-05-11 00:15:10.585413361~2019-05-11 00:31:01.430200716, Loss: 0.4722, Nodes_count: 2315, Cost Time: 19.29s
Time: 2019-05-11 00:31:01.430200716~2019-05-11 00:46:29.482831031, Loss: 0.3280, Nodes_count: 47459, Cost Time: 44.37s
Time: 2019-05-11 00:46:29.482831031~2019-05-11 01:01:37.064383098, Loss: 0.4018, Nodes_count: 47932, Cost Time: 58.17s
Time: 2019-05-11 01:01:37.064383098~2019-05-11 01:16:39.942709627, Loss: 0.3520, Nodes_count: 53616, Cost Time: 77.04s
Time: 2019-05-11 01:16:39.942709627~2019-05-11 01:32:01.462298208, Loss: 0.3217, Nodes_count: 54094, Cost Time: 103.11s
Time: 2019-05-11 01:32:01.462298208~2019-05-11 01:47:38.109830109, Loss: 0.2666, Nodes_count: 59005, Cost Time: 126.45s
Time: 2019-05-11 01:47:38.109830109~2019-05-11 02:03:11.829718951, Loss: 0.6230,

In [20]:
ans_5_14=test_day_new(graph_5_14,"graph_5_14")

after merge: TemporalData(dst=[13591537], msg=[13591537, 41], src=[13591537], t=[13591537])
Time: 2019-05-14 00:00:00.216652068~2019-05-14 00:15:02.152576344, Loss: 0.2322, Nodes_count: 115823, Cost Time: 26.00s
Time: 2019-05-14 00:15:02.152576344~2019-05-14 00:30:52.456495816, Loss: 0.3641, Nodes_count: 116237, Cost Time: 45.00s
Time: 2019-05-14 00:30:52.456495816~2019-05-14 00:46:31.488675323, Loss: 0.2998, Nodes_count: 116665, Cost Time: 68.78s
Time: 2019-05-14 00:46:31.488675323~2019-05-14 01:02:22.766494572, Loss: 0.1727, Nodes_count: 117163, Cost Time: 99.47s
Time: 2019-05-14 01:02:22.766494572~2019-05-14 01:18:47.800927242, Loss: 0.1523, Nodes_count: 117448, Cost Time: 134.35s
Time: 2019-05-14 01:18:47.800927242~2019-05-14 01:33:57.109874818, Loss: 0.3200, Nodes_count: 117893, Cost Time: 160.91s
Time: 2019-05-14 01:33:57.109874818~2019-05-14 01:49:42.711286391, Loss: 0.1248, Nodes_count: 138660, Cost Time: 235.21s
Time: 2019-05-14 01:49:42.711286391~2019-05-14 02:04:42.847148244

In [21]:
ans_5_15=test_day_new(graph_5_15,"graph_5_15")

after merge: TemporalData(dst=[12310324], msg=[12310324, 41], src=[12310324], t=[12310324])
Time: 2019-05-15 00:00:01.490408727~2019-05-15 00:16:14.833595653, Loss: 0.1279, Nodes_count: 1544, Cost Time: 15.97s
Time: 2019-05-15 00:16:14.833595653~2019-05-15 00:32:01.492056162, Loss: 0.1570, Nodes_count: 2670, Cost Time: 34.95s
Time: 2019-05-15 00:32:01.492056162~2019-05-15 00:47:15.554515213, Loss: 0.2672, Nodes_count: 2893, Cost Time: 42.58s
Time: 2019-05-15 00:47:15.554515213~2019-05-15 01:04:31.491761640, Loss: 0.2826, Nodes_count: 3099, Cost Time: 48.01s
Time: 2019-05-15 01:04:31.491761640~2019-05-15 01:20:01.492131631, Loss: 0.1450, Nodes_count: 3276, Cost Time: 57.79s
Time: 2019-05-15 01:20:01.492131631~2019-05-15 01:36:01.495126517, Loss: 0.1847, Nodes_count: 3536, Cost Time: 72.35s
Time: 2019-05-15 01:36:01.495126517~2019-05-15 01:51:02.896532154, Loss: 0.2568, Nodes_count: 3830, Cost Time: 86.79s
Time: 2019-05-15 01:51:02.896532154~2019-05-15 02:07:19.144870551, Loss: 0.1857, N

# Initialize the node IDF

Calculates the Inverse Document Frequencey of each nodes 

IDF_of_node_x = log(nummber_of_time_widows / (1 + nuber_of_time_window_that_has_node_x))

In [22]:
# Stores the time window for each node
node_set=set()

# list of all time windows. Times windows are generated and stored as files using test_day_new()
file_list=[]

file_path="graph_5_9/"
file_l=os.listdir("graph_5_9/")
for i in file_l:
    file_list.append(file_path+i)


node_IDF={}
node_set = {}
for f_path in tqdm(file_list):
    f=open(f_path)
    for line in f:
        l=line.strip()
        jdata=eval(l)
        jdata=eval(l)
        if jdata['loss']>0:
            if 'netflow' not in str(jdata['srcmsg']):
                if str(jdata['srcmsg']) not in node_set.keys():
                    node_set[str(jdata['srcmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['srcmsg'])].add(f_path)
            if 'netflow' not in str(jdata['dstmsg']):
                if str(jdata['dstmsg']) not in node_set.keys():
                    node_set[str(jdata['dstmsg'])] = set([f_path])
                else:
                    node_set[str(jdata['dstmsg'])].add(f_path)
for n in node_set:
    include_count = len(node_set[n])   
    IDF=math.log(len(file_list)/(include_count+1))
    node_IDF[n] = IDF    


torch.save(node_IDF,"node_IDF")
print("IDF weight calculate complete!")


100%|██████████| 85/85 [04:02<00:00,  2.86s/it]


IDF weight calculate complete!


In [23]:
def cal_train_IDF(find_str,file_list):
    """
    Calculate the IDF value for a term in a set of time windows.

    Parameters:
        find_str (str): The term to calculate the IDF for.
        file_list (list): List of file paths representing the time windows.

    Returns:
        float: The computed IDF value.
    """
    include_count=0
    for f_path in (file_list):
        f=open(f_path)
        if find_str in f.read():
            include_count+=1             
    IDF=math.log(len(file_list)/(include_count+1))
    return IDF


def cal_IDF(find_str,file_path,file_list):
    file_list=os.listdir(file_path)
    include_count=0
    different_neighbor=set()
    for f_path in (file_list):
        f=open(file_path+f_path)
        if find_str in f.read():
            include_count+=1                
                
    IDF=math.log(len(file_list)/(include_count+1))
    
    return IDF,1

def cal_redundant(find_str,edge_list):
    
    different_neighbor=set()
    for e in edge_list:
        if find_str in str(e):
            different_neighbor.add(e[0])
            different_neighbor.add(e[1])
    return len(different_neighbor)-2

def cal_anomaly_loss(loss_list,edge_list,file_path):
    """
    Calculate anomaly loss by analyzing loss values exceeding a threshold.

    Parameters:
        loss_list (list): List of loss values for each edge.
        edge_list (list): List of edges corresponding to the loss values.
        file_path (str): Path to data files (currently unused).

    Returns:
        tuple: A tuple containing:
            - count (int): Number of anomalies detected.
            - avg_loss (float): Average loss of anomalous edges.
            - node_set (set): Set of unique nodes involved in anomalies.
            - edge_set (set): Set of unique edges with anomalies.
    """
    if len(loss_list)!=len(edge_list):
        print("error!")
        return 0
    count = 0  # Count of anomalies
    loss_sum = 0  # Sum of anomalous loss values
    loss_std = std(loss_list)  # Standard deviation of loss values
    loss_mean = mean(loss_list)  # Mean of loss values
    edge_set = set()  # Unique edges with anomalies
    node_set = set()  # Unique nodes involved in anomalies
    
    thr = loss_mean + 1.5 * loss_std  # Threshold for anomaly detection

    print("thr:",thr)

    for i in range(len(loss_list)):
        # Check if loss exceeds the threshold. If exceeds then anomalous node.
        if loss_list[i]>thr:
            count+=1
            src_node=edge_list[i][0]
            dst_node=edge_list[i][1]
            
            loss_sum+=loss_list[i]
    
            node_set.add(src_node)
            node_set.add(dst_node)
            edge_set.add(edge_list[i][0]+edge_list[i][1])
    return count, loss_sum/count,node_set,edge_set

# Construct the relations between time windows

In [24]:
def is_include_key_word(s):
    """
    Check if a given string includes any predefined keywords.

    Parameters:
        s (str): The string to check.

    Returns:
        bool: True if any keyword is found in the string, False otherwise.
    """
    keywords=[
         'netflow',
        '/dev/pts',
         'proc',
      ]
    flag=False
    for i in keywords:
        if i in s:
            flag=True
    return flag


def cal_set_rel(s1,s2,file_list):
    """
    Calculate the relevance of the intersection of two sets based on IDF values.

    Parameters:
        s1 (set): The first set of elements.
        s2 (set): The second set of elements.
        node_IDF (dict): Dictionary containing precomputed IDF values for nodes.
        file_list (list): List of files to calculate default IDF values if a node is missing in node_IDF.

    Returns:
        int: The count of elements in the intersection of s1 and s2 that pass the IDF threshold.
    """
    # Intersection of the two sets
    new_s=s1 & s2
    count=0
    for i in new_s:
        # Skip processing if the element includes some specific keywords
        if is_include_key_word(i) is not True:
            # Fetch the IDF value from the dictionary or calculate a default value
            if i in node_IDF.keys():
                IDF=node_IDF[i]
            else:
                IDF=math.log(len(file_list)/(1))

            # Count the node if IDF is above the threshold
            if (IDF)>4.5 :
                print("node:",i," IDF:",IDF)
                count+=1
    return count

# label generation

In [25]:
# Store ground truth labels for each time window
# Initialize all truth value as 0 (benign)
labels={}
pred_label={}    
    
filelist = os.listdir("graph_5_14")
for f in filelist:
    labels["graph_5_14/"+f]=0
    pred_label["graph_5_14/"+f]=0

filelist = os.listdir("graph_5_15")
for f in filelist:
    labels["graph_5_15/"+f]=0
    pred_label["graph_5_15/"+f]=0

In [26]:
# 2 attack time window. Set thei truth value to 1
attack_list=[
    'graph_5_15/2019-05-15 13:58:15.520482252~2019-05-15 14:13:37.257086895.txt',
    'graph_5_15/2019-05-15 14:44:51.773840192~2019-05-15 15:00:26.765466538.txt',
]

for i in attack_list:
    labels[i]=1

In [27]:
print(f"Benign count: {len(labels.values()) - sum(labels.values())}")
print(f"Attack count: {sum(labels.values())}")

Benign count: 174
Attack count: 2


# Anomaly Detection
Steps:

1. Create a list of time window queues
2. Process all time window serially
3. For a time window calculate relevance with each of the time window in the time window queues
4. If the relevance passes a certain threshold then push it in the corresponding time window queue and move on to the next queue to check
5. If no relevance is found then create a new queue with the time window and store it

5-11 to 5-15 follows the same steps. Proper comment has been added to 5-11 only.

## 5-11

In [28]:
# Variable names don't change the results

# node_IDF=torch.load("node_IDF_5_9")
y_data_5_14=[]
df_list_5_14=[]
# node_set_list=[]

# list of time window queues
# It is a list (collection of queues) of list (queue) of dictionary (time window)
history_list_5_14=[]

tw_que=[]
his_tw={}

# Stores data for the currentl procesing time window
current_tw={}
loss_list_5_14=[]


file_path_list=[]
file_path="graph_5_11/"
file_l=os.listdir("graph_5_11/")
for i in file_l:
    file_path_list.append(file_path+i)
    
    
index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)

    # List to store loss values for edges
    edge_loss_list=[]

    # List to store edges (source-destination pairs)
    edge_list=[]

    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_5_14.append(pd.DataFrame(edge_loss_list))
    
    # Calculate anomaly loss metrics for the current time window
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_5_14/")
    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    # To check if the current time window is related to any historical time window
    added_que_flag=False
    
    # For each queues 
    for hq in history_list_5_14:
        # For each time window in a queue
        for his_tw in hq:
            # Calculate relvance between two time windows and if related push the current time window in the current queue and move on to check the next queue
            if cal_set_rel(current_tw['nodeset'],his_tw['nodeset'],file_list)!=0 and current_tw['name']!=his_tw['name']:
                print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
            if added_que_flag:
                break
            
    # If not time window on any of the queues is similar, create a new queue with the current time window and add to the list
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list_5_14.append(temp_hq)
    index_count+=1
    loss_list_5_14.append(loss_avg)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))

index_count: 0
thr: 1.9822804353454408
graph_5_11/2019-05-11 00:00:00.500131269~2019-05-11 00:15:10.585413361.txt    3.413261017214172  count: 5223  percentage: 0.0864506091101695  node count: 143  edge count: 167
index_count: 1
thr: 2.0174402361424812
graph_5_11/2019-05-11 00:15:10.585413361~2019-05-11 00:31:01.430200716.txt    3.4539998442285835  count: 5515  percentage: 0.08686680947580645  node count: 87  edge count: 101
index_count: 2
thr: 1.429723559543546
graph_5_11/2019-05-11 00:31:01.430200716~2019-05-11 00:46:29.482831031.txt    3.260913855585701  count: 6468  percentage: 0.04448173415492958  node count: 129  edge count: 152
index_count: 3
thr: 1.8731082378571418
graph_5_11/2019-05-11 00:46:29.482831031~2019-05-11 01:01:37.064383098.txt    3.344380599554258  count: 4948  percentage: 0.08477247807017543  node count: 86  edge count: 98
index_count: 4
thr: 1.5793834584651942
graph_5_11/2019-05-11 01:01:37.064383098~2019-05-11 01:16:39.942709627.txt    3.28695102075878  count: 53

In [29]:
# Store the anomalous time window file path name
name_list=[]

# For each queue
for hl in history_list_5_14:
    # Calculate loss for each queue
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)

    # If loss count is greater than 9 then it is anamalous queue
    if loss_count>5:
        name_list=[]
        for i in hl:
            name_list.append(i['name']) 
        print(name_list)
        print(loss_count)

['graph_5_11/2019-05-11 01:47:38.109830109~2019-05-11 02:03:11.829718951.txt']
8.260935078199513
['graph_5_11/2019-05-11 05:45:26.427321308~2019-05-11 06:00:32.541397033.txt']
5.703771928552858
['graph_5_11/2019-05-11 07:50:01.504859344~2019-05-11 08:05:27.045713282.txt']
5.540749057518473
['graph_5_11/2019-05-11 09:39:28.743031655~2019-05-11 09:55:12.543113845.txt']
5.01482287867678
['graph_5_11/2019-05-11 10:10:40.816335674~2019-05-11 10:26:38.131479341.txt']
6.653324832366889
['graph_5_11/2019-05-11 14:19:36.223542340~2019-05-11 14:34:58.512935630.txt']
5.200361413868848
['graph_5_11/2019-05-11 15:05:01.505056736~2019-05-11 15:20:31.495771178.txt']
5.163600187078302
['graph_5_11/2019-05-11 17:11:03.959014293~2019-05-11 17:27:33.517836395.txt']
6.687543395062588
['graph_5_11/2019-05-11 17:43:31.490443808~2019-05-11 17:59:31.493502219.txt']
5.099529130379809
['graph_5_11/2019-05-11 19:03:42.787479327~2019-05-11 19:19:45.328437923.txt']
5.299113724009555
['graph_5_11/2019-05-11 20:24:0

## 5-14

In [30]:
# 5-14

# node_IDF=torch.load("node_IDF_5_9")
y_data_5_14=[]
df_list_5_14=[]
# node_set_list=[]
history_list_5_14=[]
tw_que=[]
his_tw={}
current_tw={}
loss_list_5_14=[]


file_path_list=[]
file_path="graph_5_14/"
file_l=os.listdir("graph_5_14/")
for i in file_l:
    file_path_list.append(file_path+i)
    
    
index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_5_14.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_5_14/")

    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list_5_14:
        for his_tw in hq:

            if cal_set_rel(current_tw['nodeset'],his_tw['nodeset'],file_list)!=0 and current_tw['name']!=his_tw['name']:
                print("history queue:",his_tw['name'])

                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
            if added_que_flag:
                break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list_5_14.append(temp_hq)
    index_count+=1
    loss_list_5_14.append(loss_avg)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))

index_count: 0
thr: 1.0992022639441434
graph_5_14/2019-05-14 00:00:00.216652068~2019-05-14 00:15:02.152576344.txt    2.2715391591619443  count: 9919  percentage: 0.054114656075418995  node count: 484  edge count: 554
index_count: 1
thr: 1.6962835574304205
graph_5_14/2019-05-14 00:15:02.152576344~2019-05-14 00:30:52.456495816.txt    3.3969684547659287  count: 5049  percentage: 0.06403459821428571  node count: 90  edge count: 94
index_count: 2
thr: 1.5208833027290787
graph_5_14/2019-05-14 00:30:52.456495816~2019-05-14 00:46:31.488675323.txt    3.3382268749424404  count: 4948  percentage: 0.053099244505494504  node count: 101  edge count: 111
index_count: 3
thr: 1.199508745825513
graph_5_14/2019-05-14 00:46:31.488675323~2019-05-14 01:02:22.766494572.txt    3.12558900028578  count: 5647  percentage: 0.042420372596153846  node count: 123  edge count: 134
index_count: 4
thr: 1.1171231469970961
graph_5_14/2019-05-14 01:02:22.766494572~2019-05-14 01:18:47.800927242.txt    2.607708723288659  co

In [31]:
name_list=[]
for hl in history_list_5_14:
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)
#     name_list=[]
    if loss_count>12:
        name_list=[]
        for i in hl:
            name_list.append(i['name']) 
        print(name_list)
        for i in name_list:
            pred_label[i]=1
        print(loss_count)

## 5-15

In [32]:
# 5-15 

# node_IDF=torch.load("node_IDF_5_15")
# node_IDF=torch.load("node_IDF_5_9")
y_data_5_15=[]
df_list_5_15=[]
# node_set_list=[]
history_list_5_15=[]
tw_que=[]
his_tw={}
current_tw={}
loss_list_5_15=[]



file_path_list=[]
file_path="graph_5_15/"
file_l=os.listdir("graph_5_15/")
for i in file_l:
    file_path_list.append(file_path+i)

index_count=0
for f_path in sorted(file_path_list):
    f=open(f_path)
    edge_loss_list=[]
    edge_list=[]
    print('index_count:',index_count)
    
    for line in f:
        l=line.strip()
        jdata=eval(l)
        edge_loss_list.append(jdata['loss'])
        edge_list.append([str(jdata['srcmsg']),str(jdata['dstmsg'])])
    df_list_5_15.append(pd.DataFrame(edge_loss_list))
    count,loss_avg,node_set,edge_set=cal_anomaly_loss(edge_loss_list,edge_list,"graph_5_15/")

    current_tw['name']=f_path
    current_tw['loss']=loss_avg
    current_tw['index']=index_count
    current_tw['nodeset']=node_set

    added_que_flag=False
    for hq in history_list_5_15:
        for his_tw in hq:

            if cal_set_rel(current_tw['nodeset'],his_tw['nodeset'],file_list)!=0 and current_tw['name']!=his_tw['name']:
                print("history queue:",his_tw['name'])
                hq.append(copy.deepcopy(current_tw))
                added_que_flag=True
                break
            if added_que_flag:
                break
    if added_que_flag is False:
        temp_hq=[copy.deepcopy(current_tw)]
        history_list_5_15.append(temp_hq)
    index_count+=1
    loss_list_5_15.append(loss_avg)
    print( f_path,"  ",loss_avg," count:",count," percentage:",count/len(edge_list)," node count:",len(node_set)," edge count:",len(edge_set))

index_count: 0
thr: 0.7017160930906966
graph_5_15/2019-05-15 00:00:01.490408727~2019-05-15 00:16:14.833595653.txt    1.569451990186191  count: 7442  percentage: 0.04200912210982659  node count: 430  edge count: 710
index_count: 1
thr: 0.8338568772713835
graph_5_15/2019-05-15 00:16:14.833595653~2019-05-15 00:32:01.492056162.txt    1.6007974350398273  count: 7705  percentage: 0.05532657398897059  node count: 138  edge count: 180
index_count: 2
thr: 1.322699930062194
graph_5_15/2019-05-15 00:32:01.492056162~2019-05-15 00:47:15.554515213.txt    3.302451719005372  count: 1273  percentage: 0.03453233506944445  node count: 56  edge count: 67
index_count: 3
thr: 1.4335928802529354
graph_5_15/2019-05-15 00:47:15.554515213~2019-05-15 01:04:31.491761640.txt    3.393008336262723  count: 1195  percentage: 0.04322193287037037  node count: 55  edge count: 64
index_count: 4
thr: 0.8126172968089068
graph_5_15/2019-05-15 01:04:31.491761640~2019-05-15 01:20:01.492131631.txt    2.0165322843454856  count: 

In [33]:
name_list=[]
for hl in history_list_5_15:
    loss_count=0
    for hq in hl:
        if loss_count==0:
            loss_count=(loss_count+1)*(hq['loss']+1)
        else:
            loss_count=(loss_count)*(hq['loss']+1)
#     name_list=[]
    if loss_count>12:
        name_list=[]
        for i in hl:
            name_list.append(i['name']) 
        print(name_list)
        for i in name_list:
            pred_label[i]=1
        print(loss_count)

['graph_5_15/2019-05-15 13:58:15.520482252~2019-05-15 14:13:37.257086895.txt']
15.148306795008041


# Evaluation

In [34]:
from sklearn.metrics import average_precision_score, roc_auc_score

from sklearn.metrics import roc_auc_score
import torch
from sklearn import preprocessing
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix

# Unused and incomplete function
def plot_thr():
    np.seterr(invalid='ignore')
    step=0.01
    thr_list=torch.arange(-5,5,step)

    precision_list=[]
    recall_list=[]
    fscore_list=[]
    accuracy_list=[]
    auc_val_list=[]
    for thr in thr_list:
        threshold=thr
        y_prediction=[]
        for i in y_test_scores:
            if i >threshold:
                y_prediction.append(1)
            else:
                y_prediction.append(0)
        precision,recall,fscore,accuracy,auc_val=classifier_evaluation(y_test, y_prediction)   
        precision_list.append(float(precision))
        recall_list.append(float(recall))
        fscore_list.append(float(fscore))
        accuracy_list.append(float(accuracy))
        auc_val_list.append(float(auc_val))

    max_fscore=max(fscore_list)
    max_fscore_index=fscore_list.index(max_fscore)
    print(max_fscore_index)
    print("max threshold:",thr_list[max_fscore_index])
    print('precision:',precision_list[max_fscore_index])
    print('recall:',recall_list[max_fscore_index])
    print('fscore:',fscore_list[max_fscore_index])
    print('accuracy:',accuracy_list[max_fscore_index])    
    print('auc:',auc_val_list[max_fscore_index])
    
        
     # list tensor
#     precision_list=torch.tensor(precision_list)   
#     recall_list=torch.tensor(recall_list)   
#     fscore_list=torch.tensor(fscore_list)   
#     accuracy_list=torch.tensor(accuracy_list)   
#     auc_val_list=torch.tensor(auc_val_list)   

    


    
    # plt.scatter(attack_x, attack_y, s=20, c='r', label='Attack graph',marker='*')
    # plt.scatter(bengin_x, bengin_y, s=20, c='g', label='Bengin graph',marker='1')
    # plt.scatter(bengin_x, bengin_y, s=20, c='g', label='Bengin graph',marker='1')

    plt.plot(thr_list,precision_list,color='red',label='precision',linewidth=2.0,linestyle='-')
    plt.plot(thr_list,recall_list,color='orange',label='recall',linewidth=2.0,linestyle='solid')
    plt.plot(thr_list,fscore_list,color='y',label='F-score',linewidth=2.0,linestyle='dashed')
    plt.plot(thr_list,accuracy_list,color='g',label='accuracy',linewidth=2.0,linestyle='dashdot')
    plt.plot(thr_list,auc_val_list,color='b',label='auc_val',linewidth=2.0,linestyle='dotted')
    # '-', '--', '-.', ':', 'None', ' ', '', 'solid', 'dashed', 'dashdot', 'dotted'


    # plt.scatter(turnovers, graph_loss, c=color)
    plt.xlabel("Threshold", fontdict={'size': 16})
    plt.ylabel("Rate", fontdict={'size': 16})
    plt.title("Different evaluation Indicators by varying threshold value", fontdict={'size': 12})
    plt.legend(loc='best', fontsize=12, markerscale=0.5)
    plt.show()

# Function to compute classification evaluation metrics
def classifier_evaluation(y_test, y_test_pred):
    """
    Calculate evaluation metrics based on ground truth and predicted labels.

    Parameters:
        y_test (list or array): Ground truth binary labels.
        y_test_pred (list or array): Predicted binary labels.

    Returns:
        tuple: precision, recall, F-score, accuracy, and AUC values.
    """
    tn, fp, fn, tp = confusion_matrix(y_test, y_test_pred).ravel()
    print('tn:',tn)
    print('fp:',fp)
    print('fn:',fn)
    print('tp:',tp)

    # Calculate evaluation metrics
    precision=tp/(tp+fp)
    recall=tp/(tp+fn)
    accuracy=(tp+tn)/(tp+tn+fp+fn)
    fscore=2*(precision*recall)/(precision+recall)    
    auc_val=roc_auc_score(y_test, y_test_pred)

    print("precision:",precision)
    print("recall:",recall)
    print("fscore:",fscore)
    print("accuracy:",accuracy)
    print("auc_val:",auc_val)

    return precision,recall,fscore,accuracy,auc_val

# Function to apply Min-Max scaling to a dataset
def minmax(data):
    """
    Apply Min-Max scaling to normalize data to a range [0, 1].

    Parameters:
        data (list or array): Input data to be normalized.

    Returns:
        list: Normalized data.
    """
    min_val=min(data)
    max_val=max(data)
    ans=[]
    for i in data:
        ans.append((i-min_val)/(max_val-min_val))
    return ans

In [35]:
# Create list of truth and predicted values from ground truth and predict dictionaries
y=[]
y_pred=[]
for i in labels:
    y.append(labels[i])
    y_pred.append(pred_label[i])

In [36]:
classifier_evaluation(y,y_pred)

tn: 174
fp: 0
fn: 1
tp: 1
precision: 1.0
recall: 0.5
fscore: 0.6666666666666666
accuracy: 0.9943181818181818
auc_val: 0.75


(1.0, 0.5, 0.6666666666666666, 0.9943181818181818, 0.75)

# Count attack edge numbers

In [37]:
def keyword_hit(line):
    attack_nodes=[
#             'sshd',
            'sshdlog',
        'shm',
#          'python',
#             'firefox',
        '189.141.204.211',
        '208.203.20.42',
       
#         '',
#         '',
#         '',
        ]
    flag=False
    for i in attack_nodes:
        if i in line:
            flag=True
            break
    return flag



files=[    
    'graph_5_15/2019-05-15 13:58:15.520482252~2019-05-15 14:13:37.257086895.txt',
    'graph_5_15/2019-05-15 14:44:51.773840192~2019-05-15 15:00:26.765466538.txt',]

# Count total edges in these attac time windows
attack_edge_count=0
for fpath in tqdm(files):
    f=open(fpath)
    for line in f:
        if keyword_hit(line):
            attack_edge_count+=1
print(attack_edge_count)

100%|██████████| 2/2 [00:00<00:00,  4.33it/s]

1207





# Visualization

For the provided attack time windows create graphs for visualization

In [38]:
import os

from graphviz import Digraph
import networkx as nx
import datetime
import community.community_louvain as community_louvain
from tqdm import tqdm



# Some common path abstraction for visualization
replace_dic = {
        '/run/shm/':'/run/shm/*',
        #     '/home/admin/.cache/mozilla/firefox/pe11scpa.default/cache2/entries/':'/home/admin/.cache/mozilla/firefox/pe11scpa.default/cache2/entries/*',
        '/home/admin/.cache/mozilla/firefox/':'/home/admin/.cache/mozilla/firefox/*',
        '/home/admin/.mozilla/firefox':'/home/admin/.mozilla/firefox*',
        '/data/replay_logdb/':'/data/replay_logdb/*',
        '/home/admin/.local/share/applications/':'/home/admin/.local/share/applications/*',
        '/usr/share/applications/':'/usr/share/applications/*',
        '/lib/x86_64-linux-gnu/':'/lib/x86_64-linux-gnu/*',
        '/proc/':'/proc/*',
        '/stat':'*/stat',
        '/etc/bash_completion.d/':'/etc/bash_completion.d/*',
        '/usr/bin/python2.7':'/usr/bin/python2.7/*',
        '/usr/lib/python2.7':'/usr/lib/python2.7/*',
        '/data/data/org.mozilla.fennec_firefox_dev/cache/':'/data/data/org.mozilla.fennec_firefox_dev/cache/*',
        'UNNAMED':'UNNAMED*',
        '/etc/fonts/':'/etc/fonts/*',
}


def replace_path_name(path_name):
    for i in replace_dic:
        if i in path_name:
            return replace_dic[i]
    return path_name


# Users should manually put the detected anomalous time windows here
attack_list = [
        'graph_5_15/2019-05-15 13:58:15.520482252~2019-05-15 14:13:37.257086895.txt',
        'graph_5_15/2019-05-15 14:44:51.773840192~2019-05-15 15:00:26.765466538.txt',
]

original_edges_count = 0
graphs = []
gg = nx.DiGraph()
count = 0
for path in tqdm(attack_list):
    if ".txt" in path:
        line_count = 0
        node_set = set()
        tempg = nx.DiGraph()
        f = open(path, "r")
        edge_list = []
        for line in f:
            count += 1
            l = line.strip()
            jdata = eval(l)
            edge_list.append(jdata)

        edge_list = sorted(edge_list, key=lambda x: x['loss'], reverse=True)
        original_edges_count += len(edge_list)

        loss_list = []
        for i in edge_list:
            loss_list.append(i['loss'])
        loss_mean = mean(loss_list)
        loss_std = std(loss_list)
        print(loss_mean)
        print(loss_std)
        thr = loss_mean + 1.5 * loss_std
        print("thr:", thr)
        for e in edge_list:
            if e['loss'] > thr:
                tempg.add_edge(str(hashgen(replace_path_name(e['srcmsg']))),
                               str(hashgen(replace_path_name(e['dstmsg']))))
                gg.add_edge(str(hashgen(replace_path_name(e['srcmsg']))), str(hashgen(replace_path_name(e['dstmsg']))),
                            loss=e['loss'], srcmsg=e['srcmsg'], dstmsg=e['dstmsg'], edge_type=e['edge_type'],
                            time=e['time'])


partition = community_louvain.best_partition(gg.to_undirected())

# Generate the candidate subgraphs based on community discovery results
communities = {}
max_partition = 0
for i in partition:
    if partition[i] > max_partition:
        max_partition = partition[i]
for i in range(max_partition + 1):
    communities[i] = nx.DiGraph()
for e in gg.edges:
    communities[partition[e[0]]].add_edge(e[0], e[1])
    communities[partition[e[1]]].add_edge(e[0], e[1])


# Define the attack nodes. They are **only be used to plot the colors of attack nodes and edges**.
# They won't change the detection results.
# Didn't add too much nodes for coloring. Most of the results are compared with the ground truth documentations manually
def attack_edge_flag(msg):
    attack_nodes = [
        '208.203.20.42',
        '189.141.204.211',
        '/var/log/sshdlog',
        '/usr/sbin/sshd',
        '/usr/local/lib/firefox-54.0.1/firefox',
    ]
    flag = False
    for i in attack_nodes:
        if i in str(msg):
            flag = True
    return flag


# Plot and render candidate subgraph
os.system(f"mkdir -p ./graph_visual/")
graph_index = 0
for c in communities:
    dot = Digraph(name="MyPicture", comment="the test", format="pdf")
    dot.graph_attr['rankdir'] = 'LR'

    for e in communities[c].edges:
        try:
            temp_edge = gg.edges[e]
            srcnode = e['srcnode']
            dstnode = e['dstnode']
        except:
            pass

        if True:
            # source node
            if "'subject': '" in temp_edge['srcmsg']:
                src_shape = 'box'
            elif "'file': '" in temp_edge['srcmsg']:
                src_shape = 'oval'
            elif "'netflow': '" in temp_edge['srcmsg']:
                src_shape = 'diamond'
            if attack_edge_flag(temp_edge['srcmsg']):
                src_node_color = 'red'
            else:
                src_node_color = 'blue'
            dot.node(name=str(hashgen(replace_path_name(temp_edge['srcmsg']))), label=str(
                replace_path_name(temp_edge['srcmsg']) + str(
                    partition[str(hashgen(replace_path_name(temp_edge['srcmsg'])))])), color=src_node_color,
                     shape=src_shape)

            # destination node
            if "'subject': '" in temp_edge['dstmsg']:
                dst_shape = 'box'
            elif "'file': '" in temp_edge['dstmsg']:
                dst_shape = 'oval'
            elif "'netflow': '" in temp_edge['dstmsg']:
                dst_shape = 'diamond'
            if attack_edge_flag(temp_edge['dstmsg']):
                dst_node_color = 'red'
            else:
                dst_node_color = 'blue'
            dot.node(name=str(hashgen(replace_path_name(temp_edge['dstmsg']))), label=str(
                replace_path_name(temp_edge['dstmsg']) + str(
                    partition[str(hashgen(replace_path_name(temp_edge['dstmsg'])))])), color=dst_node_color,
                     shape=dst_shape)

            if attack_edge_flag(temp_edge['srcmsg']) and attack_edge_flag(temp_edge['dstmsg']):
                edge_color = 'red'
            else:
                edge_color = 'blue'
            dot.edge(str(hashgen(replace_path_name(temp_edge['srcmsg']))),
                     str(hashgen(replace_path_name(temp_edge['dstmsg']))), label=temp_edge['edge_type'],
                     color=edge_color)

    dot.render(f'./graph_visual/subgraph_' + str(graph_index), view=False)
    graph_index += 1





 50%|█████     | 1/2 [00:05<00:05,  5.85s/it]

6.879262950067025
4.766587935442699
thr: 14.029144853231074
2.3494681087178666
4.200902737755488
thr: 8.650822215351099


100%|██████████| 2/2 [00:13<00:00,  6.52s/it]
