<a href="https://colab.research.google.com/github/esthy13/cil-intrusion-detection/blob/main/notebooks/cybersecurity_icarl_margarita.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
class IDSIncrementalDataset(Dataset):
    def __init__(self, df, feature_cols, label_col):
        self.features = df[feature_cols].values.astype(np.float32)
        self.labels = df[label_col].values.astype(int)
        self.indices = np.arange(len(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = torch.tensor(self.features[idx])
        y = torch.tensor(self.labels[idx])
        return x, y

    def filter_by_classes(self, class_list):
        mask = np.isin(self.labels, class_list)
        df_filtered = pd.DataFrame(self.features[mask])
        df_filtered["label"] = self.labels[mask]

        return IDSIncrementalDataset(
            df_filtered,
            feature_cols=df_filtered.columns[:-1],
            label_col="label"
        )

In [None]:
class IDSNet(nn.Module):
    def __init__(self, input_dim, feature_dim=128):
        super().__init__()

        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

        self.classifier = None  # SOLO para entrenamiento

    def forward(self, x):
        feats = self.feature_extractor(x)
        feats = F.normalize(feats, dim=1)

        if self.classifier is not None:
            logits = self.classifier(feats)
            return logits, feats

        return feats

    def update_classifier(self, num_classes):
        self.classifier = nn.Linear(
            self.feature_extractor[-1].out_features,
            num_classes
        )

In [None]:
class ICaRL:
    def __init__(self, model, device, memory_size=2000):
        self.model = model.to(device)
        self.device = device
        self.memory_size = memory_size

        self.seen_classes = []
        self.exemplars = {}      # class_id -> indices
        self.class_means = {}    # class_id -> mean vector

        self.old_model = None

    def distillation_loss(self, old_logits, new_logits, T=2):
        old_probs = F.softmax(old_logits / T, dim=1)
        new_log_probs = F.log_softmax(new_logits / T, dim=1)
        return F.kl_div(new_log_probs, old_probs, reduction="batchmean") * (T * T)

    def add_classes(self, new_classes):
        self.seen_classes += new_classes
        num_classes = len(self.seen_classes)

        self.model.update_classifier(num_classes)

        if self.old_model is not None:
            old_w = self.old_model.classifier.weight.data
            self.model.classifier.weight.data[:old_w.size(0)] = old_w

        self.old_model = copy.deepcopy(self.model).eval()

    def train_task(self, dataset, epochs=10, batch_size=64, lr=1e-3):
        self.model.train()

        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        for epoch in range(epochs):
            total_loss = 0.0

            for x, y in loader:
                x, y = x.to(self.device), y.to(self.device)

                logits, feats = self.model(x)
                loss = F.cross_entropy(logits, y)

                if self.old_model is not None:
                    with torch.no_grad():
                        old_logits, _ = self.old_model(x)
                    loss += self.distillation_loss(old_logits, logits)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1} | Loss: {total_loss / len(loader):.4f}"

    def build_exemplars(self, dataset):
        m = self.memory_size // len(self.seen_classes)
        self.model.eval()

        for cls in self.seen_classes:
            idxs = np.where(dataset.labels == cls)[0]

            feats = []
            with torch.no_grad():
                for idx in idxs:
                    x, _ = dataset[idx]
                    _, f = self.model(x.unsqueeze(0).to(self.device))
                    feats.append(f.cpu())

            feats = torch.cat(feats, dim=0)
            class_mean = feats.mean(dim=0)

            distances = torch.norm(feats - class_mean, dim=1)
            selected = idxs[torch.argsort(distances)[:m]]

            self.exemplars[cls] = selected
    def compute_class_means(self, dataset):
        self.class_means = {}
        self.model.eval()

        with torch.no_grad():
            for cls, exemplar_idxs in self.exemplars.items():
                feats = []

                for idx in exemplar_idxs:
                    x, _ = dataset[idx]
                    _, f = self.model(x.unsqueeze(0).to(self.device))
                    feats.append(f.cpu())

                feats = torch.cat(feats, dim=0)
                mean = feats.mean(dim=0)
                mean = mean / mean.norm()   # must be normalized

                self.class_means[cls] = mean
    def predict(self, x):
        self.model.eval()

        with torch.no_grad():
            _, feats = self.model(x.to(self.device))
            feats = feats / feats.norm(dim=1, keepdim=True)

            preds = []
            for f in feats:
                dists = {
                    cls: torch.norm(f - mean.to(self.device))
                    for cls, mean in self.class_means.items()
                }
                preds.append(min(dists, key=dists.get))

        return torch.tensor(preds)