# Prefix Scoring with Per-Trace FPR Calibration

**Outputs:** FAR-Recall curve, Timeline plot, Lead distribution

In [None]:
#=============================================================================
# SETUP
#=============================================================================
import os, json, random
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

BASE = "/teamspace/studios/this_studio"
V1_PATH = f"{BASE}/content/LogHub_HDFS/HDFS_v1/preprocessed/Event_traces.csv"
CKPT_DIR = f"{BASE}/checkpoints"
OUTPUT_DIR = f"{BASE}/output-prefix"
os.makedirs(OUTPUT_DIR, exist_ok=True)

CONTEXT_LEN, D_MODEL, N_HEADS, N_LAYERS = 128, 256, 8, 4
PAD, CLS, MASK, SEP, OFF = 0, 1, 2, 3, 4
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

In [None]:
#=============================================================================
# LOAD DATA & MODEL
#=============================================================================
df = pd.read_csv(V1_PATH)
v1_normal, v1_failure = [], []
for _, row in tqdm(df.iterrows(), total=len(df), desc="V1"):
    features = str(row.get('Features', '[]')).strip('[]"')
    events = [int(e.strip().strip("'")[1:]) for e in features.split(',') if e.strip().startswith('E')]
    if events:
        (v1_normal if row['Label'] == 'Success' else v1_failure).append(events)

def seed_split(seqs, r=0.2):
    rng = np.random.RandomState(SEED)
    idx = rng.permutation(len(seqs))
    s = int(len(seqs)*(1-r))
    return [seqs[i] for i in idx[:s]], [seqs[i] for i in idx[s:]]

normal_val, normal_test = seed_split(v1_normal)
failure_val, failure_test = seed_split(v1_failure)
print(f"Normal: {len(v1_normal):,}, Failure: {len(v1_failure):,}")
print(f"VAL: {len(normal_val)} N | TEST: {len(normal_test)} N, {len(failure_test)} F")

class LogBERT(nn.Module):
    def __init__(self, vs, dm=D_MODEL, nh=N_HEADS, nl=N_LAYERS, ml=CONTEXT_LEN):
        super().__init__()
        self.tok = nn.Embedding(vs, dm, padding_idx=PAD)
        self.pos = nn.Embedding(ml, dm)
        self.drop = nn.Dropout(0.1)
        el = nn.TransformerEncoderLayer(dm, nh, dm*4, 0.1, 'gelu', batch_first=True)
        self.enc = nn.TransformerEncoder(el, nl)
        self.head = nn.Linear(dm, vs)
        self.register_buffer('ctr', torch.zeros(dm))
    def forward(self, ids, mask=None):
        x = self.tok(ids) + self.pos(torch.arange(ids.size(1), device=ids.device))
        h = self.enc(self.drop(x), src_key_padding_mask=(mask==0) if mask is not None else None)
        return self.head(h), h[:,0,:]

ckpt = torch.load(f"{CKPT_DIR}/logbert_ep11.pt", map_location=device)
VOCAB_SIZE = ckpt['model_state_dict']['tok.weight'].shape[0]
model = LogBERT(VOCAB_SIZE).to(device)
model.load_state_dict(ckpt['model_state_dict'])
model.ctr = ckpt['center'].to(device)
model.eval()
print(f"✅ Model: vocab={VOCAB_SIZE}")

In [None]:
#=============================================================================
# SCORING + LENGTH NORMALIZATION
#=============================================================================
@torch.no_grad()
def score_prefix(seq, plen):
    s = [min(t+OFF, VOCAB_SIZE-1) for t in seq[:plen][:CONTEXT_LEN-2]]
    tok = [CLS] + s + [SEP] + [PAD]*(CONTEXT_LEN-len(s)-2)
    ids = torch.tensor([tok], device=device)
    _, cls = model(ids)
    return torch.sum((cls - model.ctr)**2).item()

PREFIX_LENGTHS = list(range(10, CONTEXT_LEN+1, 10))
normal_sample = normal_val[:2000]

mu_k, sigma_k = {}, {}
for k in tqdm(PREFIX_LENGTHS, desc="μ,σ"):
    scores = [score_prefix(seq, k) for seq in normal_sample if len(seq) >= k]
    mu_k[k] = np.mean(scores) if scores else 0
    sigma_k[k] = np.std(scores) + 1e-6 if scores else 1

def z_score(s, k):
    return (s - mu_k.get(k, 0)) / sigma_k.get(k, 1)

def trace_score(seq, step=10):
    max_k = min(len(seq), CONTEXT_LEN)
    z_scores = [z_score(score_prefix(seq, k), k) for k in range(10, max_k+1, step)]
    return max(z_scores) if z_scores else 0

print("Computing trace_scores...")
normal_trace_scores = np.array([trace_score(seq) for seq in tqdm(normal_sample, desc="Normal")])
failure_trace_scores = np.array([trace_score(seq) for seq in tqdm(failure_test, desc="Failure")])
print(f"Normal: mean={normal_trace_scores.mean():.4f}, Failure: mean={failure_trace_scores.mean():.4f}")

In [None]:
#=============================================================================
# FAR-RECALL CURVE (Key visualization)
#=============================================================================
FAR_VALUES = [0.001, 0.005, 0.01, 0.02, 0.05, 0.10]  # 0.1% to 10%

far_recall_data = []
for far in FAR_VALUES:
    th = np.quantile(normal_trace_scores, 1 - far)
    recall = (failure_trace_scores > th).mean()
    far_recall_data.append({'far': far, 'threshold': th, 'recall': recall})
    print(f"FAR={far*100:5.1f}% → TH={th:.3f} → Recall={recall:.1%}")

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
fars = [d['far']*100 for d in far_recall_data]
recalls = [d['recall']*100 for d in far_recall_data]

ax.plot(fars, recalls, 'o-', color='steelblue', linewidth=2, markersize=10)

# Annotate points
for i, d in enumerate(far_recall_data):
    ax.annotate(f"{d['recall']*100:.0f}%", (fars[i], recalls[i]), 
                textcoords="offset points", xytext=(5, 10), fontsize=10)

# Highlight operating point (FAR=1%)
ax.axvline(x=1, color='red', linestyle='--', alpha=0.5, label='Operating point (FAR=1%)')

ax.set_xlabel('False Alarm Rate (%)', fontsize=12)
ax.set_ylabel('Recall (%)', fontsize=12)
ax.set_title('FAR–Recall Trade-off (Prefix Scoring)', fontsize=14, fontweight='bold')
ax.set_xscale('log')
ax.set_xticks([0.1, 0.5, 1, 2, 5, 10])
ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/far_recall_curve.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"✅ Saved: {OUTPUT_DIR}/far_recall_curve.png")

In [None]:
#=============================================================================
# EARLY WARNING AT OPERATING POINT (FAR=1%)
#=============================================================================
FAR_TARGET = 0.01
TH = np.quantile(normal_trace_scores, 1 - FAR_TARGET)
print(f"Operating point: FAR={FAR_TARGET*100}%, TH={TH:.4f}")

def early_detect(seq, th, step=5):
    max_k = min(len(seq), CONTEXT_LEN)
    for k in range(10, max_k+1, step):
        if z_score(score_prefix(seq, k), k) > th:
            return k, len(seq)
    return None, len(seq)

print(f"\nEarly warning on TEST failures ({len(failure_test)})...")
ew_results = []
for seq in tqdm(failure_test):
    alarm, total = early_detect(seq, TH)
    ew_results.append({'alarm': alarm, 'total': total, 'lead': (total-alarm) if alarm else 0, 'seq': seq})

detected = [r for r in ew_results if r['alarm']]
recall = len(detected) / len(ew_results)
leads = [r['lead'] for r in detected]

print(f"\n=== EARLY WARNING @ FAR={FAR_TARGET*100}% ===")
print(f"Recall: {len(detected)}/{len(ew_results)} ({recall:.1%})")
if leads:
    print(f"Lead: median={np.median(leads):.0f}, mean={np.mean(leads):.0f}")
    print(f"      p25={np.percentile(leads,25):.0f}, p75={np.percentile(leads,75):.0f}")

In [None]:
#=============================================================================
# TIMELINE PLOT
#=============================================================================
def get_all_z_scores(seq, step=5):
    max_k = min(len(seq), CONTEXT_LEN)
    return [(k, z_score(score_prefix(seq, k), k)) for k in range(10, max_k+1, step)]

sample_traces = []
for r in detected[:2]:
    sample_traces.append(('Failure (detected)', r['seq'], r['alarm']))
for r in [r for r in ew_results if not r['alarm']][:1]:
    sample_traces.append(('Failure (missed)', r['seq'], None))
sample_traces.append(('Normal', normal_test[0], None))

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for i, (label, seq, alarm_pos) in enumerate(sample_traces):
    ax = axes[i//2, i%2]
    zs = get_all_z_scores(seq, step=3)
    x, y = [z[0] for z in zs], [z[1] for z in zs]
    color = 'red' if 'Failure' in label else 'blue'
    ax.plot(x, y, color=color, linewidth=2)
    ax.axhline(y=TH, color='green', linestyle='--', linewidth=2, label=f'TH={TH:.2f}')
    if alarm_pos:
        ax.axvline(x=alarm_pos, color='orange', linestyle=':', linewidth=2, label=f'Alarm@{alarm_pos}')
    ax.set_title(f'{label} (len={len(seq)})', fontsize=12, fontweight='bold')
    ax.set_xlabel('Prefix length'); ax.set_ylabel('Z-score')
    ax.legend(loc='upper left'); ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/timeline_plot.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"✅ Saved: {OUTPUT_DIR}/timeline_plot.png")

In [None]:
#=============================================================================
# LEAD DISTRIBUTION
#=============================================================================
if leads:
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.hist(leads, bins=30, color='steelblue', edgecolor='white', alpha=0.8)
    ax.axvline(x=np.median(leads), color='red', linestyle='--', lw=2, label=f'Median={np.median(leads):.0f}')
    ax.axvline(x=np.mean(leads), color='orange', linestyle='--', lw=2, label=f'Mean={np.mean(leads):.0f}')
    ax.set_xlabel('Lead events'); ax.set_ylabel('Count')
    ax.set_title('Lead-Time Distribution', fontsize=14, fontweight='bold')
    ax.legend(); ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/lead_distribution.png', dpi=150)
    plt.show()
    print(f"✅ Saved: {OUTPUT_DIR}/lead_distribution.png")

In [None]:
#=============================================================================
# SAVE RESULTS
#=============================================================================
fp = sum(1 for seq in tqdm(normal_test[:500], desc="FP") if early_detect(seq, TH)[0])

results = {
    'method': 'per_trace_far_calibration',
    'operating_point': {'far_target': FAR_TARGET, 'threshold': float(TH)},
    'far_recall_curve': far_recall_data,
    'early_warning': {
        'tested': len(ew_results), 'detected': len(detected), 'recall': float(recall),
        'lead_events': {
            'mean': float(np.mean(leads)) if leads else 0,
            'median': float(np.median(leads)) if leads else 0,
            'p25': float(np.percentile(leads, 25)) if leads else 0,
            'p75': float(np.percentile(leads, 75)) if leads else 0
        }
    },
    'false_positive': {'tested': 500, 'fp_count': fp, 'fp_rate': float(fp/500)}
}

with open(f'{OUTPUT_DIR}/early_warning.json', 'w') as f:
    json.dump(results, f, indent=2)

print('\n' + '='*60)
print('FINAL RESULTS')
print('='*60)
print(f"FAR-Recall curve: {len(far_recall_data)} points")
print(f"Recall @ FAR={FAR_TARGET*100}%: {recall:.1%}")
print(f"Lead: median={np.median(leads):.0f}, mean={np.mean(leads):.0f}")
print(f"FP: {fp}/500 ({fp/500:.1%})")
print(f"\n✅ Saved: {OUTPUT_DIR}/early_warning.json")