In [14]:
import os
import math
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, accuracy_score

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv

In [15]:
data = pd.read_parquet("custom_features_dataset.parquet")

In [16]:
data.columns

Index(['Timestamp', 'Src IP', 'Dst IP', 'Bwd Packet Length Min', 'Protocol_6',
       'Bwd Packets/s', 'FWD Init Win Bytes', 'Packet Length Std',
       'FIN Flag Count', 'SrcPortRange_registered', 'Packet Length Min',
       'Fwd Seg Size Min', 'DstPortRange_well_known', 'Bwd IAT Total',
       'SYN Flag Count', 'Bwd Packet Length Std', 'target'],
      dtype='object')

In [None]:
split_ratio = 0.8
train_df = data[:int(len(data)*split_ratio)]
test_df = data[int(len(data)*split_ratio):]

In [17]:
TIMESTAMP_COL = "Timestamp"
SRC_COL = "Src IP"
DST_COL = "Dst IP"
LABEL_COL = "target"

In [18]:
EDGE_COLS = None

In [19]:
BIN_SECONDS = 300

In [20]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [27]:
def time_posenc(epoch_seconds: np.ndarray, periods=(60, 300, 3600, 86400)) -> np.ndarray:
    """
    epoch_seconds: shape [E]
    returns: shape [E, 2*len(periods)]
    """
    s = np.asarray(epoch_seconds, dtype=float).reshape(-1, 1)   # [E, 1]
    feats = []
    for P in periods:
        w = 2.0 * math.pi / float(P)
        feats.append(np.sin(w * s))    # [E, 1]
        feats.append(np.cos(w * s))    # [E, 1]
    return np.concatenate(feats, axis=1)  # [E, 2*len(periods)]


In [28]:
def build_snapshots(df: pd.DataFrame,
                    edge_cols,
                    fit_scaler: bool,
                    scaler_edge: StandardScaler | None = None,
                    bin_seconds: int = 300,
                    device: str = "cpu"):
    """
    Returns:
      snapshots: list[Data] (PyG) in time order
      ip2idx: dict mapping IP -> node index (stable across this split)
      scaler_edge: fitted StandardScaler for edge features
      edge_cols_kept: list of columns used (existing + non-NA) + time enc names
    """
    ID_COLS = [SRC_COL, DST_COL, TIMESTAMP_COL]
    cols_needed = [c for c in ID_COLS + edge_cols + [LABEL_COL] if c in df.columns]

    df = df[cols_needed].dropna(subset=[SRC_COL, DST_COL])
    df = bin_time(df, bin_seconds=bin_seconds)

    # Edge feature scaler
    if scaler_edge is None:
        scaler_edge = StandardScaler()
        fit_scaler = True
    if fit_scaler:
        scaler_edge.fit(df[edge_cols].astype(float).values)

    # Stable node indexing across all snapshots
    ips = pd.Index(pd.unique(pd.concat([df[SRC_COL], df[DST_COL]])))
    ip2idx = {ip: i for i, ip in enumerate(ips)}

    snapshots = []
    prev_activity = defaultdict(int)  # 1-bin lag activity per node

    for b, g in df.sort_values('_bin').groupby('_bin'):
        # Nodes -> ids
        src = g[SRC_COL].map(ip2idx).astype(int).values
        dst = g[DST_COL].map(ip2idx).astype(int).values
        edge_index = torch.tensor(np.vstack([src, dst]), dtype=torch.long)

        # Edge features = scaled numeric + time encodings
        eX = scaler_edge.transform(g[edge_cols].astype(float).values)
        tfe = time_posenc(g['_epoch'].values)  # [E, 2*len(periods)]
        edge_attr = torch.tensor(np.hstack([eX, tfe]), dtype=torch.float)

        # Labels (edge-level)
        y = torch.tensor(g[LABEL_COL].astype(int).values, dtype=torch.long)

        # Node features (inductive): in/out/total degree + prev-bin activity
        n_nodes = len(ip2idx)
        out_deg = np.bincount(src, minlength=n_nodes)
        in_deg = np.bincount(dst, minlength=n_nodes)
        deg = (out_deg + in_deg).reshape(-1, 1)
        node_feat = np.hstack([
            in_deg.reshape(-1, 1),
            out_deg.reshape(-1, 1),
            deg,
            np.array([prev_activity[i] for i in range(n_nodes)]).reshape(-1, 1)
        ])
        node_feat = np.log1p(node_feat)  # stabilize scale
        x = torch.tensor(node_feat, dtype=torch.float)

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=y
        )
        data._bin = int(b)
        snapshots.append(data)

        # Update prev activity
        touched = np.bincount(np.concatenate([src, dst]), minlength=n_nodes)
        for i, c in enumerate(touched):
            prev_activity[i] = int(c)

    time_names = [f"time_{i}" for i in range(tfe.shape[1])] if len(snapshots) > 0 else []
    return snapshots, ip2idx, scaler_edge, edge_cols + time_names

In [29]:
class WeightGRU_H(nn.Module):
    """EGCU-H: evolve weights using current node features + previous weights."""
    def __init__(self, in_feat_dim, w_in_dim, w_out_dim):
        super().__init__()
        self.w_in_dim = w_in_dim
        self.w_out_dim = w_out_dim
        self.h_dim = w_in_dim * w_out_dim
        self.xproj = nn.Linear(in_feat_dim, self.h_dim)
        self.gru = nn.GRUCell(self.h_dim, self.h_dim)

    def forward(self, x_t, h_prev):
        x_sum = x_t.mean(dim=0, keepdim=True)           # [1, F]
        x_in = self.xproj(x_sum).squeeze(0)             # [h_dim]
        h_t = self.gru(x_in, h_prev)                    # [h_dim]
        W_t = h_t.view(self.w_out_dim, self.w_in_dim)
        return W_t, h_t


class WeightGRU_O(nn.Module):
    """EGCU-O: evolve weights only from previous weights."""
    def __init__(self, w_in_dim, w_out_dim):
        super().__init__()
        self.w_in_dim = w_in_dim
        self.w_out_dim = w_out_dim
        self.h_dim = w_in_dim * w_out_dim
        self.gru = nn.GRUCell(self.h_dim, self.h_dim)

    def forward(self, _x_t_unused, h_prev):
        zeros = torch.zeros_like(h_prev)
        h_t = self.gru(zeros, h_prev)
        W_t = h_t.view(self.w_out_dim, self.w_in_dim)
        return W_t, h_t


class EvolveGCNEdgeClassifier(nn.Module):
    """
    Drop-in for GraphTimeEdgeClassifier:
    - forward(data) -> logits [E, C]
    - call reset_state() at the start of each snapshot sequence
    """
    def __init__(self, in_node, in_edge, hidden=64, out_hidden=64,
                 num_layers=2, dropout=0.1, variant='H', num_classes=2):
        super().__init__()
        assert num_layers >= 1
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.act = nn.ReLU()

        # GCN layers (weights updated each step)
        self.convs = nn.ModuleList()
        dims = [in_node] + [hidden]*(num_layers-1) + [out_hidden]
        for i in range(num_layers):
            self.convs.append(GCNConv(dims[i], dims[i+1], bias=False))

        # Evolvers
        self.variant = variant.upper()
        self.evolvers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = dims[i]
            out_dim = dims[i+1]
            if self.variant == 'H':
                self.evolvers.append(WeightGRU_H(in_dim, w_in_dim=in_dim, w_out_dim=out_dim))
            elif self.variant == 'O':
                self.evolvers.append(WeightGRU_O(w_in_dim=in_dim, w_out_dim=out_dim))
            else:
                raise ValueError("variant must be 'H' or 'O'")

        # Hidden states (flattened weights)
        self._states = [None]*num_layers
        self._on_device = None  # track device for init
        self.edge_mlp = nn.Sequential(
            nn.Linear(2*out_hidden + in_edge, out_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_hidden, num_classes)
        )

    @torch.no_grad()
    def reset_state(self, hard=True):
        """Reinitialize the evolving weights (call at start of a sequence)."""
        for i, conv in enumerate(self.convs):
            device = conv.lin.weight.device
            W = torch.empty(conv.out_channels, conv.in_channels, device=device)
            nn.init.xavier_uniform_(W)
            self._states[i] = W.flatten().detach().clone()
            conv.lin.weight.copy_(W)

    def detach_state(self):
        """Detach hidden states between steps (for truncated BPTT if needed)."""
        self._states = [h.detach() for h in self._states]

    def _evolve_and_set(self, layer_idx, x_t):
        h_prev = self._states[layer_idx]
        evolver = self.evolvers[layer_idx]
        W_t, h_t = evolver(x_t, h_prev)
        self._states[layer_idx] = h_t
        with torch.no_grad():
            self.convs[layer_idx].lin.weight.copy_(W_t)
        return W_t

    def forward(self, data: Data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        for l, conv in enumerate(self.convs):
            self._evolve_and_set(l, x)
            x = conv(x, edge_index)
            x = self.act(x)
            x = self.dropout(x)

        src, dst = edge_index
        z = torch.cat([x[src], x[dst], edge_attr], dim=1)
        logits = self.edge_mlp(z)
        return logits

In [30]:
def pick_edge_cols(df: pd.DataFrame):
    if EDGE_COLS is not None:
        return [c for c in EDGE_COLS if c in df.columns]
    # auto-detect numeric columns (exclude ids and label)
    excluded = {TIMESTAMP_COL, SRC_COL, DST_COL, LABEL_COL}
    cand = []
    for c in df.columns:
        if c in excluded: 
            continue
        # keep only numeric-like columns (ignore strings/categoricals)
        try:
            pd.to_numeric(df[c].dropna().iloc[:10])
            cand.append(c)
        except Exception:
            pass
    return cand

In [31]:
edge_cols = pick_edge_cols(train_df)
print(f"Using {len(edge_cols)} edge feature columns:", edge_cols[:10], "..." if len(edge_cols) > 10 else "")

Using 13 edge feature columns: ['Bwd Packet Length Min', 'Protocol_6', 'Bwd Packets/s', 'FWD Init Win Bytes', 'Packet Length Std', 'FIN Flag Count', 'SrcPortRange_registered', 'Packet Length Min', 'Fwd Seg Size Min', 'DstPortRange_well_known'] ...


In [32]:
print("Building snapshots...")
scaler = StandardScaler().fit(train_df[edge_cols].astype(float).values)

train_snaps, train_ip2idx, scaler, edge_cols_full = build_snapshots(
    train_df, edge_cols=edge_cols, fit_scaler=False, scaler_edge=scaler,
    bin_seconds=BIN_SECONDS, device=DEVICE
)
test_snaps,  test_ip2idx,  _,      _ = build_snapshots(
    test_df,  edge_cols=edge_cols, fit_scaler=False, scaler_edge=scaler,
    bin_seconds=BIN_SECONDS, device=DEVICE
)

print(f"Train snapshots: {len(train_snaps)} | Test snapshots: {len(test_snaps)}")
in_node_dim = train_snaps[0].x.size(1)
in_edge_dim = train_snaps[0].edge_attr.size(1)
print("Node feature dim:", in_node_dim, "| Edge feature dim:", in_edge_dim)

Building snapshots...
Train snapshots: 673 | Test snapshots: 147
Node feature dim: 4 | Edge feature dim: 21


In [33]:
def train_one_sequence(model, optimizer, snapshots, device="cpu", detach_every=None):
    """
    Train across an ordered list of snapshots (one sequence).
    detach_every: if set (int), detach model state every N steps (truncated BPTT)
    """
    model.train()
    model.reset_state(hard=False)  # keep evolving but ensure conv weights are set

    total_loss = 0.0
    total_edges = 0

    for t, data in enumerate(snapshots):
        data = data.to(device)
        logits = model(data)
        loss = F.cross_entropy(logits, data.y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        total_loss += float(loss) * data.y.numel()
        total_edges += int(data.y.numel())

        if detach_every and (t + 1) % detach_every == 0:
            model.detach_state()

    return total_loss / max(total_edges, 1)


In [34]:
@torch.no_grad()
def eval_sequences(model, sequences, device="cpu"):
    model.eval()
    all_preds, all_true = [], []

    for snaps in sequences:
        model.reset_state(hard=False)
        for data in snaps:
            data = data.to(device)
            logits = model(data)
            pred = logits.argmax(dim=1).cpu().numpy()
            y = data.y.cpu().numpy()
            all_preds.append(pred)
            all_true.append(y)
    y_true = np.concatenate(all_true) if all_true else np.array([])
    y_pred = np.concatenate(all_preds) if all_preds else np.array([])
    acc = accuracy_score(y_true, y_pred) if y_true.size else 0.0
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1]) if y_true.size else np.zeros((2, 2), dtype=int)
    return acc, cm

In [35]:
model = EvolveGCNEdgeClassifier(
    in_node=in_node_dim,
    in_edge=in_edge_dim,
    hidden=64,
    out_hidden=64,
    num_layers=2,
    dropout=0.1,
    variant='H',      # 'H' (EGCU-H) or 'O' (EGCU-O)
    num_classes=2
).to(DEVICE)

optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-3, weight_decay=1e-4)

# Treat train/test snapshots each as a single ordered sequence.
# If you have multiple sequences, wrap them as list-of-lists.
train_sequences = [train_snaps]
test_sequences  = [test_snaps]

EPOCHS = 10
DETACH_EVERY = 5  # optional truncated BPTT across long sequences

In [36]:
for epoch in range(1, EPOCHS + 1):
    # One pass over the (single) train sequence
    train_loss = train_one_sequence(model, optimizer, train_snaps, device=DEVICE, detach_every=DETACH_EVERY)

    # Eval on test
    acc, cm = eval_sequences(model, test_sequences, device=DEVICE)

    print(f"[Epoch {epoch:02d}] loss={train_loss:.4f} | test acc={acc:.4f}")
    print("Confusion matrix (rows=true [0,1], cols=pred [0,1]):\n", cm)

print("Done.")

Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods.cpp:836.)
  total_loss += float(loss) * data.y.numel()


[Epoch 01] loss=15.5214 | test acc=0.8326
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5391219       0]
 [1084194       0]]
[Epoch 02] loss=6.9870 | test acc=0.8326
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5391219       0]
 [1084194       0]]
[Epoch 03] loss=1.0274 | test acc=0.8326
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5391219       0]
 [1084194       0]]
[Epoch 04] loss=0.8291 | test acc=0.9994
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5389077    2142]
 [   1901 1082293]]
[Epoch 05] loss=0.3244 | test acc=0.9991
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5387353    3866]
 [   1901 1082293]]
[Epoch 06] loss=0.0382 | test acc=0.8326
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5391219       0]
 [1084194       0]]
[Epoch 07] loss=0.2955 | test acc=0.9986
Confusion matrix (rows=true [0,1], cols=pred [0,1]):
 [[5383761    7458]
 [   1901 1082293]]
[Epoch 08] loss=14.1701 | test acc=0.9928
Confusion matrix (r