In [10]:
from EGraphSAGE import Model, compute_accuracy, train
import pickle
import dgl
import torch as th
import os
from pathlib import Path
from tqdm import tqdm
import numpy as np
import pandas as pd

In [17]:
features_df = pd.read_csv('raw/NetFlow_v3_Features.csv')
features = list(features_df.Feature)
features

['IPV4_SRC_ADDR',
 'IPV4_DST_ADDR',
 'L4_SRC_PORT',
 'L4_DST_PORT',
 'PROTOCOL',
 'L7_PROTO',
 'IN_BYTES',
 'OUT_BYTES',
 'IN_PKTS',
 'OUT_PKTS',
 'FLOW_DURATION_MILLISECONDS',
 'TCP_FLAGS',
 'CLIENT_TCP_FLAGS',
 'SERVER_TCP_FLAGS',
 'DURATION_IN',
 'DURATION_OUT',
 'MIN_TTL',
 'MAX_TTL',
 'LONGEST_FLOW_PKT',
 'SHORTEST_FLOW_PKT',
 'MIN_IP_PKT_LEN',
 'MAX_IP_PKT_LEN',
 'SRC_TO_DST_SECOND_BYTES',
 'DST_TO_SRC_SECOND_BYTES',
 'RETRANSMITTED_IN_BYTES',
 'RETRANSMITTED_IN_PKTS',
 'RETRANSMITTED_OUT_BYTES',
 'RETRANSMITTED_OUT_PKTS',
 'SRC_TO_DST_AVG_THROUGHPUT',
 'DST_TO_SRC_AVG_THROUGHPUT',
 'NUM_PKTS_UP_TO_128_BYTES',
 'NUM_PKTS_128_TO_256_BYTES',
 'NUM_PKTS_256_TO_512_BYTES',
 'NUM_PKTS_512_TO_1024_BYTES',
 'NUM_PKTS_1024_TO_1514_BYTES',
 'TCP_WIN_MAX_IN',
 'TCP_WIN_MAX_OUT',
 'ICMP_TYPE',
 'ICMP_IPV4_TYPE',
 'DNS_QUERY_ID',
 'DNS_QUERY_TYPE',
 'DNS_TTL_ANSWER',
 'FTP_COMMAND_RET_CODE',
 'FLOW_START_MILLISECONDS',
 'FLOW_END_MILLISECONDS',
 'SRC_TO_DST_IAT_MIN                ',
 'SRC_TO

In [None]:
classes = pd.read_csv('raw/NF-ToN-IoT-v3.csv', dtype='category', usecols=['Attack'])

In [16]:
def get_edge_masks(l, train_split=0.8, valid_test_split=0.5):
    tr = int(l * train_split)
    o = l - tr
    edge_train_mask = np.concatenate((np.ones(tr), np.zeros(o)))
    valid = int(o*valid)
    edge_valid_mask = np.concatenate((np.zeros(tr), np.ones(valid), np.zeros(o - valid)))
    return edge_train_mask, edge_valid_mask

In [None]:
from sklearn.utils import class_weight 
from sklearn.metrics import f1_score
from torch import nn


EPOCHS = 1


root = Path('interm/NF-IoT flowgraphs')
graph_files = [(root / f) for f in os.listdir(root)]

model = Model(
    ...
)
opt = th.optim.Adam(model.parameters())


training_accs, validation_accs = [], []
# training_F1s, validation_F1s = {}, {} # add ALL metrics


for epoch in range(1, EPOCHS+1):

    # temporally segmented flow graphs 
    for i, graph_file in enumerate(graph_files):
        with open(graph_file, 'rb') as f:
            G = pickle.load(f)
            
        train_mask, valid_mask = get_edge_masks(len(G.edata['h']))
        
            
        class_weights = class_weight.compute_class_weight(
            'balanced',
            np.unique(G.edata['Attack'].cpu().numpy()),
            G.edata['Attack'].cpu().numpy())

        class_weights = th.FloatTensor(class_weights)
        criterion = nn.CrossEntropyLoss(weight = class_weights)
        
        node_features = G.ndata['h']
        edge_features = G.edata['h']
        edge_label = G.edata['Attack']

        pred = model(G, node_features, edge_features)#.cuda()
        loss = criterion(pred[train_mask], edge_label[valid_mask])
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        train_acc = compute_accuracy(pred[train_mask], edge_label[train_mask])
        training_accs += [train_acc]
        valid_acc = compute_accuracy(pred[valid_mask], edge_label[valid_mask])
        validation_accs += [valid_acc]
    
    
    if epoch % 100 == 0:
        print('Training acc:', train_acc)
        print('Validation acc:', validation_accs)

    if epoch == epochs-1 and test_acc:
        test_mask = ~np.array(train_mask + valid_mask)
        test_acc = compute_accuracy(pred[test_mask], edge_label[test_mask])
        print('\nFinal test acc:', test_acc)
    else:
        test_acc = None

    
    
    


[PosixPath('interm/NF-IoT flowgraphs/38.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/18.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/25.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/39.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/52.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/46.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/45.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/6.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/1.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/28.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/31.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/24.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/8.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/20.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/55.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/43.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/33.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/50.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/14.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/37.pkl'),
 PosixPath('interm/NF-IoT flowgraphs/16.pkl'),
 PosixPath('inte