In [37]:
import os
os.chdir("../")

In [38]:
import numpy as np
import torch
import json
from torch import nn

from cids.util import misc_funcs as misc
from cids.data import SCVICCIDSDataset

import matplotlib.pyplot as plt
from cids.data import get_SCVIC_dataloader
from cids.util.config import read, transform_ray_hp_config
from cids.training.supervised import train_mlp_ray
from cids.util import misc_funcs as misc
from cids.util.metrics import top1_accuracy, combined_loss
from cids.models.transformer import TransformerEncoder
from cids.models.nn import MLP

In [39]:
tdl, vdl = get_SCVIC_dataloader(8, 2, host_embeddings=True)

INFO:cids.data:load train_dataloader based on /opt/gildemeister/gildemeister-implementation/data/scvic/train_indices.txt
DEBUG:cids.data:Start loading embeddings
INFO:cids.data:Loading SCVIC-CIDS dataset
DEBUG:cids.data:Loading normal
DEBUG:cids.data:Loading malicious


KeyboardInterrupt: 

In [None]:
data = next(iter(tdl))

In [None]:
label = data[-1]
if len(data) == 2:
    inp = torch.flatten(data[0], start_dim=1)
else:
    inps = [torch.flatten(data[i], start_dim=1) for i in range(len(data) - 1)]
    inp = torch.cat(inps, dim=1)

In [None]:
print(inp.shape)
for d in data[:-1]:
    print(d.shape)
    print(torch.flatten(d, start_dim=1).shape)

torch.Size([8, 356])
torch.Size([8, 132])
torch.Size([8, 132])
torch.Size([8, 28, 8])
torch.Size([8, 224])


In [None]:
def train_step_collaborative(src: tuple[torch.Tensor, ...], tgt: torch.Tensor, model: nn.Module, loss_criterion: callable, optimizer: torch.optim.Optimizer, accuracy: callable=None, alpha = 0.):
    
    optimizer.zero_grad()

    outputs = model(*src)

    loss = loss_criterion(*outputs, tgt, alpha=alpha)

    loss.backward()
    optimizer.step()

    if accuracy is not None:
        acc = accuracy(outputs[0], tgt)
        return loss, acc
    
    return loss


In [None]:
config = read(os.path.join(misc.root(), "config/00_hyperparameter_optimization/collaborative_classification/CIDS-1.yaml"), convert_tuple=True)

In [None]:
hp_config = config["hp_config"]
print(hp_config["lr"])

(1e-06, 0.001, 'loguniform')


In [None]:
config = transform_ray_hp_config(hp_config)
# hp_config["model"]["hidden_dims"] = [128, 64]
# hp_config["lr"] = 1e-4

In [None]:
network_config = config["model"]["network"]
host_config = config["model"]["host"]
host_config["n_head"] = host_config["n_head"].sample()
host_config["factor_dim_feedforward"] = host_config["factor_dim_feedforward"].sample()
embedding_config = config["model"]["embedding"]
embedding_config["n_head"] = embedding_config["n_head"].sample()
embedding_config["head"] = embedding_config["head"].sample()
embedding_config["d_model"] = embedding_config["d_model"].sample()
embedding_config["factor_dim_feedforward"] = embedding_config["factor_dim_feedforward"].sample()
aggregation_config = config["model"]["aggregation"]
aggregation_config["hidden_dims"] = aggregation_config["hidden_dims"].sample()

config["model"]["network"] = network_config
config["model"]["host"] = host_config
config["model"]["embedding"] = embedding_config

config["alpha"] = config["alpha"].sample()
config["lr"] = config["lr"].sample()
config["epochs"] = 1

In [None]:
import logging 
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()

In [None]:
class CollaborativeIDSNet(nn.Module):

    def __init__(self, network_encoder: nn.Module, host_encoder: nn.Module, embedding_encoder: nn.Module, aggregation_module: nn.Module):
        super().__init__()

        self.network_encoder = network_encoder
        self.host_encoder = nn.Sequential(host_encoder, nn.Flatten())
        self.embedding_encoder = nn.Sequential(embedding_encoder, nn.Flatten())
        self.aggregation_module = nn.Sequential(nn.ReLU(), aggregation_module)

    def forward(self, x_network: torch.Tensor, x_host: torch.Tensor, x_embeddings: torch.Tensor):

        x_network = self.network_encoder(x_network)
        x_host = self.host_encoder(x_host)
        x_embeddings = self.embedding_encoder(x_embeddings)


        x_agg = torch.cat([x_network, x_host, x_embeddings], dim=1)
        logits = self.aggregation_module(x_agg)

        return logits, x_network

In [None]:
device = "cuda:0"
logger.debug(f"Use device {device}")

logger.info("Load model")

network_config = config["model"]["network"]
host_config = config["model"]["host"]
host_config["dim_feedforward"] = int(host_config.pop("factor_dim_feedforward") * host_config["d_model"])
embedding_config = config["model"]["embedding"]
embedding_config["dim_feedforward"] = int(embedding_config.pop("factor_dim_feedforward") * embedding_config["d_model"])
aggregation_config = config["model"]["aggregation"]
aggregation_config["input_dim"] = int(network_config["output_dim"] \
    + host_config["d_model"] * host_config["max_len"] \
    + (embedding_config["head"] if embedding_config["head"] is not None else embedding_config["d_model"]) * embedding_config["max_len"])

model = CollaborativeIDSNet(
    network_encoder=MLP(**network_config),
    host_encoder=TransformerEncoder(**host_config),
    embedding_encoder=TransformerEncoder(**embedding_config),
    aggregation_module=MLP(**aggregation_config)
)

logger.info("load data")

model.to(device=device)


DEBUG:root:Use device cuda:0
INFO:root:Load model
INFO:root:load data


CollaborativeIDSNet(
  (network_encoder): MLP(
    (input_layer): Linear(in_features=132, out_features=128, bias=True)
    (hidden_layers): ModuleList(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
    (output_layer): Linear(in_features=64, out_features=15, bias=True)
    (dropout): Identity()
  )
  (host_encoder): Sequential(
    (0): TransformerEncoder(
      (transformer_layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
          )
          (linear1): Linear(in_features=8, out_features=4, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=4, out_features=8, bias=True)
          (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine

In [None]:


cooloff = 0
best_loss = float('inf')

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params=params, lr=config["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])

logger.info(f"Start training for {config['epochs']} epochs, under ray conditions")

total_steps = 0
alpha = config["alpha"]
loss_fnc = combined_loss
accuracy = top1_accuracy

for epoch in range(config["epochs"]):
    logger.info(f"Start epoch {epoch}")
    total_loss = 0
    total_acc = 0
    steps = 0

    for data in tdl:
        for d in data:
            print(d.dtype)
        label = data[-1]

        inp_network = data[0].to(device)
        inp_host = data[1].to(device)
        inp_embedding = data[2].to(device)
        label = label.to(device=device)

        loss, acc = train_step_collaborative(
            (inp_network, inp_host, inp_embedding),
            label,
            model=model,
            loss_criterion=loss_fnc,
            optimizer=optimizer,
            accuracy=accuracy,
            alpha=alpha
        )

        total_loss += loss.item()
        total_acc += acc.item()
        
        steps += 1
        total_steps += 1

    epoch_loss = total_loss / steps
    epoch_acc = total_acc / steps
    cooloff += 1

INFO:root:Start training for 1 epochs, under ray conditions
INFO:root:Start epoch 0


torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32
torch.float32
torch.int64
torch.float32
torch.float32


KeyboardInterrupt: 