# HDFS TRUE Pre-Warning (Curriculum Learning)

**Stage 1:** Detection + Pre-warning → Learn patterns
**Stage 2:** TRUE Pre-warning → Fine-tune

In [None]:
import os, json, random
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, precision_recall_curve

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

BASE = "/teamspace/studios/this_studio"
V1_PATH = f"{BASE}/content/LogHub_HDFS/HDFS_v1/preprocessed/Event_traces.csv"
CKPT_DIR = f"{BASE}/checkpoints"
OUTPUT_DIR = f"{BASE}/output-hdfsv1ew1"  # New output dir
os.makedirs(OUTPUT_DIR, exist_ok=True)

CONTEXT_LEN, D_MODEL, N_HEADS, N_LAYERS = 128, 256, 8, 4
PAD, CLS, MASK, SEP, OFF = 0, 1, 2, 3, 4
BATCH_SIZE, SEED = 64, 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# IMPROVED: LOOKAHEAD_BLOCKS instead of PRE_RATIO
LOOKAHEAD_BLOCKS = 15  # Label=1 if anomaly in next N blocks
print(f"Using LOOKAHEAD_BLOCKS = {LOOKAHEAD_BLOCKS}")

In [None]:
# LOAD DATA
df = pd.read_csv(V1_PATH)
v1_normal, v1_failure = [], []
for _, row in tqdm(df.iterrows(), total=len(df), desc="V1"):
    features = str(row.get('Features', '[]')).strip('[]"')
    events = [int(e.strip().strip("'")[1:]) for e in features.split(',') if e.strip().startswith('E')]
    if len(events) >= 6:
        (v1_normal if row['Label'] == 'Success' else v1_failure).append(events)

print(f"Normal: {len(v1_normal):,}, Failure: {len(v1_failure):,}")

def seed_split(seqs, r=0.2):
    rng = np.random.RandomState(SEED)
    idx = rng.permutation(len(seqs))
    s = int(len(seqs)*(1-r))
    return [seqs[i] for i in idx[:s]], [seqs[i] for i in idx[s:]]

normal_train, normal_test = seed_split(v1_normal)
failure_train, failure_test = seed_split(v1_failure)
print(f"TRAIN: {len(normal_train):,} N, {len(failure_train):,} F")
print(f"TEST: {len(normal_test):,} N, {len(failure_test):,} F")

In [None]:
# STAGE 1: Detection + Pre-warning (LOOKAHEAD_BLOCKS approach)
def create_prefixes_stage1(traces, is_failure=False, lookahead=LOOKAHEAD_BLOCKS, step=3, min_len=6, max_per=15):
    """Label=1 if this prefix is within LOOKAHEAD of trace end (failure)."""
    samples = []
    for trace in traces:
        if len(trace) < min_len: continue
        prefixes = []
        for k in range(min_len, len(trace)+1, step):
            remaining = len(trace) - k
            # For failure traces: label=1 if close to end (within lookahead)
            if is_failure:
                ew = 1 if remaining <= lookahead else 0
            else:
                ew = 0
            prefixes.append((trace[:k], ew, remaining))  # Include remaining for position encoding
        if len(prefixes) > max_per:
            # Keep mix of pos and neg
            pos = [p for p in prefixes if p[1]==1]
            neg = [p for p in prefixes if p[1]==0]
            if len(pos) > max_per//2:
                pos = random.sample(pos, max_per//2)
            neg = random.sample(neg, min(len(neg), max_per - len(pos)))
            prefixes = pos + neg
        samples.extend(prefixes)
    return samples

print("Stage 1 prefixes (LOOKAHEAD approach)...")
s1_normal = create_prefixes_stage1(normal_train, is_failure=False)
s1_failure = create_prefixes_stage1(failure_train, is_failure=True)
s1_pos = sum(1 for p in s1_failure if p[1]==1)
print(f"Stage 1: normal={len(s1_normal):,}, failure={len(s1_failure):,}, pos={s1_pos:,}")

In [None]:
# STAGE 2: TRUE Pre-warning (skip last LOOKAHEAD blocks - already in danger zone)
def create_prefixes_stage2(traces, lookahead=LOOKAHEAD_BLOCKS, step=3, min_len=6, max_per=15):
    """TRUE pre-warning: label=1 if prefix BEFORE danger zone but failure coming."""
    samples = []
    for trace in traces:
        if len(trace) < min_len + lookahead: continue  # Need enough for lookahead
        prefixes = []
        for k in range(min_len, len(trace)+1, step):
            remaining = len(trace) - k
            if remaining <= lookahead:
                continue  # SKIP: already in danger zone (detection, not pre-warning)
            # remaining > lookahead: true pre-warning opportunity
            ew = 1  # Failure is coming
            prefixes.append((trace[:k], ew, remaining))
        if len(prefixes) > max_per:
            prefixes = random.sample(prefixes, max_per)
        samples.extend(prefixes)
    return samples

print("Stage 2 prefixes (TRUE pre-warning)...")
s2_normal = [(trace[:k], 0, len(trace)-k) for trace in normal_train for k in range(6, len(trace)+1, 3)][:len(normal_train)*15]
s2_failure = create_prefixes_stage2(failure_train)
s2_pos = sum(1 for p in s2_failure if p[1]==1)
print(f"Stage 2: normal={len(s2_normal):,}, failure pos={s2_pos:,}")

In [None]:
# DATASET
class EWDataset(Dataset):
    def __init__(self, normal, failure, vs, ratio=5):
        self.pos = [p for p in failure if p[1]==1]
        self.neg = normal + [p for p in failure if p[1]==0]
        self.vs, self.ratio = vs, ratio
        self.n = len(self.pos)*(1+ratio) if self.pos else len(self.neg)
    def __len__(self): return self.n
    def __getitem__(self, idx):
        if self.pos and idx%(self.ratio+1)==0:
            prefix, ew = random.choice(self.pos)
        else:
            prefix, ew = random.choice(self.neg) if self.neg else random.choice(self.pos)
        s = [min(t+OFF, self.vs-1) for t in prefix[:CONTEXT_LEN-2]]
        tok = [CLS]+s+[SEP]+[PAD]*(CONTEXT_LEN-len(s)-2)
        masked, mlm = tok.copy(), [-100]*len(tok)
        if s:
            for p in np.random.choice(len(s), min(max(1,int(len(s)*0.15)),len(s)), replace=False)+1:
                mlm[p] = masked[p]; masked[p] = MASK
        return {'ids': torch.tensor(masked), 'mask': torch.tensor([1 if t!=PAD else 0 for t in masked]),
                'ew': torch.tensor(ew, dtype=torch.float), 'mlm': torch.tensor(mlm)}

In [None]:
# MODEL
class LogBERTEW(nn.Module):
    def __init__(self, vs):
        super().__init__()
        self.tok = nn.Embedding(vs, D_MODEL, padding_idx=PAD)
        self.pos = nn.Embedding(CONTEXT_LEN, D_MODEL)
        self.drop = nn.Dropout(0.1)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, 0.1, 'gelu', batch_first=True), N_LAYERS)
        self.head = nn.Linear(D_MODEL, vs)
        self.ew_head = nn.Linear(D_MODEL, 1)
        self.register_buffer('ctr', torch.zeros(D_MODEL)); self.ci = False
    def forward(self, ids, mask=None):
        x = self.tok(ids) + self.pos(torch.arange(ids.size(1), device=ids.device))
        h = self.enc(self.drop(x), src_key_padding_mask=(mask==0) if mask is not None else None)
        return self.head(h), h[:,0,:], self.ew_head(h[:,0,:]).squeeze(-1)
    def upd(self, e):
        with torch.no_grad(): bc = e.mean(0); self.ctr = bc if not self.ci else 0.9*self.ctr+0.1*bc; self.ci = True

ckpt = torch.load(f"{CKPT_DIR}/logbert_ep11.pt", map_location=device)
VOCAB_SIZE = ckpt['model_state_dict']['tok.weight'].shape[0]
model = LogBERTEW(VOCAB_SIZE).to(device)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
model.ctr = ckpt['center'].to(device); model.ci = True
print(f"Model loaded: VOCAB={VOCAB_SIZE}")

In [None]:
# STAGE 1: Detection + Pre-warning (LOOKAHEAD_BLOCKS approach)
def create_prefixes_stage1(traces, is_failure=False, lookahead=LOOKAHEAD_BLOCKS, step=3, min_len=6, max_per=15):
    """Label=1 if this prefix is within LOOKAHEAD of trace end (failure)."""
    samples = []
    for trace in traces:
        if len(trace) < min_len: continue
        prefixes = []
        for k in range(min_len, len(trace)+1, step):
            remaining = len(trace) - k
            # For failure traces: label=1 if close to end (within lookahead)
            if is_failure:
                ew = 1 if remaining <= lookahead else 0
            else:
                ew = 0
            prefixes.append((trace[:k], ew, remaining))  # Include remaining for position encoding
        if len(prefixes) > max_per:
            # Keep mix of pos and neg
            pos = [p for p in prefixes if p[1]==1]
            neg = [p for p in prefixes if p[1]==0]
            if len(pos) > max_per//2:
                pos = random.sample(pos, max_per//2)
            neg = random.sample(neg, min(len(neg), max_per - len(pos)))
            prefixes = pos + neg
        samples.extend(prefixes)
    return samples

print("Stage 1 prefixes (LOOKAHEAD approach)...")
s1_normal = create_prefixes_stage1(normal_train, is_failure=False)
s1_failure = create_prefixes_stage1(failure_train, is_failure=True)
s1_pos = sum(1 for p in s1_failure if p[1]==1)
print(f"Stage 1: normal={len(s1_normal):,}, failure={len(s1_failure):,}, pos={s1_pos:,}")

In [None]:
# STAGE 2: TRUE PRE-WARNING
print("\n" + "="*50)
print("STAGE 2: TRUE Pre-Warning")
print("="*50)

model.load_state_dict(torch.load(f"{OUTPUT_DIR}/hdfs_ew_stage1.pt", map_location=device))
print("Loaded Stage 1 checkpoint")

s2_ds = EWDataset(s2_normal, s2_failure, VOCAB_SIZE, ratio=5)
s2_loader = DataLoader(s2_ds, batch_size=BATCH_SIZE, shuffle=True)
POS_WEIGHT_S2 = len(s2_normal) / max(s2_pos, 1)
ew_crit_s2 = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT_S2], device=device))

opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
best_s2 = float('inf')

for ep in range(5):
    model.train(); tl = 0
    for b in tqdm(s2_loader, desc=f"S2.{ep+1}"):
        ids, mask, ew_l, mlm_l = b['ids'].to(device), b['mask'].to(device), b['ew'].to(device), b['mlm'].to(device)
        opt.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                mlm_lg, cls, ew = model(ids, mask)
                loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + 1.0*ew_crit_s2(ew, ew_l)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            mlm_lg, cls, ew = model(ids, mask)
            loss = 0.2*F.cross_entropy(mlm_lg.view(-1,VOCAB_SIZE), mlm_l.view(-1), ignore_index=-100) + 0.1*torch.mean((cls-model.ctr)**2) + 1.0*ew_crit_s2(ew, ew_l)
            loss.backward(); opt.step()
        model.upd(cls.detach()); tl += loss.item()
    avg = tl/len(s2_loader); print(f"  {ep+1}: {avg:.4f}")
    if avg < best_s2: best_s2 = avg; torch.save(model.state_dict(), f"{OUTPUT_DIR}/hdfs_ew_prewarn.pt")

print(f"Stage 2 best: {best_s2:.4f}")

In [None]:
# EVALUATION (F1-optimal, per-failure)
print("\nEVALUATION")
model.load_state_dict(torch.load(f"{OUTPUT_DIR}/hdfs_ew_prewarn.pt", map_location=device))
model.eval()

@torch.no_grad()
def score_trace(trace, step=3):
    if len(trace) < 6: return []
    probs = []
    for k in range(6, len(trace)+1, step):
        s = [min(t+OFF, VOCAB_SIZE-1) for t in trace[:k][:CONTEXT_LEN-2]]
        tok = [CLS]+s+[SEP]+[PAD]*(CONTEXT_LEN-len(s)-2)
        ids = torch.tensor([tok], device=device)
        _, _, ew = model(ids)
        probs.append((k, torch.sigmoid(ew).item()))
    return probs

failure_scores = []
for t in tqdm(failure_test, desc="F"):
    probs = score_trace(t)
    if probs: failure_scores.append({'trace': t, 'max_prob': max(p for _,p in probs), 'probs': probs})

normal_scores = []
for t in tqdm(normal_test, desc="N"):
    probs = score_trace(t)
    if probs: normal_scores.append({'trace': t, 'max_prob': max(p for _,p in probs)})

print(f"Scored: {len(failure_scores)} F, {len(normal_scores)} N")

In [None]:
# F1-OPTIMAL + METRICS
all_probs = [s['max_prob'] for s in failure_scores] + [s['max_prob'] for s in normal_scores]
all_labels = [1]*len(failure_scores) + [0]*len(normal_scores)

auc = roc_auc_score(all_labels, all_probs) if len(set(all_labels)) > 1 else 0.5
prec, rec, thr = precision_recall_curve(all_labels, all_probs)
f1 = 2*prec*rec/(prec+rec+1e-9)
best_idx = f1.argmax()
TH = thr[best_idx] if best_idx < len(thr) else thr[-1] if len(thr) > 0 else 0.5

tp, leads = [], []
for s in failure_scores:
    if s['max_prob'] > TH:
        tp.append(s)
        for k, p in s['probs']:
            if p > TH: leads.append(len(s['trace']) - k); break
fp = [s for s in normal_scores if s['max_prob'] > TH]

recall = len(tp)/len(failure_scores) if failure_scores else 0
far = len(fp)/len(normal_scores) if normal_scores else 0
precision = len(tp)/(len(tp)+len(fp)) if (len(tp)+len(fp))>0 else 0
f1_final = 2*precision*recall/(precision+recall) if (precision+recall)>0 else 0

print(f"\n=== RESULTS ===")
print(f"AUC: {auc:.4f}, TH: {TH:.4f}")
print(f"F1: {f1_final:.4f}, P: {precision:.4f}, R: {recall:.4f}, FAR: {far:.4f}")
if leads: print(f"Lead: median={np.median(leads):.0f}, mean={np.mean(leads):.0f}")

In [None]:
# SAVE
results = {'dataset': 'HDFS', 'method': 'curriculum_true_prewarn',
    'metrics': {'auc': float(auc), 'threshold': float(TH), 'f1': float(f1_final),
        'precision': float(precision), 'recall': float(recall), 'far': float(far),
        'lead_median': float(np.median(leads)) if leads else 0,
        'lead_mean': float(np.mean(leads)) if leads else 0}}
with open(f"{OUTPUT_DIR}/hdfs_ew_results.json", 'w') as f: json.dump(results, f, indent=2)
print(f"Saved: {OUTPUT_DIR}/hdfs_ew_results.json")
print(json.dumps(results['metrics'], indent=2))