In [None]:
import torch.nn as nn
import os
import pickle
import torch
import warnings
import wandb
import time
import numpy as np
import pandas as pd
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


from src.models import EGAT, EGRAPHSAGE
# from src.models import EGAT, EGCN, EGRAPHSAGE
from src.lightning_model import GraphModel
from src.lightning_data import GraphDataModule
from src.dataset.dataset_info import datasets
from local_variables import local_datasets_path

warnings.filterwarnings("ignore", ".*does not have many workers.*")


In [None]:
using_wandb = False
save_top_k = 1

In [None]:
my_datasets = [
    datasets["cic_ton_iot"],
    datasets["cic_ids_2017"],
    datasets["cic_ton_iot_modified"],
    datasets["ccd_inid_modified"],
    datasets["nf_uq_nids_modified"],
    datasets["edge_iiot"],
    datasets["nf_cse_cic_ids2018"],
    datasets["nf_uq_nids"],
    datasets["x_iiot"],
]

In [None]:
early_stopping_patience = max_epochs = 1
# early_stopping_patience = 20
learning_rate = 0.005
weight_decay = 0.0
ndim_out = [128, 128]
num_layers = 2
number_neighbors = [25, 10]
activation = F.relu
dropout = 0.0
residual = True
multi_class = True
use_centralities_nfeats = False
aggregation = "mean"

run_dtime = time.strftime("%Y%m%d-%H%M%S")

g_type = "flow"

if multi_class:
    g_type += "__multi_class"

if use_centralities_nfeats:
    g_type += "__n_feats"

g_type += "__unsorted"

In [None]:
time_elapsed_dict = {}

for dataset in my_datasets:
    dataset_folder = os.path.join(local_datasets_path, dataset.name)
    graphs_folder = os.path.join(dataset_folder, g_type)

    logs_folder = os.path.join("logs", dataset.name)
    os.makedirs(logs_folder, exist_ok=True)
    wandb_runs_path = os.path.join("logs", "wandb_runs")
    os.makedirs(wandb_runs_path, exist_ok=True)

    labels_mapping = {0: "Normal", 1: "Attack"}
    num_classes = 2
    if multi_class:
        with open(os.path.join(dataset_folder, "labels_names.pkl"), "rb") as f:
            labels_names = pickle.load(f)
        labels_mapping = labels_names[0]
    num_classes = len(labels_mapping)

    dataset_kwargs = dict(
        use_node_features=use_centralities_nfeats,
        multi_class=True,
        using_masking=False,
        masked_class=2,
        num_workers=0,
        label_col=dataset.label_col,
        class_num_col=dataset.class_num_col,
        device='cuda' if torch.cuda.is_available() else "cpu"
    )

    data_module = GraphDataModule(
        graphs_folder, batch_size=1, **dataset_kwargs)
    data_module.setup()

    ndim = next(iter(data_module.train_dataloader())).ndata["h"].shape[-1]
    edim = next(iter(data_module.train_dataloader())).edata['h'].shape[-1]

    my_models = {
        # "e_gcn": EGCN(ndim, edim, ndim_out, num_layers, activation,
        #               dropout, residual, num_classes),
        f"e_graphsage_{aggregation}": EGRAPHSAGE(ndim, edim, ndim_out, num_layers, activation, dropout,
                                                 residual, num_classes, num_neighbors=number_neighbors, aggregation=aggregation),
        f"e_graphsage_{aggregation}_no_sampling": EGRAPHSAGE(ndim, edim, ndim_out, num_layers, activation, dropout,
                                                             residual, num_classes, num_neighbors=None, aggregation=aggregation),
        "e_gat_no_sampling": EGAT(ndim, edim, ndim_out, num_layers, activation, dropout,
                                    residual, num_classes, num_neighbors=None),
        # "e_gat_sampling": EGAT(ndim, edim, ndim_out, num_layers, activation, dropout,
        #                        residual, num_classes, num_neighbors=number_neighbors),
    }

    criterion = nn.CrossEntropyLoss(data_module.train_dataset.class_weights)

    elapsed = {}

    for model_name, model in my_models.items():

        config = {
            "run_dtime": run_dtime,
            "type": "GNN",
            "model_name": model_name,
            "max_epochs": max_epochs,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "ndim": ndim,
            "edim": edim,
            "ndim_out": ndim_out,
            "num_layers": num_layers,
            "number_neighbors": number_neighbors,
            "activation": activation.__name__,
            "dropout": dropout,
            "residual": residual,
            "multi_class": multi_class,
            "aggregation": aggregation,
            # "details": "updating edge features",
            "early_stopping_patience": early_stopping_patience,
            "use_centralities_nfeats": use_centralities_nfeats,
        }

        graph_model = GraphModel(model, criterion, learning_rate, config, model_name,
                                    labels_mapping, weight_decay=weight_decay, using_wandb=using_wandb, norm=False, multi_class=True, verbose=False)

        if using_wandb:
            wandb_logger = WandbLogger(
                project=f"GNN-Analysis-{dataset.name}",
                name=model_name,
                config=config,
                save_dir=wandb_runs_path
            )
        else:
            wandb_logger = None

        f1_checkpoint_callback = ModelCheckpoint(
            monitor="val_f1_score",
            mode="max",
            filename="best-val-f1-{epoch:02d}-{val_f1_score:.2f}",
            save_top_k=save_top_k,
            save_on_train_epoch_end=False,
            verbose=False,
        )
        early_stopping_callback = EarlyStopping(
            monitor="val_loss",
            mode="min",
            patience=early_stopping_patience,
            verbose=False,
        )

        trainer = pl.Trainer(
            max_epochs=max_epochs,
            num_sanity_val_steps=0,
            # log_every_n_steps=0,
            callbacks=[
                f1_checkpoint_callback,
                early_stopping_callback
            ],
            default_root_dir=logs_folder,
            logger=wandb_logger,
        )

        trainer.fit(graph_model, datamodule=data_module)

        test_results = []
        test_elapsed = []
        print(
            f"==>> f1_checkpoint_callback.best_k_models.keys(): {f1_checkpoint_callback.best_k_models.keys()}")
        for i, k in enumerate(f1_checkpoint_callback.best_k_models.keys()):
            graph_model.test_prefix = f"best_f1_{i}"
            results = trainer.test(
                graph_model, datamodule=data_module, ckpt_path=k, verbose=False)
            test_results.append(results[0][f"best_f1_{i}_test_f1"])
            test_elapsed.append(results[0][f"best_f1_{i}_elapsed"])

        logs = {
            "median_f1_of_best_f1": np.median(test_results),
            "max_f1_of_best_f1": np.max(test_results),
            "avg_f1_of_best_f1": np.mean(test_results)
        }
        elapsed[model_name] = np.mean(test_elapsed).item()
        print(f"==>> model_name: {model_name}")
        print(f"==>> test_elapsed: {np.mean(test_elapsed)}")
        if using_wandb:
            wandb.log(logs)
            wandb.finish()
        else:
            trainer.logger.log_metrics(logs, step=trainer.global_step)

    time_elapsed_dict[dataset.name] = elapsed

print(f"==>> time_elapsed_dict: {time_elapsed_dict}")

In [None]:
# Convert the nested dictionary into a DataFrame
df = pd.DataFrame.from_dict(time_elapsed_dict, orient='index')
df

In [None]:
# Calculate the average time for each model (i.e.,column-wise mean)
average_times = df.mean(axis=0)
average_times