# BGL Early Warning (Paper-style + Phase 3)

In [None]:
import os, json, random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, precision_recall_curve

!pip install drain3 -q
from drain3 import TemplateMiner
from drain3.template_miner_config import TemplateMinerConfig

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# ===== KAGGLE PATHS =====
BASE = "/kaggle/working"
BGL_FILE = "/kaggle/input/loghub-bgl-log-data/BGL.log"
CKPT_DIR = "/kaggle/input/bgltest"
EW_OUTPUT = f"{BASE}/output-bglew3"
os.makedirs(EW_OUTPUT, exist_ok=True)

CONTEXT_LEN, D_MODEL, N_HEADS, N_LAYERS = 128, 256, 8, 4
PAD, CLS, MASK, SEP = 0, 1, 2, 3
BATCH_SIZE, SEED = 64, 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DELTA_MIN = 5
DELTA_SEC = DELTA_MIN * 60
WINDOW_SIZE, STEP_SIZE = 128, 8
SAMPLING_RATIO = 5

In [None]:
# CREATE WINDOWS (True pre-warning on viable subset)
# Traces are already windows from sliding window sessionization
# Label = 1 if alert coming; Label = 0 if normal or too close to alert

DELTA_EVENTS = 50  # Alert within next 50 events = pre-warning positive

def create_prewarn_data(traces):
    """Create pre-warning training data from viable traces"""
    windows = []
    for t in traces:
        first_alert = t['first_alert']
        # Only use prefix BEFORE first_alert
        # Label = 1 if we're within DELTA_EVENTS of alert
        for end in range(STRIDE, first_alert, STRIDE):  # Sliding within viable prefix
            prefix_tids = t['tids'][:end]
            distance_to_alert = first_alert - end
            label = 1 if distance_to_alert <= DELTA_EVENTS else 0
            windows.append({
                'tids': prefix_tids,
                'label': label,
                'distance': distance_to_alert
            })
    return windows

def create_normal_windows(traces, max_per_trace=3):
    """Normal windows (label=0)"""
    windows = []
    for t in traces:
        for end in range(STRIDE, min(WINDOW, len(t['tids'])), STRIDE)[:max_per_trace]:
            windows.append({
                'tids': t['tids'][:end],
                'label': 0
            })
    return windows

# Split train/test
rng = np.random.RandomState(SEED)
n_idx = rng.permutation(len(normal_traces))
f_idx = rng.permutation(len(failure_traces))
n_split = int(0.7 * len(normal_traces))
f_split = int(0.7 * len(failure_traces))

train_normal = [normal_traces[i] for i in n_idx[:n_split]]
val_normal = [normal_traces[i] for i in n_idx[n_split:int(0.85*len(normal_traces))]]
test_normal = [normal_traces[i] for i in n_idx[int(0.85*len(normal_traces)):]]

train_failure = [failure_traces[i] for i in f_idx[:f_split]]
val_failure = [failure_traces[i] for i in f_idx[f_split:int(0.85*len(failure_traces))]]
test_failure = [failure_traces[i] for i in f_idx[int(0.85*len(failure_traces)):]]

print(f"Train: {len(train_normal)} N, {len(train_failure)} F")
print(f"Val: {len(val_normal)} N, {len(val_failure)} F")
print(f"Test: {len(test_normal)} N, {len(test_failure)} F")

# Create training windows
train_prewarn = create_prewarn_data(train_failure)
train_norm = create_normal_windows(train_normal)
train_windows = train_norm + train_prewarn

n_pos = sum(1 for w in train_windows if w['label']==1)
print(f"\nTraining windows: {len(train_windows):,} (pre-warning pos={n_pos:,})")


In [None]:
# LOAD CHECKPOINT
ckpt = torch.load(f"{CKPT_DIR}/checkpoints_logbert_ep11.pt", map_location=device)
OLD_VOCAB_SIZE = ckpt['model_state_dict']['tok.weight'].shape[0]
print(f"HDFS VOCAB: {OLD_VOCAB_SIZE}")

config = TemplateMinerConfig()
config.drain_depth = 4
config.drain_sim_th = 0.4
miner = TemplateMiner(config=config)

In [None]:
# PARSE BGL
print("PARSING BGL")

def get_cluster_id(result):
    return result.get('cluster_id', 0) if isinstance(result, dict) else result.cluster_id

bgl_events = []
with open(BGL_FILE, 'r', errors='ignore') as f:
    for line in tqdm(f, desc="BGL"):
        line = line.strip()
        if not line: continue
        parts = line.split(None, 9)
        if len(parts) < 2: continue
        is_anomaly = parts[0] != '-'
        try: timestamp = int(parts[1])
        except: timestamp = None
        message = parts[-1]
        tid = OLD_VOCAB_SIZE + get_cluster_id(miner.add_log_message(message))
        bgl_events.append((timestamp, tid, is_anomaly))

NUM_CLUSTERS = len(miner.drain.clusters)
VOCAB_SIZE = OLD_VOCAB_SIZE + NUM_CLUSTERS + 10
print(f"Events: {len(bgl_events):,}, BGL Clusters: {NUM_CLUSTERS}, Total VOCAB: {VOCAB_SIZE}")

In [None]:
# CREATE TRACES (Sliding window, viable pre-warning only)
print("\nCREATING TRACES (Window=100, Stride=20, viable only)")
WINDOW, STRIDE = 100, 20

# Create all sliding windows
all_windows = []
n = len(bgl_events)
for i in range(0, n - WINDOW + 1, STRIDE):
    window_events = bgl_events[i:i+WINDOW]
    tids = [e[1] for e in window_events]
    first_alert = next((j for j, e in enumerate(window_events) if e[2]), WINDOW)
    all_windows.append({
        'tids': tids,
        'events': window_events,
        'first_alert': first_alert,
        'has_alert': first_alert < WINDOW
    })

# Separate normal vs failure
normal_traces = [w for w in all_windows if not w['has_alert']]
failure_traces_all = [w for w in all_windows if w['has_alert']]

# Viable pre-warning: alert at position >= STRIDE
failure_traces = [w for w in failure_traces_all if w['first_alert'] >= STRIDE]

print(f"Total windows: {len(all_windows):,}")
print(f"Normal: {len(normal_traces):,}")
print(f"Failure (all): {len(failure_traces_all):,}")
print(f"Failure (viable pre-warning, alert >= {STRIDE}): {len(failure_traces):,}")
print(f"")
print(f">>> Theoretical max recall: {100*len(failure_traces)/len(failure_traces_all):.1f}%")
print(f">>> Training on viable subset only")


In [None]:
# CREATE WINDOWS (Viable Pre-Warning Subset)
WINDOW_SIZE, STEP_SIZE = 128, 8
MIN_TIME_TO_ANOM = 60  # Only traces with >60s before first anomaly

def filter_viable_traces(traces, min_time=MIN_TIME_TO_ANOM):
    """Keep only traces where we have >min_time seconds before first anomaly"""
    viable = []
    for t in traces:
        events = t.get('events', [])
        if not events: continue
        first_ts = events[0][0]
        first_anom_idx = next((j for j, e in enumerate(events) if e[2]), len(events))
        if first_anom_idx >= len(events): continue  # No anomaly
        first_anom_ts = events[first_anom_idx][0]
        time_to_anom = first_anom_ts - first_ts if first_ts and first_anom_ts else 0
        if time_to_anom > min_time:
            viable.append(t)
    return viable

def create_prewarn_windows(traces):
    """Create windows BEFORE first anomaly only (true pre-warning)"""
    windows = []
    for trace in traces:
        tids, events = trace['tids'], trace.get('events')
        if not events: continue
        n = len(tids)
        first_anom_idx = next((j for j, e in enumerate(events) if e[2]), n)
        
        # Only create windows BEFORE first anomaly
        for i in range(0, min(first_anom_idx - WINDOW_SIZE, n - WINDOW_SIZE + 1), STEP_SIZE):
            end = min(i + WINDOW_SIZE, n)
            if end <= first_anom_idx:  # Window ends before anomaly
                # Positive if anomaly coming within DELTA_SEC
                end_time = events[min(end-1, n-1)][0]
                first_anom_ts = events[first_anom_idx][0]
                time_to_anom = first_anom_ts - end_time if end_time and first_anom_ts else float('inf')
                label = 1 if time_to_anom <= DELTA_SEC else 0
                windows.append({'tids': tids[i:end], 'label': label, 'time_to_anom': time_to_anom})
    return windows

def create_normal_windows(traces):
    """Normal windows (label=0)"""
    windows = []
    for trace in traces:
        tids = trace['tids']
        n = len(tids)
        for i in range(0, max(1, n - WINDOW_SIZE + 1), STEP_SIZE * 4):  # Subsample
            end = min(i + WINDOW_SIZE, n)
            windows.append({'tids': tids[i:end], 'label': 0})
    return windows

# Filter to viable traces
rng = np.random.RandomState(SEED)
n_idx, f_idx = rng.permutation(len(normal_traces)), rng.permutation(len(failure_traces))
n_split, f_split = int(0.8*len(normal_traces)), int(0.8*len(failure_traces))
train_normal = [normal_traces[i] for i in n_idx[:n_split]]
test_normal = [normal_traces[i] for i in n_idx[n_split:]]
train_failure_all = [failure_traces[i] for i in f_idx[:f_split]]
test_failure_all = [failure_traces[i] for i in f_idx[f_split:]]

# Filter viable
train_failure = filter_viable_traces(train_failure_all)
test_failure = filter_viable_traces(test_failure_all)
print(f"Viable: train={len(train_failure)}/{len(train_failure_all)}, test={len(test_failure)}/{len(test_failure_all)}")

# Create pre-warning windows
train_prewarn = create_prewarn_windows(train_failure)
train_norm_windows = create_normal_windows(train_normal)
train_windows = train_norm_windows + train_prewarn

n_pos = sum(1 for w in train_windows if w['label']==1)
print(f"Train windows: {len(train_windows):,} (prewarn_pos={n_pos:,})")
print(f"Test: {len(test_normal)} N, {len(test_failure)} F (viable)")


In [None]:
# SAVE PARSED DATA (for re-evaluation without re-parsing)
import pickle
save_data = {
    'bgl_events': bgl_events,
    'normal_traces': normal_traces,
    'failure_traces': failure_traces,
    'NUM_CLUSTERS': NUM_CLUSTERS,
    'VOCAB_SIZE': VOCAB_SIZE,
    'OLD_VOCAB_SIZE': OLD_VOCAB_SIZE,
    'train_normal': train_normal,
    'test_normal': test_normal,
    'train_failure': train_failure,
    'test_failure': test_failure
}
with open(f"{EW_OUTPUT}/bgl_parsed_data.pkl", 'wb') as f:
    pickle.dump(save_data, f)
print(f"Saved: {EW_OUTPUT}/bgl_parsed_data.pkl")

In [None]:
# DATASET & MODEL
class BalancedDataset(Dataset):
    def __init__(self, windows, vs, ratio=SAMPLING_RATIO):
        self.pos = [w for w in windows if w['label']==1]
        self.neg = [w for w in windows if w['label']==0]
        self.vs, self.ratio = vs, ratio
        self.n = len(self.pos)*(1+ratio) if self.pos else len(self.neg)
    def __len__(self): return self.n
    def __getitem__(self, idx):
        w = random.choice(self.pos) if (self.pos and idx%(self.ratio+1)==0) else random.choice(self.neg if self.neg else self.pos)
        s = [min(t, self.vs-1) for t in w['tids'][:CONTEXT_LEN-2]]
        tok = [CLS]+s+[SEP]+[PAD]*(CONTEXT_LEN-len(s)-2)
        masked, mlm_l = tok.copy(), [-100]*len(tok)
        if len(s)>0:
            for p in np.random.choice(len(s), min(max(1,int(len(s)*0.15)), len(s)), replace=False)+1:
                mlm_l[p] = masked[p]; masked[p] = MASK
        return {'ids': torch.tensor(masked), 'mask': torch.tensor([1 if t!=PAD else 0 for t in masked]),
                'ew_label': torch.tensor(w['label'], dtype=torch.float), 'mlm_labels': torch.tensor(mlm_l)}

class LogBERTEW(nn.Module):
    def __init__(self, vs):
        super().__init__()
        self.tok = nn.Embedding(vs, D_MODEL, padding_idx=PAD)
        self.pos = nn.Embedding(CONTEXT_LEN, D_MODEL)
        self.drop = nn.Dropout(0.1)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, 0.1, 'gelu', batch_first=True), N_LAYERS)
        self.head = nn.Linear(D_MODEL, vs)
        self.ew_head = nn.Linear(D_MODEL, 1)
        self.register_buffer('ctr', torch.zeros(D_MODEL)); self.ci = False
    def forward(self, ids, mask=None):
        x = self.tok(ids) + self.pos(torch.arange(ids.size(1), device=ids.device))
        h = self.enc(self.drop(x), src_key_padding_mask=(mask==0) if mask is not None else None)
        return self.head(h), h[:,0,:], self.ew_head(h[:,0,:]).squeeze(-1)
    def upd(self, e):
        with torch.no_grad(): bc = e.mean(0); self.ctr = bc if not self.ci else 0.9*self.ctr+0.1*bc; self.ci = True

In [None]:
# MODEL INIT
model = LogBERTEW(VOCAB_SIZE).to(device)
with torch.no_grad():
    model.tok.weight[:OLD_VOCAB_SIZE] = ckpt['model_state_dict']['tok.weight']
    nn.init.normal_(model.tok.weight[OLD_VOCAB_SIZE:], std=0.02)
state = {k:v for k,v in ckpt['model_state_dict'].items() if 'tok.weight' not in k and 'head.' not in k}
model.load_state_dict(state, strict=False)
with torch.no_grad():
    model.head.weight[:OLD_VOCAB_SIZE] = ckpt['model_state_dict']['head.weight']
    model.head.bias[:OLD_VOCAB_SIZE] = ckpt['model_state_dict']['head.bias']
model.ctr = ckpt['center'].to(device); model.ci = True
print(f"Model loaded: {OLD_VOCAB_SIZE} -> {VOCAB_SIZE}")

In [None]:
# TRAINING
LAMBDA_MLM, LAMBDA_VHM, MU_EW = 0.4, 0.1, 1.0
train_pos = sum(1 for w in train_windows if w['label']==1)
POS_WEIGHT = (len(train_windows)-train_pos) / max(train_pos, 1)
train_ds = BalancedDataset(train_windows, VOCAB_SIZE)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
ew_crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=device))
scaler = torch.amp.GradScaler('cuda') if device.type=='cuda' else None
best = float('inf')

# Phase 1
print("PHASE 1")
for p in model.parameters(): p.requires_grad = False
for p in model.ew_head.parameters(): p.requires_grad = True
opt = torch.optim.Adam(model.ew_head.parameters(), lr=1e-4)
for ep in range(2):
    model.train(); tl = 0
    for b in tqdm(train_loader, desc=f"P1.{ep+1}"):
        opt.zero_grad(); _, _, ew = model(b['ids'].to(device), b['mask'].to(device))
        loss = ew_crit(ew, b['ew_label'].to(device)); loss.backward(); opt.step(); tl += loss.item()
    print(f"  {ep+1}: {tl/len(train_loader):.4f}")

# Phase 2
print("\nPHASE 2")
for p in model.parameters(): p.requires_grad = True
opt = torch.optim.AdamW(model.parameters(), lr=3e-5)
for ep in range(8):
    model.train(); tl = 0
    for b in tqdm(train_loader, desc=f"P2.{ep+1}"):
        ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew_label'].to(device), b['mlm_labels'].to(device)
        opt.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                mlm_lg, cls, ew = model(ids, mask)
                loss = LAMBDA_MLM*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + LAMBDA_VHM*torch.mean((cls-model.ctr)**2) + MU_EW*ew_crit(ew, ew_l)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            mlm_lg, cls, ew = model(ids, mask)
            loss = LAMBDA_MLM*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + LAMBDA_VHM*torch.mean((cls-model.ctr)**2) + MU_EW*ew_crit(ew, ew_l)
            loss.backward(); opt.step()
        model.upd(cls.detach()); tl += loss.item()
    avg = tl/len(train_loader); print(f"  {ep+1}: {avg:.4f}")
    if avg < best: best = avg; torch.save(model.state_dict(), f"{EW_OUTPUT}/bgl_ew_best.pt")

In [None]:
# PHASE 3: Fine-tune LR=1e-5 + early stop
print("\nPHASE 3 (LR=1e-5 + early-stop)")
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
patience, no_improve = 2, 0

for ep in range(8):
    model.train(); tl = 0
    for b in tqdm(train_loader, desc=f"P3.{ep+1}"):
        ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew_label'].to(device), b['mlm_labels'].to(device)
        opt.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                mlm_lg, cls, ew = model(ids, mask)
                loss = LAMBDA_MLM*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + LAMBDA_VHM*torch.mean((cls-model.ctr)**2) + MU_EW*ew_crit(ew, ew_l)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            mlm_lg, cls, ew = model(ids, mask)
            loss = LAMBDA_MLM*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + LAMBDA_VHM*torch.mean((cls-model.ctr)**2) + MU_EW*ew_crit(ew, ew_l)
            loss.backward(); opt.step()
        model.upd(cls.detach()); tl += loss.item()
    avg = tl/len(train_loader); print(f"  {ep+1}: {avg:.4f}")
    if avg < best:
        best = avg; torch.save(model.state_dict(), f"{EW_OUTPUT}/bgl_ew_best.pt"); no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience: print(f"  Early stop"); break
print(f"Best loss: {best:.4f}")

In [None]:
# STAGE 2: TRUE PRE-WARNING (Clean windows only)
print("
" + "="*60)
print("STAGE 2: TRUE Pre-Warning (clean windows only)")
print("="*60)

def create_prewarn_windows(traces):
    """Windows WITHOUT anomaly inside, but WITH anomaly coming in DELTA_SEC"""
    windows = []
    for trace in traces:
        tids, events = trace['tids'], trace.get('events')
        if not events: continue
        n = len(tids)
        for i in range(0, max(1, n - WINDOW_SIZE + 1), STEP_SIZE):
            end = min(i + WINDOW_SIZE, n)
            has_current = any(events[k][2] for k in range(i, min(end, n)))
            if has_current: continue
            end_time = events[min(end-1, n-1)][0]
            has_future = False
            if end_time:
                for j in range(end, n):
                    if events[j][0] and events[j][0] > end_time + DELTA_SEC: break
                    if events[j][2]: has_future = True; break
            if has_future:
                windows.append({'tids': tids[i:end], 'label': 1})
    return windows

prewarn_windows = create_prewarn_windows(train_failure)
s2_normal = create_windows(train_normal, is_failure=False)
print(f"Stage 2: {len(prewarn_windows)} pre-warning + {len(s2_normal)} normal")

if len(prewarn_windows) > 0:
    model.load_state_dict(torch.load(f"{EW_OUTPUT}/bgl_ew_best.pt", map_location=device))
    s2_windows = s2_normal + prewarn_windows
    s2_pos = sum(1 for w in s2_windows if w['label']==1)
    POS_WEIGHT_S2 = (len(s2_windows)-s2_pos) / max(s2_pos, 1)
    s2_ds = BalancedDataset(s2_windows, VOCAB_SIZE, ratio=3)
    s2_loader = DataLoader(s2_ds, batch_size=BATCH_SIZE, shuffle=True)
    ew_crit_s2 = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT_S2], device=device))
    opt = torch.optim.AdamW(model.parameters(), lr=5e-6)
    best_s2 = float('inf')
    for ep in range(10):
        model.train(); tl = 0
        for b in tqdm(s2_loader, desc=f"S2.{ep+1}"):
            ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew_label'].to(device), b['mlm_labels'].to(device)
            opt.zero_grad()
            if scaler:
                with torch.amp.autocast('cuda'):
                    mlm_lg, cls, ew = model(ids, mask)
                    loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + 1.0*ew_crit_s2(ew, ew_l)
                scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            else:
                mlm_lg, cls, ew = model(ids, mask)
                loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + 1.0*ew_crit_s2(ew, ew_l)
                loss.backward(); opt.step()
            model.upd(cls.detach()); tl += loss.item()
        avg = tl/len(s2_loader); print(f"  {ep+1}: {avg:.4f}")
        if avg < best_s2: best_s2 = avg; torch.save(model.state_dict(), f"{EW_OUTPUT}/bgl_ew_prewarn.pt")
    print(f"Stage 2 best: {best_s2:.4f}")
    model.load_state_dict(torch.load(f"{EW_OUTPUT}/bgl_ew_prewarn.pt", map_location=device))
else:
    print("No clean pre-warning windows! Skipping Stage 2.")


In [None]:
# STAGE 2: Pre-Warning Only Fine-tuning
print("\n" + "="*60)
print("STAGE 2: Pre-Warning Only (clean windows)")
print("="*60)

if len(s2_prewarn) > 0:
    model.load_state_dict(torch.load(f"{EW_OUTPUT}/bgl_ew_best.pt", map_location=device))
    
    s2_ds = BalancedDataset(s2_windows, VOCAB_SIZE, ratio=3)
    s2_loader = DataLoader(s2_ds, batch_size=BATCH_SIZE, shuffle=True)
    s2_pos_weight = (len(s2_windows) - s2_pos) / max(s2_pos, 1)
    ew_crit_s2 = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([s2_pos_weight], device=device))
    
    opt = torch.optim.AdamW(model.parameters(), lr=5e-6)
    best_s2 = float('inf')
    
    for ep in range(8):
        model.train(); tl = 0
        for b in tqdm(s2_loader, desc=f"S2.{ep+1}"):
            ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew_label'].to(device), b['mlm_labels'].to(device)
            opt.zero_grad()
            if scaler:
                with torch.amp.autocast('cuda'):
                    mlm_lg, cls, ew = model(ids, mask)
                    loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + ew_crit_s2(ew, ew_l)
                scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            else:
                mlm_lg, cls, ew = model(ids, mask)
                loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + ew_crit_s2(ew, ew_l)
                loss.backward(); opt.step()
            model.upd(cls.detach()); tl += loss.item()
        avg = tl/len(s2_loader); print(f"  {ep+1}: {avg:.4f}")
        if avg < best_s2:
            best_s2 = avg
            torch.save(model.state_dict(), f"{EW_OUTPUT}/bgl_ew_prewarn.pt")
    print(f"Stage 2 best: {best_s2:.4f}")
    model.load_state_dict(torch.load(f"{EW_OUTPUT}/bgl_ew_prewarn.pt", map_location=device))
else:
    print("No pre-warning windows! Using Stage 1 model.")


In [None]:
# EVALUATION (True Pre-Warning: score before first anomaly, val threshold)
print("\nEVALUATION (True Pre-Warning)")
model.eval()

@torch.no_grad()
def score_trace(tids, step=4):
    probs = []
    for k in range(min(WINDOW_SIZE, len(tids)), len(tids)+1, step):
        s = [min(t, VOCAB_SIZE-1) for t in tids[max(0,k-WINDOW_SIZE):k]]
        if len(s) < 5: continue
        tok = [CLS]+s[:CONTEXT_LEN-2]+[SEP]+[PAD]*(CONTEXT_LEN-len(s[:CONTEXT_LEN-2])-2)
        ids = torch.tensor([tok], device=device)
        _, _, ew = model(ids)
        probs.append((k, torch.sigmoid(ew).item()))
    return probs

def first_anomaly_idx(trace):
    events = trace.get('events', [])
    for i, e in enumerate(events):
        if e[2]: return i
    return len(events)

# Score all traces
print("Scoring test traces...")
test_scores = []
for t in tqdm(test_failure):
    probs = score_trace(t['tids'])
    fa_idx = first_anomaly_idx(t)
    # Max prob BEFORE first anomaly
    pre_probs = [(k, p) for k, p in probs if k < fa_idx]
    max_pre = max((p for _, p in pre_probs), default=0.0) if pre_probs else 0.0
    test_scores.append({'trace': t, 'max_pre': max_pre, 'probs': probs, 'fa_idx': fa_idx, 'pre_probs': pre_probs})

normal_scores = []
for t in tqdm(test_normal):
    probs = score_trace(t['tids'])
    max_prob = max((p for _, p in probs), default=0.0)
    normal_scores.append({'trace': t, 'max_prob': max_prob})

print(f"Scored: {len(test_scores)} F, {len(normal_scores)} N")

# Threshold from validation (0.99 quantile of normal)
val_normal_probs = [s['max_prob'] for s in normal_scores[:len(normal_scores)//2]]  # Use half as val
threshold = np.quantile(val_normal_probs, 0.99) if val_normal_probs else 0.5
print(f"Threshold (0.99 quantile): {threshold:.4f}")

# Evaluate
tp, fp, leads = 0, 0, []
for s in test_scores:
    if s['max_pre'] > threshold:
        tp += 1
        # Lead = first alarm idx to first anomaly idx
        alarm_idx = next((k for k, p in s['pre_probs'] if p > threshold), s['fa_idx'])
        lead = s['fa_idx'] - alarm_idx
        if lead > 0: leads.append(lead)

for s in normal_scores[len(normal_scores)//2:]:  # Test normals
    if s['max_prob'] > threshold:
        fp += 1

nF = len(test_scores)
nN = len(normal_scores) - len(normal_scores)//2
recall = tp / nF if nF > 0 else 0
far = fp / nN if nN > 0 else 0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

print(f"\n=== TRUE PRE-WARNING RESULTS ===")
print(f"Recall: {recall:.3f} ({tp}/{nF})")
print(f"FAR: {far:.3f} ({fp}/{nN})")
print(f"Precision: {precision:.3f}")
print(f"F1: {f1:.3f}")
if leads:
    print(f"Lead: median={np.median(leads):.0f}, mean={np.mean(leads):.1f} events")


In [None]:
# F1-OPTIMAL + METRICS (FIXED: lead to first anomaly)
all_probs = [s['max_prob'] for s in failure_scores] + [s['max_prob'] for s in normal_scores]
all_labels = [1]*len(failure_scores) + [0]*len(normal_scores)

auc = roc_auc_score(all_labels, all_probs) if len(set(all_labels)) > 1 else 0.5
prec, rec, thr = precision_recall_curve(all_labels, all_probs)
f1 = 2*prec*rec/(prec+rec+1e-9)
best_idx = f1.argmax()
TH = thr[best_idx] if best_idx < len(thr) else thr[-1] if len(thr) > 0 else 0.5

tp, leads = [], []
for s in failure_scores:
    if s['max_prob'] > TH:
        tp.append(s)
        # FIXED: Lead to FIRST ANOMALY, not trace end
        events = s['trace'].get('events', [])
        first_anom_idx = next((i for i, e in enumerate(events) if e[2]), len(events))
        for k, p in s['probs']:
            if p > TH:
                lead = first_anom_idx - k
                if lead > 0:
                    leads.append(lead)
                break
fp = [s for s in normal_scores if s['max_prob'] > TH]

recall = len(tp)/len(failure_scores) if failure_scores else 0
far = len(fp)/len(normal_scores) if normal_scores else 0
precision = len(tp)/(len(tp)+len(fp)) if (len(tp)+len(fp))>0 else 0
f1_final = 2*precision*recall/(precision+recall) if (precision+recall)>0 else 0

print(f"\n=== RESULTS ===")
print(f"AUC: {auc:.4f}, TH: {TH:.4f}")
print(f"F1: {f1_final:.4f}, P: {precision:.4f}, R: {recall:.4f}, FAR: {far:.4f}")
if leads: print(f"Lead: median={np.median(leads):.0f}, mean={np.mean(leads):.0f} (to first anomaly)")
print(f"TP: {len(tp)}, FP: {len(fp)}, FN: {len(failure_scores)-len(tp)}, TN: {len(normal_scores)-len(fp)}")

In [None]:
# SAVE
results = {'dataset': 'BGL', 'method': 'joint_ew_head_v2',
    'metrics': {'auc': float(auc), 'threshold': float(TH), 'f1': float(f1_final),
        'precision': float(precision), 'recall': float(recall), 'far': float(far),
        'lead_median': float(np.median(leads)) if leads else 0,
        'lead_mean': float(np.mean(leads)) if leads else 0}}
with open(f"{EW_OUTPUT}/bgl_ew_results.json", 'w') as f: json.dump(results, f, indent=2)
print(f"Saved: {EW_OUTPUT}/bgl_ew_results.json")
print(json.dumps(results['metrics'], indent=2))

In [None]:
# Post-Training Analysis: validation-calibrated thresholds, lead-to-first-anomaly, FAR–Recall, PR/ROC, CIs
import math, statistics
from sklearn.metrics import auc, confusion_matrix

# 1) Helpers --------------------------------------------------------------

def first_anomaly_index(trace):
    events = trace.get('events') or []
    for i, (_, _, is_anom) in enumerate(events):
        if is_anom:
            return i
    return None

@torch.no_grad()
def trace_scores_pre_anom(trace, step=8):
    tids = trace['tids']
    probs = score_trace(tids, step=step)
    q = first_anomaly_index(trace)
    if q is None:
        max_pre = max((p for _, p in probs), default=0.0)
    else:
        max_pre = max((p for k, p in probs if (k - 1) < q), default=0.0)
    return max_pre, probs, q


def earliest_crossing_k(probs, threshold, first_idx):
    for k, p in probs:
        if p > threshold and (first_idx is None or (k - 1) < first_idx):
            return k
    return None


def lead_events_from_crossing(k_cross, first_idx):
    if k_cross is None or first_idx is None:
        return None
    return max(0, first_idx - k_cross + 1)


def split_validation_test(norm_list, fail_list, val_ratio=0.3, seed=SEED):
    rng = np.random.RandomState(seed)
    nN, nF = len(norm_list), len(fail_list)
    nN_val = int(max(1, round(nN * val_ratio))) if nN > 0 else 0
    nF_val = int(max(1, round(nF * val_ratio))) if nF > 0 else 0
    n_idx = rng.permutation(nN) if nN > 0 else np.array([], dtype=int)
    f_idx = rng.permutation(nF) if nF > 0 else np.array([], dtype=int)
    val_norm = [norm_list[i] for i in n_idx[:nN_val]]
    tst_norm = [norm_list[i] for i in n_idx[nN_val:]]
    val_fail = [fail_list[i] for i in f_idx[:nF_val]]
    tst_fail = [fail_list[i] for i in f_idx[nF_val:]]
    return val_norm, val_fail, tst_norm, tst_fail

# 2) Build validation/test splits ----------------------------------------
val_normal, val_failure, final_normal, final_failure = split_validation_test(
    test_normal, test_failure, val_ratio=0.3, seed=SEED)
print(f"Validation: {len(val_normal)} N, {len(val_failure)} F | Test: {len(final_normal)} N, {len(final_failure)} F")

val_norm_scores = [trace_scores_pre_anom(t)[0] for t in tqdm(val_normal, desc="VAL normal scores")]
val_fail_scores = [trace_scores_pre_anom(t)[0] for t in tqdm(val_failure, desc="VAL failure scores")]

final_norm_sc = []
final_fail_sc = []
final_fail_meta = []
for t in tqdm(final_normal, desc="TEST normal scores"):
    s, _, _ = trace_scores_pre_anom(t)
    final_norm_sc.append(s)
for t in tqdm(final_failure, desc="TEST failure scores"):
    s, pr, q = trace_scores_pre_anom(t)
    final_fail_sc.append(s)
    final_fail_meta.append((pr, q))

# 3) Threshold by FAR on validation --------------------------------------
far_targets = [0.001, 0.005, 0.01, 0.02, 0.05]

def threshold_for_far(norm_scores, far):
    if len(norm_scores) == 0:
        return 1.0
    q = max(0.0, min(1.0, 1.0 - far))
    return float(np.quantile(norm_scores, q))

calib = {float(ft): threshold_for_far(val_norm_scores, ft) for ft in far_targets}
print("Calibrated thresholds (VAL):", {k: round(v, 4) for k, v in calib.items()})

# 4) Evaluate on TEST -----------------------------------------------------

def evaluate_at_threshold(th):
    tp = 0; fp = 0; leads = []
    for (probs, q) in final_fail_meta:
        kx = earliest_crossing_k(probs, th, q)
        if kx is not None:
            tp += 1
            ld = lead_events_from_crossing(kx, q)
            if ld is not None: leads.append(ld)
    for s in final_norm_sc:
        if s > th: fp += 1
    nF = len(final_fail_meta) or 1
    nN = len(final_norm_sc) or 1
    recall = tp / nF
    far = fp / nN
    precision = tp / max(tp + fp, 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-9)
    lead_stats = {
        'median': float(np.median(leads)) if leads else 0.0,
        'mean': float(np.mean(leads)) if leads else 0.0,
        'p25': float(np.percentile(leads, 25)) if leads else 0.0,
        'p75': float(np.percentile(leads, 75)) if leads else 0.0,
        'count': len(leads)
    }
    return {'f1': float(f1), 'precision': float(precision), 'recall': float(recall),
            'far': float(far), 'lead': lead_stats,
            'tp': int(tp), 'fp': int(fp), 'nF': int(nF), 'nN': int(nN)}

far_curve = []
for ft in far_targets:
    th = calib[float(ft)]
    m = evaluate_at_threshold(th)
    far_curve.append({'far_target': float(ft), 'threshold': th, **m})

op_target = 0.01 if 0.01 in far_targets else far_targets[min(2, len(far_targets)-1)]
op_th = calib[float(op_target)]
op_metrics = evaluate_at_threshold(op_th)

# 5) PR/ROC on TEST -------------------------------------------------------
all_scores = np.array(final_fail_sc + final_norm_sc, dtype=float)
all_labels = np.array([1]*len(final_fail_sc) + [0]*len(final_norm_sc), dtype=int)
roc = float(roc_auc_score(all_labels, all_scores)) if len(set(all_labels))>1 else 0.5
pr_p, pr_r, _ = precision_recall_curve(all_labels, all_scores)
pr = float(auc(pr_r, pr_p)) if len(pr_r)>1 else 0.0

# 6) Bootstrap 95% CI -----------------------------------------------------

def bootstrap_ci(n_boot=300, seed=SEED):
    rng = np.random.RandomState(seed)
    f_idx = np.arange(len(final_fail_meta))
    n_idx = np.arange(len(final_norm_sc))
    f1s, recs, fars = [], [], []
    for _ in range(n_boot):
        bf = rng.choice(f_idx, size=len(f_idx), replace=True) if len(f_idx)>0 else []
        bn = rng.choice(n_idx, size=len(n_idx), replace=True) if len(n_idx)>0 else []
        tp = 0; fp = 0
        for i in bf:
            probs, q = final_fail_meta[i]
            if earliest_crossing_k(probs, op_th, q) is not None:
                tp += 1
        for j in bn:
            if final_norm_sc[j] > op_th:
                fp += 1
        nF = max(1, len(f_idx)); nN = max(1, len(n_idx))
        recall = tp / nF
        far = fp / nN
        precision = tp / max(tp + fp, 1)
        f1 = 2*precision*recall / max(precision+recall, 1e-9)
        f1s.append(f1); recs.append(recall); fars.append(far)
    def ci(arr):
        if not arr: return [0.0, 0.0]
        lo, hi = np.percentile(arr, [2.5, 97.5])
        return [float(lo), float(hi)]
    return {'f1': ci(f1s), 'recall': ci(recs), 'far': ci(fars)}

cis = bootstrap_ci()

# 7) Save consolidated JSON -----------------------------------------------
report = {
    'dataset': 'BGL',
    'method': 'true_prewarn_analysis',
    'config': {
        'delta_min': int(DELTA_MIN),
        'window_size': int(WINDOW_SIZE),
        'step_size': 32,
        'far_targets': far_targets
    },
    'validation': {
        'n_normal': len(val_normal), 'n_failure': len(val_failure),
        'thresholds_by_far': {str(k): float(v) for k, v in calib.items()}
    },
    'test': {
        'n_normal': len(final_normal), 'n_failure': len(final_failure),
        'operating_point': {'far_target': float(op_target), 'threshold': float(op_th)},
        'metrics': {
            'roc_auc': roc,
            'pr_auc': pr,
            **op_metrics
        },
        'ci_95': cis
    },
    'far_recall_curve': far_curve,
    'notes': 'Threshold calibrated on validation normals by per-trace FAR; lead measured to FIRST anomaly event.'
}

with open(os.path.join(EW_OUTPUT, 'bgl_ew_analysis.json'), 'w') as f:
    json.dump(report, f, indent=2)
print(f"\n✓ Saved analysis to {EW_OUTPUT}/bgl_ew_analysis.json")
print(json.dumps(report['test']['metrics'], indent=2))