# HDFS V2 TRUE Pre-Warning (Time-based + Drain3)

Like BGL: timestamp-based labeling with DELTA_SEC

In [None]:
import os, json, random, re
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, precision_recall_curve

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# ===== KAGGLE PATHS (FIXED) =====
BASE = "/kaggle/working"
CKPT_DIR = "/kaggle/input/hdfs-ew3"
V2_PATH = "/kaggle/input/loghub-hdfs-hadoop-distributed-file-system-data/HDFS_v2/node_logs"
DRAIN_TEMPLATES = None  # Will parse fresh with Drain3
OUTPUT_DIR = f"{BASE}/output-hdfsv2ew3"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"V2_PATH contents: {os.listdir(V2_PATH)[:5]}...")

CONTEXT_LEN, D_MODEL, N_HEADS, N_LAYERS = 128, 256, 8, 4
PAD, CLS, MASK, SEP, OFF = 0, 1, 2, 3, 4
BATCH_SIZE, SEED = 64, 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DELTA_MIN = 5
DELTA_SEC = DELTA_MIN * 60
SESSION_GAP = 3600

In [None]:
# ===== LOAD PARSED DATA (FROM LOCAL CPU) =====
import pickle

# Kaggle path
PKL_FILE = "/kaggle/input/hdfsv2-parsed/hdfsv2_parsed_data.pkl"

with open(PKL_FILE, 'rb') as f:
    data = pickle.load(f)

v2_events_raw = data['v2_events_raw']
normal_traces = data['normal_traces']
failure_traces = data['failure_traces']
train_normal = data['train_normal']
test_normal = data['test_normal']
train_failure = data['train_failure']
test_failure = data['test_failure']
train_windows = data['train_windows']
NUM_TEMPLATES = data['NUM_TEMPLATES']
VOCAB_SIZE = data['VOCAB_SIZE']
OLD_VOCAB_SIZE = data['OLD_VOCAB_SIZE']

print(f"Loaded parsed data!")
print(f"Events: {len(v2_events_raw):,}")
print(f"VOCAB_SIZE: {VOCAB_SIZE}")
print(f"Train windows: {len(train_windows):,}")
print(f"Test: {len(test_normal)} N, {len(test_failure)} F")
print("\n>>> SKIP parsing cells, go to DATASET & MODEL <<<")

In [None]:
# LOAD CHECKPOINT (skip drain templates - parse fresh)
ckpt = torch.load(f"{CKPT_DIR}/checkpoints_logbert_ep11.pt", map_location=device)
VOCAB_SIZE = ckpt['model_state_dict']['tok.weight'].shape[0]
print(f"VOCAB: {VOCAB_SIZE}")

template_map = {}  # Empty - will parse fresh


In [None]:
# PARSE V2 RAW LOGS WITH DRAIN3 (REUSE OLD TEMPLATE MAPPING)
print("PARSING V2 LOGS")
!pip install drain3 -q
from drain3 import TemplateMiner
from drain3.template_miner_config import TemplateMinerConfig
from datetime import datetime

# Initialize Drain3
config = TemplateMinerConfig()
config.drain_depth = 4
config.drain_sim_th = 0.4
miner = TemplateMiner(config=config)

# Anomaly patterns
ANOMALY_PATTERNS = ['ERROR', 'Exception', 'failed', 'timeout', 'IOException']
import re
anomaly_regex = re.compile('|'.join(ANOMALY_PATTERNS), re.IGNORECASE)

def parse_timestamp(line):
    m = re.search(r'(?P<y>\d{4})-(?P<m>\d{2})-(?P<d>\d{2})[ T]'
                  r'(?P<h>\d{2}):(?P<mi>\d{2}):(?P<s>\d{2})(?:[,\.](?P<ms>\d{3}))?',
                  line)
    if m:
        y = int(m.group('y')); mo = int(m.group('m')); d = int(m.group('d'))
        h = int(m.group('h')); mi = int(m.group('mi')); s = int(m.group('s'))
        return int(datetime(y, mo, d, h, mi, s).timestamp())
    m2 = re.search(r'(?P<d>\d{6})\s+(?P<t>\d{6})', line)
    if m2:
        d = m2.group('d'); t = m2.group('t')
        day = int(d[4:6]); h = int(t[:2]); mi = int(t[2:4]); s = int(t[4:6])
        return day*86400 + h*3600 + mi*60 + s
    return None

# Track unknown templates
next_unknown_id = VOCAB_SIZE
unknown_templates = {}

def get_tid(line):
    global next_unknown_id
    result = miner.add_log_message(line)
    cluster = result if hasattr(result, 'get_template') else result.get('cluster')
    if cluster is None:
        return VOCAB_SIZE - 1  # Unknown fallback
    template_str = cluster.get_template() if hasattr(cluster, 'get_template') else str(cluster)
    # Try to map to old template
    if template_str in template_map:
        return template_map[template_str]
    # New template - assign new ID (track but keep in vocab range)
    if template_str not in unknown_templates:
        unknown_templates[template_str] = len(unknown_templates)
    return min(VOCAB_SIZE - 1, VOCAB_SIZE - 10 + unknown_templates[template_str])

# Parse all log files
v2_events_raw = []
log_files = [f for f in os.listdir(V2_PATH) if f.endswith('.log')]
print(f"Parsing {len(log_files)} log files...")

for file_id, log_file in enumerate(tqdm(log_files)):
    filepath = os.path.join(V2_PATH, log_file)
    with open(filepath, 'r', errors='ignore') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            ts = parse_timestamp(line)
            if ts is None: continue
            is_anomaly = bool(anomaly_regex.search(line))
            tid = get_tid(line)
            v2_events_raw.append((ts, tid, is_anomaly, file_id))

NUM_TEMPLATES = len(miner.drain.clusters)
print(f"Total events (with timestamp): {len(v2_events_raw):,}")
print(f"Templates found: {NUM_TEMPLATES}")
print(f"Mapped to old vocab: {NUM_TEMPLATES - len(unknown_templates)}")
print(f"New unknown templates: {len(unknown_templates)}")
print(f"Anomalies: {sum(1 for _,_,a,_ in v2_events_raw if a):,}")

In [None]:
# === SAVE PARSED DATA (Chạy sau khi parse xong) ===
import pickle
parsed_data = {
    'events': v2_events_raw,
    'num_templates': NUM_TEMPLATES,
    'unknown_templates': unknown_templates
}
with open(f"{OUTPUT_DIR}/v2_parsed.pkl", 'wb') as f:
    pickle.dump(parsed_data, f)
print(f"✅ Saved parsed data: {len(v2_events_raw):,} events")

# Clear raw events to save memory
# del v2_events_raw; import gc; gc.collect()

In [None]:
# CREATE TRACES (per-file, sorted by timestamp)
print("\nCREATING TRACES")
from collections import defaultdict

# Group by file, then sort by timestamp within file
file_events = defaultdict(list)
for ts, tid, anom, file_id in v2_events_raw:
    file_events[file_id].append((ts, tid, anom))

# Create sessions within each file
sessions = []
for file_id, events in file_events.items():
    events.sort(key=lambda x: x[0])  # Sort by timestamp
    current, last_time = [], None
    for ts, tid, anom in events:
        if last_time and ts - last_time > SESSION_GAP:
            if current: sessions.append(current)
            current = []
        current.append((ts, tid, anom))
        last_time = ts
    if current: sessions.append(current)

print(f"Total sessions: {len(sessions)}")

# Filter to traces
normal_traces, failure_traces = [], []
for sess in sessions:
    tids = [tid for _,tid,_ in sess]
    if len(tids) >= 20:
        if any(a for _,_,a in sess):
            failure_traces.append({'tids': tids, 'events': sess})
        else:
            normal_traces.append({'tids': tids, 'events': sess})

print(f"Normal: {len(normal_traces)}, Failure: {len(failure_traces)}")

In [None]:
# CREATE WINDOWS (time-based like BGL)
WINDOW_SIZE, STEP_SIZE = 256, 64

def create_windows(traces, is_failure=False):
    windows = []
    for trace in traces:
        tids, events = trace['tids'], trace.get('events')
        n = len(tids)
        for i in range(0, max(1, n - WINDOW_SIZE + 1), STEP_SIZE):
            end = min(i + WINDOW_SIZE, n)
            w_tids = tids[i:end]
            if is_failure and events:
                end_time = events[min(end-1, n-1)][0]
                has_future = False
                if end_time:
                    for j in range(end, n):
                        if events[j][0] and events[j][0] > end_time + DELTA_SEC: break
                        if events[j][2]: has_future = True; break
                has_current = any(events[k][2] for k in range(i, min(end, n)))
                label = 1 if (has_current or has_future) else 0
            else:
                label = 0
            windows.append({'tids': w_tids, 'label': label})
    return windows

def create_prewarn_windows(traces):
    windows = []
    for trace in traces:
        tids, events = trace['tids'], trace.get('events')
        n = len(tids)
        for i in range(0, max(1, n - WINDOW_SIZE + 1), STEP_SIZE):
            end = min(i + WINDOW_SIZE, n)
            has_current = any(events[k][2] for k in range(i, min(end, n)))
            if has_current: continue  # TRUE pre-warning only
            end_time = events[min(end-1, n-1)][0]
            has_future = False
            if end_time:
                for j in range(end, n):
                    if events[j][0] and events[j][0] > end_time + DELTA_SEC: break
                    if events[j][2]: has_future = True; break
            label = 1 if has_future else 0
            windows.append({'tids': tids[i:end], 'label': label})
    return windows

# Safe split
def safe_split(n, ratio=0.8):
    return n-1 if n>1 and int(n*ratio)==n else max(1, int(n*ratio))

rng = np.random.RandomState(SEED)
n_idx = rng.permutation(len(normal_traces)) if normal_traces else []
f_idx = rng.permutation(len(failure_traces)) if failure_traces else []
n_split = safe_split(len(normal_traces)) if normal_traces else 0
f_split = safe_split(len(failure_traces)) if failure_traces else 0

train_normal = [normal_traces[i] for i in n_idx[:n_split]] if normal_traces else []
test_normal = [normal_traces[i] for i in n_idx[n_split:]] if normal_traces else []
train_failure = [failure_traces[i] for i in f_idx[:f_split]] if failure_traces else []
test_failure = [failure_traces[i] for i in f_idx[f_split:]] if failure_traces else []

print(f"Train: {len(train_normal)} N, {len(train_failure)} F")
print(f"Test: {len(test_normal)} N, {len(test_failure)} F")

# Stage 1 windows
s1_windows = create_windows(train_normal) + create_windows(train_failure, True)
s1_pos = sum(1 for w in s1_windows if w['label']==1)
print(f"Stage 1: {len(s1_windows):,} windows (pos={s1_pos:,})")

# Stage 2 windows (TRUE pre-warning)
s2_windows = create_windows(train_normal) + create_prewarn_windows(train_failure)
s2_pos = sum(1 for w in s2_windows if w['label']==1)
print(f"Stage 2: {len(s2_windows):,} windows (pos={s2_pos:,})")

In [None]:
# === SAVE TRACES + WINDOWS (Chạy sau khi tạo windows xong) ===
import pickle
processed_data = {
    'train_normal': train_normal,
    'train_failure': train_failure,
    'test_normal': test_normal,
    'test_failure': test_failure,
    's1_windows': s1_windows,
    's2_windows': s2_windows,
    's1_pos': s1_pos,
    's2_pos': s2_pos
}
with open(f"{OUTPUT_DIR}/v2_processed.pkl", 'wb') as f:
    pickle.dump(processed_data, f)
print(f"✅ Saved processed data: {len(s1_windows):,} S1, {len(s2_windows):,} S2 windows")

In [None]:
# === LOAD SAVED DATA (Thay cho parse nếu đã có) ===
# Uncomment để load thay vì parse:

# import pickle
# with open(f"{OUTPUT_DIR}/v2_processed.pkl", 'rb') as f:
#     data = pickle.load(f)
# train_normal = data['train_normal']
# train_failure = data['train_failure']
# test_normal = data['test_normal']
# test_failure = data['test_failure']
# s1_windows = data['s1_windows']
# s2_windows = data['s2_windows']
# s1_pos = data['s1_pos']
# s2_pos = data['s2_pos']
# print(f"✅ Loaded: {len(s1_windows):,} S1, {len(s2_windows):,} S2 windows")

In [None]:
# DATASET + MODEL (same as BGL)
class BalancedDataset(Dataset):
    def __init__(self, windows, vs, ratio=5):
        self.pos = [w for w in windows if w['label']==1]
        self.neg = [w for w in windows if w['label']==0]
        self.vs, self.ratio = vs, ratio
        self.n = len(self.pos)*(1+ratio) if self.pos else len(self.neg)
    def __len__(self): return self.n
    def __getitem__(self, idx):
        w = random.choice(self.pos) if (self.pos and idx%(self.ratio+1)==0) else random.choice(self.neg if self.neg else self.pos)
        s = [min(t+OFF, self.vs-1) for t in w['tids'][:CONTEXT_LEN-2]]
        tok = [CLS]+s+[SEP]+[PAD]*(CONTEXT_LEN-len(s)-2)
        masked, mlm = tok.copy(), [-100]*len(tok)
        if s:
            for p in np.random.choice(len(s), min(max(1,int(len(s)*0.15)),len(s)), replace=False)+1:
                mlm[p] = masked[p]; masked[p] = MASK
        return {'ids': torch.tensor(masked), 'mask': torch.tensor([1 if t!=PAD else 0 for t in masked]),
                'ew': torch.tensor(w['label'], dtype=torch.float), 'mlm': torch.tensor(mlm)}

class LogBERTEW(nn.Module):
    def __init__(self, vs):
        super().__init__()
        self.tok = nn.Embedding(vs, D_MODEL, padding_idx=PAD)
        self.pos = nn.Embedding(CONTEXT_LEN, D_MODEL)
        self.drop = nn.Dropout(0.1)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, 0.1, 'gelu', batch_first=True), N_LAYERS)
        self.head = nn.Linear(D_MODEL, vs)
        self.ew_head = nn.Linear(D_MODEL, 1)
        self.register_buffer('ctr', torch.zeros(D_MODEL)); self.ci = False
    def forward(self, ids, mask=None):
        x = self.tok(ids) + self.pos(torch.arange(ids.size(1), device=ids.device))
        h = self.enc(self.drop(x), src_key_padding_mask=(mask==0) if mask is not None else None)
        return self.head(h), h[:,0,:], self.ew_head(h[:,0,:]).squeeze(-1)
    def upd(self, e):
        with torch.no_grad(): bc = e.mean(0); self.ctr = bc if not self.ci else 0.9*self.ctr+0.1*bc; self.ci = True

model = LogBERTEW(VOCAB_SIZE).to(device)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
model.ctr = ckpt['center'].to(device); model.ci = True
print(f"Model loaded: VOCAB={VOCAB_SIZE}")

In [None]:
# STAGE 1 TRAINING
print("\n" + "="*50)
print("STAGE 1: Detection + Pre-warning")
print("="*50)

LAMBDA_MLM, LAMBDA_VHM, MU_EW = 0.4, 0.1, 1.0
s1_ds = BalancedDataset(s1_windows, VOCAB_SIZE)
s1_loader = DataLoader(s1_ds, batch_size=BATCH_SIZE, shuffle=True)
POS_WEIGHT = (len(s1_windows)-s1_pos) / max(s1_pos, 1)
ew_crit = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=device))
scaler = torch.amp.GradScaler('cuda') if device.type=='cuda' else None
best = float('inf')

# Phase 1
print("Phase 1")
for p in model.parameters(): p.requires_grad = False
for p in model.ew_head.parameters(): p.requires_grad = True
opt = torch.optim.Adam(model.ew_head.parameters(), lr=1e-4)
for ep in range(2):
    model.train(); tl = 0
    for b in tqdm(s1_loader, desc=f"P1.{ep+1}"):
        opt.zero_grad(); _, _, ew = model(b['ids'].to(device), b['mask'].to(device))
        loss = ew_crit(ew, b['ew'].to(device)); loss.backward(); opt.step(); tl += loss.item()
    print(f"  {ep+1}: {tl/len(s1_loader):.4f}")

# Phase 2
print("Phase 2")
for p in model.parameters(): p.requires_grad = True
opt = torch.optim.AdamW(model.parameters(), lr=3e-5)
for ep in range(8):
    model.train(); tl = 0
    for b in tqdm(s1_loader, desc=f"P2.{ep+1}"):
        ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew'].to(device), b['mlm'].to(device)
        opt.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                mlm_lg, cls, ew = model(ids, mask)
                loss = LAMBDA_MLM*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + LAMBDA_VHM*torch.mean((cls-model.ctr)**2) + MU_EW*ew_crit(ew, ew_l)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            mlm_lg, cls, ew = model(ids, mask)
            loss = LAMBDA_MLM*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + LAMBDA_VHM*torch.mean((cls-model.ctr)**2) + MU_EW*ew_crit(ew, ew_l)
            loss.backward(); opt.step()
        model.upd(cls.detach()); tl += loss.item()
    avg = tl/len(s1_loader); print(f"  {ep+1}: {avg:.4f}")
    if avg < best: best = avg; torch.save(model.state_dict(), f"{OUTPUT_DIR}/hdfsv2_ew_stage1.pt")

print(f"Stage 1 best: {best:.4f}")

In [None]:
# STAGE 2: TRUE PRE-WARNING
print("\n" + "="*50)
print("STAGE 2: TRUE Pre-Warning")
print("="*50)

model.load_state_dict(torch.load(f"{OUTPUT_DIR}/hdfsv2_ew_stage1.pt", map_location=device))
print("Loaded Stage 1 checkpoint")

s2_ds = BalancedDataset(s2_windows, VOCAB_SIZE)
s2_loader = DataLoader(s2_ds, batch_size=BATCH_SIZE, shuffle=True)
POS_WEIGHT_S2 = (len(s2_windows)-s2_pos) / max(s2_pos, 1)
ew_crit_s2 = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT_S2], device=device))

opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
best_s2 = float('inf')

for ep in range(5):
    model.train(); tl = 0
    for b in tqdm(s2_loader, desc=f"S2.{ep+1}"):
        ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew'].to(device), b['mlm'].to(device)
        opt.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                mlm_lg, cls, ew = model(ids, mask)
                loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + 1.0*ew_crit_s2(ew, ew_l)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            mlm_lg, cls, ew = model(ids, mask)
            loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + 1.0*ew_crit_s2(ew, ew_l)
            loss.backward(); opt.step()
        model.upd(cls.detach()); tl += loss.item()
    avg = tl/len(s2_loader); print(f"  {ep+1}: {avg:.4f}")
    if avg < best_s2: best_s2 = avg; torch.save(model.state_dict(), f"{OUTPUT_DIR}/hdfsv2_ew_prewarn.pt")

print(f"Stage 2 best: {best_s2:.4f}")

In [None]:
# EVALUATION
print("\nEVALUATION")
model.load_state_dict(torch.load(f"{OUTPUT_DIR}/hdfsv2_ew_prewarn.pt", map_location=device))
model.eval()

@torch.no_grad()
def score_trace(tids, step=32):
    probs = []
    for k in range(min(WINDOW_SIZE, len(tids)), len(tids)+1, step):
        s = [min(t+OFF, VOCAB_SIZE-1) for t in tids[max(0,k-WINDOW_SIZE):k]]
        if len(s) < 10: continue
        tok = [CLS]+s[:CONTEXT_LEN-2]+[SEP]+[PAD]*(CONTEXT_LEN-len(s[:CONTEXT_LEN-2])-2)
        ids = torch.tensor([tok], device=device)
        _, _, ew = model(ids)
        probs.append((k, torch.sigmoid(ew).item()))
    return probs

failure_scores = [{'trace': t, 'max_prob': max(p for _,p in score_trace(t['tids'])), 'probs': score_trace(t['tids'])} 
                  for t in tqdm(test_failure) if score_trace(t['tids'])]
normal_scores = [{'trace': t, 'max_prob': max(p for _,p in score_trace(t['tids']))} 
                 for t in tqdm(test_normal) if score_trace(t['tids'])]
print(f"Scored: {len(failure_scores)} F, {len(normal_scores)} N")

In [None]:
# F1-OPTIMAL + METRICS
all_probs = [s['max_prob'] for s in failure_scores] + [s['max_prob'] for s in normal_scores]
all_labels = [1]*len(failure_scores) + [0]*len(normal_scores)

auc = roc_auc_score(all_labels, all_probs) if len(set(all_labels)) > 1 else 0.5
prec, rec, thr = precision_recall_curve(all_labels, all_probs)
f1 = 2*prec*rec/(prec+rec+1e-9)
best_idx = f1.argmax()
TH = thr[best_idx] if best_idx < len(thr) else thr[-1] if len(thr) > 0 else 0.5

tp, leads = [], []
for s in failure_scores:
    if s['max_prob'] > TH:
        tp.append(s)
        for k, p in s['probs']:
            if p > TH: # FIXED: calculate lead to first anomaly
; break
fp = [s for s in normal_scores if s['max_prob'] > TH]

recall = len(tp)/len(failure_scores) if failure_scores else 0
far = len(fp)/len(normal_scores) if normal_scores else 0
precision = len(tp)/(len(tp)+len(fp)) if (len(tp)+len(fp))>0 else 0
f1_final = 2*precision*recall/(precision+recall) if (precision+recall)>0 else 0

print(f"\n=== RESULTS ===")
print(f"AUC: {auc:.4f}, TH: {TH:.4f}")
print(f"F1: {f1_final:.4f}, P: {precision:.4f}, R: {recall:.4f}, FAR: {far:.4f}")
if leads: print(f"Lead: median={np.median(leads):.0f}, mean={np.mean(leads):.0f}")

In [None]:
# SAVE
results = {'dataset': 'HDFS_V2', 'method': 'curriculum_true_prewarn_timebased',
    'metrics': {'auc': float(auc), 'threshold': float(TH), 'f1': float(f1_final),
        'precision': float(precision), 'recall': float(recall), 'far': float(far),
        'lead_median': float(np.median(leads)) if leads else 0,
        'lead_mean': float(np.mean(leads)) if leads else 0}}
with open(f"{OUTPUT_DIR}/hdfsv2_ew_results.json", 'w') as f: json.dump(results, f, indent=2)
print(f"Saved: {OUTPUT_DIR}/hdfsv2_ew_results.json")
print(json.dumps(results['metrics'], indent=2))

In [None]:
# Post-Training Analysis: validation-calibrated thresholds, lead-to-first-anomaly, FAR–Recall, PR/ROC, CIs
import math, statistics
from sklearn.metrics import auc, confusion_matrix

# 1) Helpers --------------------------------------------------------------

def first_anomaly_index(trace):
    events = trace.get('events') or []
    for i, (_, _, is_anom) in enumerate(events):
        if is_anom:
            return i
    return None  # normal

@torch.no_grad()
def trace_scores_pre_anom(trace, step=32):
    """Return (max_prob_pre, probs list, first_anom_idx).
    probs: list[(k, p)] as produced by score_trace, but computed here once for reuse.
    max_prob_pre considers only prefixes strictly before the first anomaly index.
    For normal traces (first_idx=None), considers all prefixes.
    """
    tids = trace['tids']
    probs = score_trace(tids, step=step)
    q = first_anomaly_index(trace)
    if q is None:
        max_pre = max((p for _, p in probs), default=0.0)
    else:
        max_pre = max((p for k, p in probs if (k - 1) < q), default=0.0)
    return max_pre, probs, q


def earliest_crossing_k(probs, threshold, first_idx):
    """Earliest prefix length k where p>threshold and (k-1) < first_idx (if exists)."""
    for k, p in probs:
        if p > threshold and (first_idx is None or (k - 1) < first_idx):
            return k
    return None


def lead_events_from_crossing(k_cross, first_idx):
    """Lead (in events) to first anomaly = max(0, q - k + 1)."""
    if k_cross is None or first_idx is None:
        return None
    return max(0, first_idx - k_cross + 1)


def split_validation_test(norm_list, fail_list, val_ratio=0.3, seed=SEED):
    rng = np.random.RandomState(seed)
    nN, nF = len(norm_list), len(fail_list)
    nN_val = int(max(1, round(nN * val_ratio))) if nN > 0 else 0
    nF_val = int(max(1, round(nF * val_ratio))) if nF > 0 else 0
    n_idx = rng.permutation(nN) if nN > 0 else np.array([], dtype=int)
    f_idx = rng.permutation(nF) if nF > 0 else np.array([], dtype=int)
    val_norm = [norm_list[i] for i in n_idx[:nN_val]]
    tst_norm = [norm_list[i] for i in n_idx[nN_val:]]
    val_fail = [fail_list[i] for i in f_idx[:nF_val]]
    tst_fail = [fail_list[i] for i in f_idx[nF_val:]]
    return val_norm, val_fail, tst_norm, tst_fail


# 2) Build validation/test splits (from held-out traces) ------------------
val_normal, val_failure, final_normal, final_failure = split_validation_test(
    test_normal, test_failure, val_ratio=0.3, seed=SEED)
print(f"Validation: {len(val_normal)} N, {len(val_failure)} F | Test: {len(final_normal)} N, {len(final_failure)} F")

# Score per-trace max pre-anomaly prob for VAL
val_norm_scores = [trace_scores_pre_anom(t)[0] for t in tqdm(val_normal, desc="VAL normal scores")]
val_fail_scores = [trace_scores_pre_anom(t)[0] for t in tqdm(val_failure, desc="VAL failure scores")]

# Score per-trace max pre-anomaly prob and cache probs for TEST
final_norm_sc = []
final_fail_sc = []
final_fail_meta = []  # store (probs, first_idx)
for t in tqdm(final_normal, desc="TEST normal scores"):
    s, _, _ = trace_scores_pre_anom(t)
    final_norm_sc.append(s)
for t in tqdm(final_failure, desc="TEST failure scores"):
    s, pr, q = trace_scores_pre_anom(t)
    final_fail_sc.append(s)
    final_fail_meta.append((pr, q))

# 3) Threshold calibration on validation by FAR target --------------------
far_targets = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1]

def threshold_for_far(norm_scores, far):
    if len(norm_scores) == 0:
        return 1.0
    q = max(0.0, min(1.0, 1.0 - far))
    return float(np.quantile(norm_scores, q))

calib = {float(ft): threshold_for_far(val_norm_scores, ft) for ft in far_targets}
print("Calibrated thresholds (VAL):", {k: round(v, 4) for k, v in calib.items()})

# 4) Evaluate on TEST at each FAR target ----------------------------------

def evaluate_at_threshold(th):
    # Per-trace decisions
    tp = 0
    fp = 0
    leads = []
    for (probs, q) in final_fail_meta:
        kx = earliest_crossing_k(probs, th, q)
        if kx is not None:
            tp += 1
            ld = lead_events_from_crossing(kx, q)
            if ld is not None:
                leads.append(ld)
    for s in final_norm_sc:
        if s > th:
            fp += 1
    nF = len(final_fail_meta) or 1
    nN = len(final_norm_sc) or 1
    recall = tp / nF
    far = fp / nN
    precision = tp / max(tp + fp, 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-9)
    lead_stats = {
        'median': float(np.median(leads)) if leads else 0.0,
        'mean': float(np.mean(leads)) if leads else 0.0,
        'p25': float(np.percentile(leads, 25)) if leads else 0.0,
        'p75': float(np.percentile(leads, 75)) if leads else 0.0,
        'count': len(leads)
    }
    return {'f1': float(f1), 'precision': float(precision), 'recall': float(recall),
            'far': float(far), 'lead': lead_stats,
            'tp': int(tp), 'fp': int(fp), 'nF': int(nF), 'nN': int(nN)}

far_curve = []
for ft in far_targets:
    th = calib[float(ft)]
    m = evaluate_at_threshold(th)
    far_curve.append({'far_target': float(ft), 'threshold': th, **m})

# Choose operating point FAR=1% if available
op_target = 0.01 if 0.01 in far_targets else far_targets[min(2, len(far_targets)-1)]
op_th = calib[float(op_target)]
op_metrics = evaluate_at_threshold(op_th)

# 5) PR/ROC on TEST using per-trace max pre-anomaly prob -------------------
all_scores = np.array(final_fail_sc + final_norm_sc, dtype=float)
all_labels = np.array([1]*len(final_fail_sc) + [0]*len(final_norm_sc), dtype=int)
roc = float(roc_auc_score(all_labels, all_scores)) if len(set(all_labels))>1 else 0.5
pr_p, pr_r, _ = precision_recall_curve(all_labels, all_scores)
pr = float(auc(pr_r, pr_p)) if len(pr_r)>1 else 0.0

# 6) Bootstrap 95% CI at operating point ----------------------------------

def bootstrap_ci(n_boot=300, seed=SEED):
    rng = np.random.RandomState(seed)
    f_idx = np.arange(len(final_fail_meta))
    n_idx = np.arange(len(final_norm_sc))
    f1s, recs, fars = [], [], []
    for _ in range(n_boot):
        bf = rng.choice(f_idx, size=len(f_idx), replace=True) if len(f_idx)>0 else []
        bn = rng.choice(n_idx, size=len(n_idx), replace=True) if len(n_idx)>0 else []
        tp = 0; fp = 0
        for i in bf:
            probs, q = final_fail_meta[i]
            if earliest_crossing_k(probs, op_th, q) is not None:
                tp += 1
        for j in bn:
            if final_norm_sc[j] > op_th:
                fp += 1
        nF = max(1, len(f_idx)); nN = max(1, len(n_idx))
        recall = tp / nF
        far = fp / nN
        precision = tp / max(tp + fp, 1)
        f1 = 2*precision*recall / max(precision+recall, 1e-9)
        f1s.append(f1); recs.append(recall); fars.append(far)
    def ci(arr):
        if not arr: return [0.0, 0.0]
        lo, hi = np.percentile(arr, [2.5, 97.5])
        return [float(lo), float(hi)]
    return {'f1': ci(f1s), 'recall': ci(recs), 'far': ci(fars)}

cis = bootstrap_ci()

# 7) Save consolidated JSON -----------------------------------------------
report = {
    'dataset': 'HDFS_V2',
    'method': 'true_prewarn_analysis',
    'config': {
        'delta_min': int(DELTA_MIN),
        'window_size': int(WINDOW_SIZE),
        'step_size': 32,
        'far_targets': far_targets
    },
    'validation': {
        'n_normal': len(val_normal), 'n_failure': len(val_failure),
        'thresholds_by_far': {str(k): float(v) for k, v in calib.items()}
    },
    'test': {
        'n_normal': len(final_normal), 'n_failure': len(final_failure),
        'operating_point': {'far_target': float(op_target), 'threshold': float(op_th)},
        'metrics': {
            'roc_auc': roc,
            'pr_auc': pr,
            **op_metrics
        },
        'ci_95': cis
    },
    'far_recall_curve': far_curve,
    'notes': 'Threshold calibrated on validation normals by per-trace FAR; lead measured to FIRST anomaly event.'
}

with open(os.path.join(OUTPUT_DIR, 'hdfsv2_ew_analysis.json'), 'w') as f:
    json.dump(report, f, indent=2)
print(f"\n✓ Saved analysis to {OUTPUT_DIR}/hdfsv2_ew_analysis.json")
print(json.dumps(report['test']['metrics'], indent=2))