In [56]:
import pandas as pd
import numpy as np
import dgl
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
# import h5py


In [54]:

import networkx as nx 
from torch_geometric.utils.convert import from_dgl

# run this function per-epoch, per-chunk
def to_graph(data):
    # assumes 'h' is already processed
    # attrs = [c for c in data.columns if c not in ("IPV4_SRC_ADDR", "IPV4_DST_ADDR")]
    # data['h'] = data[attrs].values.tolist()
    
    data['h'] = data['h'].apply(lambda x: np.array(x, dtype=np.float32))
    data['Attack'] = data['Attack'].astype(np.int64)
    
    data = data.rename(columns={'h': 'x'})
    
    G = nx.from_pandas_edgelist(data, source='IPV4_SRC_ADDR', 
                                target='IPV4_DST_ADDR', 
                                edge_attr=['x', 'Attack'], 
                                create_using=nx.MultiGraph())
    G = G.to_directed()

    g = dgl.from_networkx(G, edge_attrs=['x', 'Attack'])
    g = g.line_graph(shared=True)

    # run this during training
    # return from_dgl(g)
    return g


In [17]:
from sklearn.utils import class_weight 
from torch import nn
import torch as th

classes_df = pd.read_csv('raw/NF-ToN-IoT-v3.csv', dtype='category', usecols=['Attack'])
unique_classes = np.array(classes_df['Attack'].unique())

# weighted cross entropy loss
class_weights = class_weight.compute_class_weight(
           class_weight= 'balanced',
            classes=unique_classes,
            y=classes_df['Attack'])

class_weights = th.FloatTensor(class_weights)
criterion = nn.CrossEntropyLoss(weight = class_weights)

In [20]:
del classes_df # memory risk

In [18]:
print(unique_classes)

['Benign' 'scanning' 'dos' 'injection' 'ddos' 'password' 'xss'
 'ransomware' 'Backdoor' 'mitm']


In [57]:
from EGraphSAGE import Preprocessing

f = 'raw/NF-ToN-IoT-v3.csv'
for chunk0 in pd.read_csv(f, chunksize=10_000):
    break

print('loaded')
chunk0 = Preprocessing._prepare_flows(chunk0)
G0 = to_graph(chunk0)

loaded


In [58]:
from torch_geometric.nn import GraphSAGE

model = GraphSAGE(
    in_channels=G0.ndata['x'].shape[1],
    hidden_channels=64, # 128 in original EGraphSAGE paper
    num_layers=2,
    out_channels=len(unique_classes), # !! assumes ordered ?
    dropout=0.2
)
model

GraphSAGE(50, 10, num_layers=2)

In [None]:
import torch as th
from tqdm import tqdm
from EGraphSAGE import Model
import torch.nn.functional as F

def train_one_epoch(model, csv_file, criterion, chunksize=10_000, train=0.8):
    optimizer = th.optim.Adam(model.parameters())
    model.train()
    
    losses, test_losses = [], []
    
    for chunk in tqdm(pd.read_csv(csv_file, chunksize=chunksize)):
        chunk = Preprocessing._prepare_flows(chunk)
        G = to_graph(chunk)
        
        # all ones for edges
        # G.edata['x'] = th.ones(G.num_edges(), G.ndata['x'].shape[1])
        
        # reshape all data for batching
        # G.ndata['h'] = th.reshape(G.ndata['h'], (G.ndata['h'].shape[0], 1 ,G.ndata['h'].shape[1]))
        # G.edata['h'] = th.reshape(G.edata['h'], (G.edata['h'].shape[0], 1 ,G.edata['h'].shape[1]))
        
        size = G.number_of_nodes()
        train_mask = np.zeros(size)
        train_mask[:int(size*train)] = 1
        test_mask = ~np.array(train_mask, dtype=bool)
        
        optimizer.zero_grad()
        labels =  G.ndata['Attack']

        G = from_dgl(G)
        pred = model(G.x, G.edge_index)
        
        loss = criterion(pred[train_mask, :], labels[train_mask])
        test_loss = criterion(pred[test_mask, :], labels[test_mask])
        

        losses.append(loss)
        test_losses.append(test_loss)
        
        del chunk
        del G

    return losses, test_losses


metrics = []

for epoch in range(2):
    print(f'epoch {epoch+1}/2')
    m = train_one_epoch(
        model=model,
        csv_file=f,
        criterion=criterion
        
    )  
    print(m)
    metrics.append(m)
    
    
# !! most of the time is in preprocessing

epoch 1/2


79it [02:16,  1.73s/it]


KeyboardInterrupt: 

In [49]:
G0.number_of_edges()

1190494