# Notebook 04: CNN CVE Severity Classification (Refactor v2)

Reiner Refactor der bestehenden CNN-Experimente (PyTorch) zur Reduktion von Redundanz. Keine neuen Features, nur Konsolidierung:
- Zusammengefasste Imports & Konfiguration
- Gemeinsame Hilfsfunktionen (Vokabular, Dataset, Training)
- Baseline Run + Param-Sweep wie original
- Interpretierbarkeit (Filter n-Grams) beibehalten
- Persistenz: `results/metrics_cnn.csv`, `results/cnn_filter_activation_examples.json`


## 1. Imports & Globale Konfiguration

Sammelt alle Bibliotheken, stellt Seed-Konfiguration und globale Hyperparameter bereit (Vokabular, Embedding-Dimensionen, Filtergrößen, Trainingsparameter). Ergebnisse & Artefakte werden nach `results/` geschrieben.


Setzt deterministische Seeds für Python, NumPy und PyTorch (inkl. CUDA falls verfügbar). Hinweis: Volle Deterministik kann Performance kosten, ist hier aber für Vergleichbarkeit wichtiger.

In [None]:
# Consolidated imports & config (placed early so helpers can use them)
import os, random, json, math, time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

DATA_PROCESSED = Path('data/processed') if Path('data/processed').exists() else Path('..') / 'data' / 'processed'
RESULTS_DIR = Path('results') if Path('results').exists() else Path('..') / 'results'
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

# Hyperparameters
MAX_LEN = 160
BATCH_SIZE = 128
EPOCHS = 5
PATIENCE = 2
LR = 1e-3
FILTER_SIZES = [3,4,5]
FILTERS_PER_SIZE = 128
DROPOUT = 0.5
EMBED_DIMS = [50, 100, 200]
VOCAB_SIZES = [10000, 25000]
MIN_FREQS = [1, 2]
CLASS_NAMES = ['low','medium','high','critical']
CLASS_TO_ID = {c:i for i,c in enumerate(CLASS_NAMES)}

print({'results_dir': str(RESULTS_DIR), 'data_dir': str(DATA_PROCESSED)})

## 2. Daten- & Vokabular-Hilfsfunktionen

Lädt Varianten, erkennt Text- und Labelspalten heuristisch, Tokenisierung (Whitespace, lowercase) und Vokabularaufbau mit `min_freq` & Begrenzung durch `vocab_size`. Dataset kapselt Padding & Trunkierung (`MAX_LEN`).

In [None]:
# Utilities: load variants, detect columns, vocab, dataset
from collections import Counter
from dataclasses import dataclass, asdict

VARIANT_FILES = {
    'raw': 'cves_processed_text_raw.csv',
    'clean': 'cves_processed_text_clean.csv',
    'raw_lemma': 'cves_processed_text_raw_lemma.csv',
    'clean_lemma': 'cves_processed_text_clean_lemma.csv'
}

TEXT_CANDIDATES = ['description_clean','description','text','summary']
LABEL_CANDIDATES = ['severity','cvss_severity','baseSeverity','label']

def load_variant(name: str) -> pd.DataFrame:
    fname = VARIANT_FILES[name]
    fp = DATA_PROCESSED / fname
    df = pd.read_csv(fp)
    return df

def detect_columns(df: pd.DataFrame):
    text_col = None
    for c in TEXT_CANDIDATES:
        if c in df.columns:
            text_col = c; break
    if text_col is None:
        object_cols = [c for c in df.columns if df[c].dtype=='object']
        avg_lengths = {c: df[c].astype(str).str.split().str.len().mean() for c in object_cols}
        text_col = max(avg_lengths, key=avg_lengths.get)
    label_col = None
    for c in LABEL_CANDIDATES:
        if c in df.columns:
            label_col = c; break
    if label_col is None:
        raise ValueError('No label column found.')
    return text_col, label_col

def tokenize(text: str):
    return text.lower().strip().split()

def build_vocab(texts, vocab_size: int, min_freq: int):
    counter = Counter()
    for t in texts:
        counter.update(tokenize(t))
    items = [(tok,freq) for tok,freq in counter.items() if freq >= min_freq]
    items.sort(key=lambda x: (-x[1], x[0]))
    trimmed = items[:vocab_size-2]
    stoi = {tok:i+2 for i,(tok,_) in enumerate(trimmed)}
    stoi['<pad>'] = 0
    stoi['<unk>'] = 1
    return stoi

class TextDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_len):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len
    def __len__(self):
        return len(self.texts)
    def encode(self, text):
        toks = tokenize(text)
        ids = [self.vocab.get(t,1) for t in toks][:self.max_len]
        if len(ids) < self.max_len:
            ids += [0]*(self.max_len - len(ids))
        return ids
    def __getitem__(self, idx):
        return torch.tensor(self.encode(self.texts[idx]), dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)


## 3. Modellarchitektur (Multi-Kernel CNN)

Embedding + parallele 1D-Convs (Kernelgrößen in `FILTER_SIZES`), ReLU, Global-Max-Pooling, Feature-Konkatenation, Dropout und lineare Klassifikation. Aktivierungen für Interpretierbarkeit werden zwischengespeichert.

In [None]:
# Model definition (multi-kernel CNN matching original logic)
class CNNTextClassifier(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, num_classes: int, filter_sizes, filters_per_size: int, dropout: float, pad_idx: int = 0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=embed_dim, out_channels=filters_per_size, kernel_size=fs, padding=fs//2)
            for fs in filter_sizes
        ])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(filters_per_size * len(filter_sizes), num_classes)
    def forward(self, x):
        emb = self.embedding(x)      # (B, L, E)
        emb = emb.transpose(1,2)     # (B, E, L)
        conv_outs = []
        activations = []
        for conv in self.convs:
            h = torch.relu(conv(emb))  # (B, F, L)
            activations.append(h.detach())
            pooled = torch.max(h, dim=2)[0]  # (B, F)
            conv_outs.append(pooled)
        cat = torch.cat(conv_outs, dim=1)
        cat = self.dropout(cat)
        logits = self.fc(cat)
        return logits, activations

## 4. Trainings- & Evaluations-Helfer

Enthält Epochen-Training (`train_epoch`), Validierung (`eval_model`) sowie Extraktion aktivster n-Grams je Filter zur Interpretierbarkeit (`extract_filter_ngrams`).

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits, _ = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    return total_loss / len(loader.dataset)

def eval_model(model, loader, criterion):
    model.eval()
    total_loss = 0
    all_y, all_p = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits, _ = model(xb)
            loss = criterion(logits, yb)
            total_loss += loss.item() * xb.size(0)
            preds = torch.argmax(logits, dim=1)
            all_y.extend(yb.cpu().tolist())
            all_p.extend(preds.cpu().tolist())
    f1 = f1_score(all_y, all_p, average='macro')
    acc = accuracy_score(all_y, all_p)
    return total_loss / len(loader.dataset), f1, acc, (all_y, all_p)

# Interpretability (filter n-grams) based on original logic

def extract_filter_ngrams(model, vocab_inv, sample_loader, max_per_filter=2):
    model.eval()
    results = []
    with torch.no_grad():
        for xb, yb in sample_loader:
            xb = xb.to(device)
            logits, activations = model(xb)
            for layer_idx, acts in enumerate(activations):
                B,F,L = acts.shape
                for f in range(F):
                    max_val = -1e9
                    best = None
                    for b in range(B):
                        vals = acts[b,f,:]
                        local_max, pos = torch.max(vals, dim=0)
                        if local_max.item() > max_val:
                            max_val = local_max.item()
                            best = (b, pos.item())
                    if best is None:
                        continue
                    b, pos = best
                    ksize = model.convs[layer_idx].kernel_size[0]
                    token_ids = xb[b, pos:pos+ksize].cpu().tolist()
                    tokens = [vocab_inv.get(tid, '<unk>') for tid in token_ids]
                    results.append({
                        'layer': layer_idx,
                        'kernel_size': ksize,
                        'filter_index': f,
                        'activation': max_val,
                        'ngram': tokens
                    })
            break
    grouped = {}
    for r in results:
        key = (r['layer'], r['filter_index'])
        grouped.setdefault(key, []).append(r)
    pruned = []
    for k,v in grouped.items():
        v.sort(key=lambda x: -x['activation'])
        pruned.extend(v[:max_per_filter])
    return pruned

## 5. Baseline Lauf

Erstellt einen Train/Val/Test Split (70/15/15), baut ein Vokabular mit Standardparametern und trainiert ein einzelnes CNN als Ausgangspunkt. Persistiert Metriken & erste Filter-Beispiele.

In [None]:
available_variants = [v for v,f in VARIANT_FILES.items() if (DATA_PROCESSED / f).exists()]
assert available_variants, 'No variant files available.'
BASE_VARIANT = available_variants[0]
print('Baseline variant:', BASE_VARIANT)

df_base = load_variant(BASE_VARIANT)
text_col, label_col = detect_columns(df_base)
# Filter to known classes
if label_col == 'severity':
    df_base = df_base[df_base[label_col].str.lower().isin(CLASS_NAMES)].copy()
labels_raw = df_base[label_col].astype(str).str.lower()
label_to_id = {l:i for i,l in enumerate(sorted(labels_raw.unique()))}
id_to_label = {v:k for k,v in label_to_id.items()}
labels_int = labels_raw.map(label_to_id).values
texts = df_base[text_col].fillna('').astype(str).values

X_train_txt, X_temp_txt, y_train, y_temp = train_test_split(texts, labels_int, test_size=0.3, stratify=labels_int, random_state=SEED)
X_val_txt, X_test_txt, y_val, y_test = train_test_split(X_temp_txt, y_temp, test_size=0.5, stratify=y_temp, random_state=SEED)
print({'train': len(X_train_txt), 'val': len(X_val_txt), 'test': len(X_test_txt)})

vocab = build_vocab(X_train_txt, vocab_size=VOCAB_SIZES[0], min_freq=MIN_FREQS[0])
vocab_inv = {v:k for k,v in vocab.items()}

train_ds = TextDataset(X_train_txt, y_train, vocab, MAX_LEN)
val_ds = TextDataset(X_val_txt, y_val, vocab, MAX_LEN)
test_ds = TextDataset(X_test_txt, y_test, vocab, MAX_LEN)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

model = CNNTextClassifier(vocab_size=len(vocab), embed_dim=EMBED_DIMS[0], num_classes=len(label_to_id), filter_sizes=FILTER_SIZES, filters_per_size=FILTERS_PER_SIZE, dropout=DROPOUT).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

best_f1 = -1
best_state = None
epochs_no_improve = 0
history = {'train_loss':[], 'val_loss':[], 'train_f1':[], 'val_f1':[], 'train_acc':[], 'val_acc':[]}
for ep in range(EPOCHS):
    tr_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_f1, val_acc, _ = eval_model(model, val_loader, criterion)
    history['train_loss'].append(tr_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(val_f1)
    history['val_acc'].append(val_acc)
    # quick train stats (optional compute) - skip extra pass for simplicity
    print(f"Epoch {ep+1}/{EPOCHS} tr_loss={tr_loss:.4f} val_loss={val_loss:.4f} val_f1={val_f1:.4f} val_acc={val_acc:.4f}")
    if val_f1 > best_f1:
        best_f1 = val_f1
        best_state = model.state_dict()
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print('Early stopping')
            break

if best_state:
    model.load_state_dict(best_state)

# Test evaluation
_, test_f1, test_acc, (y_true_test, y_pred_test) = eval_model(model, test_loader, criterion)
print('Test macro F1:', test_f1, 'Test Acc:', test_acc)

# Confusion matrix
cm = confusion_matrix(y_true_test, y_pred_test)
print('Confusion matrix:\n', cm)

# Interpretability sample
sample_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
filter_examples = extract_filter_ngrams(model, vocab_inv, sample_loader, max_per_filter=2)

# Persist single baseline metrics (append or create)
metrics_path = RESULTS_DIR / 'metrics_cnn.csv'
rec = {
    'phase': 'baseline',
    'variant': BASE_VARIANT,
    'embed_dim': EMBED_DIMS[0],
    'vocab_size': VOCAB_SIZES[0],
    'min_freq': MIN_FREQS[0],
    'macro_f1': best_f1,
    'test_macro_f1': test_f1,
    'test_accuracy': test_acc,
    'n_params': sum(p.numel() for p in model.parameters())
}
if metrics_path.exists():
    dfm = pd.read_csv(metrics_path)
    dfm = pd.concat([dfm, pd.DataFrame([rec])], ignore_index=True)
else:
    dfm = pd.DataFrame([rec])
dfm.to_csv(metrics_path, index=False)

interpret_path = RESULTS_DIR / 'cnn_filter_activation_examples.json'
interpret_data = []
if interpret_path.exists():
    try:
        interpret_data = json.loads(interpret_path.read_text())
    except Exception:
        interpret_data = []
interpret_data.append({'config': rec, 'filter_examples': filter_examples})
interpret_path.write_text(json.dumps(interpret_data, indent=2))
print('Baseline recorded.')

## 6. Parameter-Sweep

Iteriert über Varianten und Hyperparameter-Kombinationen (`embed_dim`, `vocab_size`, `min_freq`). Speichert jede Konfiguration inkrementell in `metrics_cnn.csv`. Für die größte Embedding-Dimension werden zusätzlich Filter-n-Gram Beispiele gespeichert. Frühes Stoppen basierend auf Macro-F1.

In [None]:
sweep_records = []
metrics_path = RESULTS_DIR / 'metrics_cnn.csv'
interpret_path = RESULTS_DIR / 'cnn_filter_activation_examples.json'

variants_to_run = [v for v in VARIANT_FILES if (DATA_PROCESSED / VARIANT_FILES[v]).exists()]
print('Variants available for sweep:', variants_to_run)

for variant in variants_to_run:
    dfv = load_variant(variant)
    tcol, lcol = detect_columns(dfv)
    if lcol == 'severity':
        dfv = dfv[dfv[lcol].str.lower().isin(CLASS_NAMES)].copy()
    labels_raw = dfv[lcol].astype(str).str.lower()
    label_map = {l:i for i,l in enumerate(sorted(labels_raw.unique()))}
    inv_label_map = {v:k for k,v in label_map.items()}
    y_all = labels_raw.map(label_map).values
    texts_all = dfv[tcol].fillna('').astype(str).values
    if len(texts_all) < 100:
        print('Skip small variant', variant)
        continue
    X_train_txt, X_val_txt, y_train, y_val = train_test_split(texts_all, y_all, test_size=0.2, stratify=y_all, random_state=SEED)

    for embed_dim in EMBED_DIMS:
        for vocab_size in VOCAB_SIZES:
            for min_freq in MIN_FREQS:
                print(f"Run variant={variant} emb={embed_dim} vocab={vocab_size} min_freq={min_freq}")
                vocab = build_vocab(X_train_txt, vocab_size=vocab_size, min_freq=min_freq)
                vocab_inv = {v:k for k,v in vocab.items()}
                train_ds = TextDataset(X_train_txt, y_train, vocab, MAX_LEN)
                val_ds = TextDataset(X_val_txt, y_val, vocab, MAX_LEN)
                train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
                val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
                model = CNNTextClassifier(vocab_size=len(vocab), embed_dim=embed_dim, num_classes=len(label_map), filter_sizes=FILTER_SIZES, filters_per_size=FILTERS_PER_SIZE, dropout=DROPOUT).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=LR)
                criterion = nn.CrossEntropyLoss()
                best_f1 = -1
                best_state = None
                epochs_no_improve = 0
                start = time.time()
                for ep in range(EPOCHS):
                    tr_loss = train_epoch(model, train_loader, optimizer, criterion)
                    val_loss, val_f1, val_acc, _ = eval_model(model, val_loader, criterion)
                    print(f"  Ep {ep+1} tr_loss={tr_loss:.3f} val_loss={val_loss:.3f} f1={val_f1:.3f} acc={val_acc:.3f}")
                    if val_f1 > best_f1:
                        best_f1 = val_f1
                        best_state = model.state_dict()
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1
                        if epochs_no_improve >= PATIENCE:
                            print('  Early stop')
                            break
                duration = round(time.time()-start,2)
                if best_state:
                    model.load_state_dict(best_state)
                    val_loss, val_f1, val_acc, _ = eval_model(model, val_loader, criterion)
                rec = {
                    'phase': 'sweep',
                    'variant': variant,
                    'embed_dim': embed_dim,
                    'vocab_size': vocab_size,
                    'min_freq': min_freq,
                    'macro_f1': val_f1,
                    'accuracy': val_acc,
                    'train_time_s': duration,
                    'n_params': sum(p.numel() for p in model.parameters())
                }
                sweep_records.append(rec)
                # Incremental persistence
                if metrics_path.exists():
                    dfm = pd.read_csv(metrics_path)
                    dfm = pd.concat([dfm, pd.DataFrame([rec])], ignore_index=True)
                else:
                    dfm = pd.DataFrame([rec])
                dfm.to_csv(metrics_path, index=False)
                # Interpretability only for largest embedding dim
                if embed_dim == max(EMBED_DIMS):
                    sample_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
                    filt_examples = extract_filter_ngrams(model, vocab_inv, sample_loader, max_per_filter=2)
                    interpret_data = []
                    if interpret_path.exists():
                        try:
                            interpret_data = json.loads(interpret_path.read_text())
                        except Exception:
                            interpret_data = []
                    interpret_data.append({'config': rec, 'filter_examples': filt_examples})
                    interpret_path.write_text(json.dumps(interpret_data, indent=2))
print('Sweep done; metrics at', metrics_path)

## 7. Verlauf & Trainingsmetriken (Baseline)

Visualisiert Verlaufsdaten (Loss, Accuracy, Macro-F1) des Baseline-Laufs sofern History vorhanden. Sweep-Metriken werden ausschließlich über die aggregierte CSV ausgewertet (separat in Evaluation Notebook).

In [None]:
import matplotlib.pyplot as plt
if 'history' in globals() and history['train_loss']:
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.plot(history['train_loss'], label='train_loss')
    plt.plot(history['val_loss'], label='val_loss')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,3,2)
    plt.plot(history['val_acc'], label='val_acc')
    plt.legend(); plt.title('Val Acc')
    plt.subplot(1,3,3)
    plt.plot(history['val_f1'], label='val_f1')
    plt.legend(); plt.title('Val F1')
    plt.tight_layout()
else:
    print('No baseline history to plot.')