# Phase 1
## Dataset
AG News
## Metrics: 
1. Acc
2. Brier Score
3. Expected Calibration Error (ECE)
4. Expected Reducble Uncertatinty (ERU)
5. Predictive Discrepecy (OOD AUROC) - P Dist
6. Predictive Dispersion (P-Disp)



In [101]:
!pip install datasets
!pip install scikit-learn
!pip install transformers
!pip install torch
!pip install torchmetrics
!pip install matplotlib
!pip install seaborn

Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2


In [102]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, set_seed
from datasets import load_dataset
import time
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pandas as pd

# ==========================================
# 1. Environment Setup
# ==========================================

In [103]:
def setup_environment():
    set_seed(42)
    os.makedirs("data/", exist_ok=True)
    os.makedirs("reports/phase1/", exist_ok=True)
    print("Environment setup complete. Seed=42.")

# MPS detection + speed tweaks
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ Using Apple MPS (M2 accelerated)")
    torch.set_float32_matmul_precision('high')   # critical for speed on MPS
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    print("⚠️  Running on CPU – will be slow")
print(f"Device: {device}")

✅ Using Apple MPS (M2 accelerated)
Device: mps


# ==========================================
# 2. Data Pipeline
# ==========================================

In [104]:
def prepare_datasets():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    def tokenize_func(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

    ag_news = load_dataset("ag_news", cache_dir="data/")
    train_val_split = ag_news["train"].train_test_split(test_size=6000, seed=42)
    train_data = train_val_split["train"]
    val_data = train_val_split["test"]
    clean_test = ag_news["test"]

    def make_noisy_test(clean_test, noise_rate=0.2, seed=42):
        rng = random.Random(seed)
        n = len(clean_test)
        k = int(n * noise_rate)
        flip_index = rng.sample(range(n), k)

        def inject_noise(example, idx):
            is_flipped = idx in flip_index
            if is_flipped:
                original_label = example["label"]
                possible_flips = [l for l in range(4) if l != original_label]
                example["label"] = random.choice(possible_flips)
            example["is_flipped"] = is_flipped
            return example
        return clean_test.map(inject_noise, with_indices=True)

    noisy_test = make_noisy_test(clean_test, noise_rate=0.2, seed=42)

    imdb = load_dataset("imdb", split="test", cache_dir="data/")
    ood_test = imdb.select(range(3000))
    ood_test = ood_test.map(lambda x: {"label": 0})

    # Tokenize (cached)
    encoded_train = train_data.map(tokenize_func, batched=True).with_format("torch", columns=["input_ids", "attention_mask", "label"])
    encoded_val = val_data.map(tokenize_func, batched=True).with_format("torch", columns=["input_ids", "attention_mask", "label"])
    encoded_clean_test = clean_test.map(tokenize_func, batched=True).with_format("torch", columns=["input_ids", "attention_mask", "label"])
    encoded_noisy_test = noisy_test.map(tokenize_func, batched=True).with_format("torch", columns=["input_ids", "attention_mask", "label", "is_flipped"])
    encoded_ood_test = ood_test.map(tokenize_func, batched=True).with_format("torch", columns=["input_ids", "attention_mask", "label"])

    return encoded_train, encoded_val, encoded_clean_test, encoded_noisy_test, encoded_ood_test

# ==========================================
# 3. Model & Loss
# ==========================================

In [105]:
class EvidentialHead(nn.Module):
    def __init__(self, input_dim=768, num_classes=4):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, features):
        raw = self.linear(features)
        evidence = F.softplus(raw)
        alpha = evidence + 1.0
        return evidence, alpha

class BertWithEvidentialHead(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        for p in self.bert.parameters():
            p.requires_grad = False
        self.bert.eval()
        self.head = EvidentialHead(768, num_classes)

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        self.bert.eval()
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        cls = outputs.last_hidden_state[:, 0, :]
        return self.head(cls)

def get_expected_probs(alpha):
    S = alpha.sum(dim=1, keepdim=True).clamp_min(1e-12)
    return alpha / S

def brier_score_loss(alpha, target, num_classes=4):
    p_hat = get_expected_probs(alpha)
    one_hot = torch.zeros(target.size(0), num_classes, device=target.device)
    one_hot.scatter_(1, target.unsqueeze(1), 1)
    return torch.mean(torch.sum((p_hat - one_hot) ** 2, dim=1))

def quick_gradient_check(model, num_classes=4):
    model.train()
    input_ids = torch.randint(0, 30522, (2, 8), device=device)
    attention_mask = torch.ones_like(input_ids)
    target = torch.randint(0, num_classes, (2,), device=device)
    _, alpha = model(input_ids=input_ids, attention_mask=attention_mask)
    loss = brier_score_loss(alpha, target, num_classes=num_classes)
    assert torch.isfinite(loss).all()
    model.zero_grad()
    loss.backward()
    assert all(p.grad is not None for p in model.head.parameters())
    print("Gradient check passed.")

# ==========================================
# 4. Metrics Implementation
# ==========================================

In [64]:
def base_accuracy(alpha, target):
    p_mean = get_expected_probs(alpha)
    preds = torch.argmax(p_mean, dim=1)
    return (preds == target).float().mean()

def _ece_from_probs(probs, target, n_bins=15):
    conf, preds = probs.max(dim=1)
    acc = preds.eq(target).float()
    bins = torch.linspace(0, 1, n_bins + 1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        mask = (conf >= bins[i]) & (conf <= bins[i + 1]) if i == n_bins - 1 else (conf >= bins[i]) & (conf < bins[i + 1])
        if mask.any():
            ece += mask.float().mean() * (conf[mask].mean() - acc[mask].mean()).abs()
    return ece.squeeze(0)

def eru_mutual_information(alpha, eps=1e-12):
    S = alpha.sum(dim=1, keepdim=True).clamp_min(eps)
    p_mean = alpha / S
    predictive_entropy = -torch.sum(p_mean * torch.log(p_mean.clamp_min(eps)), dim=1)
    expected_entropy = -torch.sum((alpha / S) * (torch.digamma(alpha + 1) - torch.digamma(S + 1)), dim=1)
    return predictive_entropy - expected_entropy

def _kl_p_mean_expected(p_mean, expected_p, eps=1e-12):
    if expected_p.dim() == 1:
        expected_p = expected_p.unsqueeze(0)
    expected_p = expected_p / expected_p.sum(dim=1, keepdim=True).clamp_min(eps)
    return torch.sum(p_mean.clamp_min(eps) * (torch.log(p_mean.clamp_min(eps)) - torch.log(expected_p.clamp_min(eps))), dim=1)

def p_disc_auroc(alpha_clean, alpha_ood):
    try:
        from torchmetrics.classification import BinaryAUROC
    except ImportError:
        raise ImportError("Please install torchmetrics for AUROC calculation")
        
    p_mean_clean = get_expected_probs(alpha_clean)
    p_mean_ood = get_expected_probs(alpha_ood)
    expected_p = p_mean_clean.mean(dim=0) # Aggregate expected probability
    
    scores_clean = _kl_p_mean_expected(p_mean_clean, expected_p)
    scores_ood = _kl_p_mean_expected(p_mean_ood, expected_p)
    
    scores = torch.cat([scores_clean, scores_ood], dim=0)
    labels = torch.cat([torch.zeros_like(scores_clean, dtype=torch.long), torch.ones_like(scores_ood, dtype=torch.long)], dim=0)
    
    metric = BinaryAUROC().to(scores.device)
    return metric(scores, labels)

def p_disp_cohens_d(alpha_clean, alpha_noisy, eps=1e-12):
    p_mean_clean = get_expected_probs(alpha_clean)
    p_mean_noisy = get_expected_probs(alpha_noisy)
    
    ent_clean = -torch.sum(p_mean_clean * torch.log(p_mean_clean.clamp_min(eps)), dim=1)
    ent_noisy = -torch.sum(p_mean_noisy * torch.log(p_mean_noisy.clamp_min(eps)), dim=1)
    
    v1, v2 = ent_clean.var(unbiased=False), ent_noisy.var(unbiased=False)
    n1, n2 = ent_clean.numel(), ent_noisy.numel()
    pooled_std = torch.sqrt(((n1 * v1) + (n2 * v2)) / (n1 + n2) + eps)
    return (ent_noisy.mean() - ent_clean.mean()) / pooled_std # Noisy should have higher entropy

# ==========================================
# 5. Training Loop & Evaluation
# ==========================================

In [108]:
# def extract_tensors(batch):
#     return batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["label"].to(device)

# def get_all_alphas_and_targets(model, dataloader):
#     model.eval()
#     all_alphas, all_targets = [], []
#     with torch.no_grad():
#         for batch in dataloader:
#             input_ids, attn_mask, labels = extract_tensors(batch)
#             _, alpha = model(input_ids=input_ids, attention_mask=attn_mask)
#             all_alphas.append(alpha.cpu())   # move to CPU early to free MPS memory
#             all_targets.append(labels.cpu())
#     return torch.cat(all_alphas).to(device), torch.cat(all_targets).to(device)

# def train_and_evaluate():
#     setup_environment()                     # ← moved to top
#     train_data, val_data, clean_test, noisy_test, ood_test = prepare_datasets()
    
#     # Optimized DataLoaders for MPS
#     train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0, persistent_workers=False)
#     val_loader   = DataLoader(val_data,   batch_size=256, shuffle=False, num_workers=0, persistent_workers=False)

#     model = BertWithEvidentialHead(num_classes=4).to(device)   # ← ONLY ONCE
#     quick_gradient_check(model)
    
#     optimizer = optim.AdamW(model.head.parameters(), lr=5e-4)
#     epochs = 5

#     print("\n--- Starting Training ---")
#     for epoch in range(1, epochs + 1):
#         start_time = time.time()
#         model.train()
#         total_loss, total_count = 0.0, 0
        
#         for batch in train_loader:
#             input_ids, attn_mask, labels = extract_tensors(batch)
#             _, alpha = model(input_ids=input_ids, attention_mask=attn_mask)
            
#             loss = brier_score_loss(alpha, labels, num_classes=4)
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#             total_loss += loss.item() * labels.size(0)
#             total_count += labels.size(0)
        
#         # LIGHT validation (only accuracy, fast)
#         with torch.no_grad():
#             val_alphas, val_targets = get_all_alphas_and_targets(model, val_loader)
#             val_acc = base_accuracy(val_alphas, val_targets)
        
#         epoch_time = time.time() - start_time
#         print(f"Epoch {epoch}/{epochs} | Train Loss: {total_loss/total_count:.4f} | Val Acc: {val_acc:.4f} | AvgTime per Epoch: {epoch_time:.1f}s")

#     print("\n--- Running Full Metric Suite (once) ---")
#     # Final heavy metrics (larger batch for speed)
#     test_loader = lambda ds: DataLoader(ds, batch_size=256, shuffle=False, num_workers=0)
#     alpha_clean, targets_clean = get_all_alphas_and_targets(model, test_loader(clean_test))
#     alpha_noisy, _ = get_all_alphas_and_targets(model, test_loader(noisy_test))
#     alpha_ood, _   = get_all_alphas_and_targets(model, test_loader(ood_test))

#     # (rest of your metric prints unchanged)
#     clean_acc = base_accuracy(alpha_clean, targets_clean)
#     clean_brier = brier_score_loss(alpha_clean, targets_clean, num_classes=4)
#     clean_ece = _ece_from_probs(get_expected_probs(alpha_clean), targets_clean, n_bins=15)
#     clean_eru = eru_mutual_information(alpha_clean).mean()
#     auroc = p_disc_auroc(alpha_clean, alpha_ood)
#     cohens_d = p_disp_cohens_d(alpha_clean, alpha_noisy)

#     print(f"Clean Accuracy:  {clean_acc:.4f} (Target > 0.90)")
#     print(f"Brier Score:     {clean_brier:.4f}")
#     print(f"ECE:             {clean_ece:.4f} (Target < 0.05)")
#     print(f"ERU (Mean):      {clean_eru:.4f}")
#     print(f"P-Disc (AUROC):  {auroc:.4f} (Target > 0.85)")
#     print(f"P-Disp (Cohen's d): {cohens_d:.4f} (Target > 1.3)")

In [109]:
def extract_tensors(batch):
    return batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["label"].to(device)

def get_all_alphas_and_targets(model, dataloader):
    model.eval()
    all_alphas, all_targets = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attn_mask, labels = extract_tensors(batch)
            _, alpha = model(input_ids=input_ids, attention_mask=attn_mask)
            all_alphas.append(alpha.cpu())   # move to CPU early to free MPS memory
            all_targets.append(labels.cpu())
    return torch.cat(all_alphas).to(device), torch.cat(all_targets).to(device)

def train_and_evaluate():
    setup_environment()
    train_data, val_data, clean_test, noisy_test, ood_test = prepare_datasets()
    
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_data,   batch_size=256, shuffle=False, num_workers=0)

    model = BertWithEvidentialHead(num_classes=4).to(device)
    quick_gradient_check(model)
    
    optimizer = optim.AdamW(model.head.parameters(), lr=5e-4)
    epochs = 5

    # === COLLECT HISTORY FOR PLOTS ===
    train_losses = []
    val_accs = []
    epoch_times = []

    print("\n--- Starting Training ---")
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        model.train()
        total_loss, total_count = 0.0, 0
        
        for batch in train_loader:
            input_ids, attn_mask, labels = extract_tensors(batch)
            _, alpha = model(input_ids=input_ids, attention_mask=attn_mask)
            
            loss = brier_score_loss(alpha, labels, num_classes=4)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item() * labels.size(0)
            total_count += labels.size(0)
        
        # Light validation
        with torch.no_grad():
            val_alphas, val_targets = get_all_alphas_and_targets(model, val_loader)
            val_acc = base_accuracy(val_alphas, val_targets).item()
        
        epoch_time = time.time() - start_time
        train_losses.append(total_loss / total_count)
        val_accs.append(val_acc)
        epoch_times.append(epoch_time)
        
        print(f"Epoch {epoch}/{epochs} | Loss: {train_losses[-1]:.4f} | Val Acc: {val_acc:.4f} | Time: {epoch_time:.1f}s")

    print("\n--- Running Full Metric Suite ---")
    test_loader = lambda ds: DataLoader(ds, batch_size=256, shuffle=False, num_workers=0)
    alpha_clean, targets_clean = get_all_alphas_and_targets(model, test_loader(clean_test))
    alpha_noisy, _ = get_all_alphas_and_targets(model, test_loader(noisy_test))
    alpha_ood, _   = get_all_alphas_and_targets(model, test_loader(ood_test))

    # Core metrics
    clean_acc = base_accuracy(alpha_clean, targets_clean).item()
    clean_brier = brier_score_loss(alpha_clean, targets_clean, num_classes=4).item()
    clean_ece = _ece_from_probs(get_expected_probs(alpha_clean), targets_clean).item()
    clean_eru = eru_mutual_information(alpha_clean).mean().item()
    auroc = p_disc_auroc(alpha_clean, alpha_ood).item()
    cohens_d = p_disp_cohens_d(alpha_clean, alpha_noisy).item()

    print(f"Clean Accuracy:  {clean_acc:.4f}")
    print(f"Brier Score:     {clean_brier:.4f}")
    print(f"ECE:             {clean_ece:.4f}")
    print(f"ERU:             {clean_eru:.4f}")
    print(f"P-Disc (AUROC):  {auroc:.4f}")
    print(f"P-Disp (d):      {cohens_d:.4f}")

    # === VISUALIZE ===
    metrics_dict = {
        "Accuracy": clean_acc,
        "Brier": clean_brier,
        "ECE": clean_ece,
        "ERU": clean_eru,
        "P-Disc": auroc,
        "P-Disp": cohens_d
    }
    save_phase1_visualizations(train_losses, val_accs, epoch_times,
                               alpha_clean, targets_clean,
                               alpha_noisy, alpha_ood,
                               metrics_dict)
    
    print(f"\n✅ All plots saved to reports/phase1/")

def save_phase1_visualizations(train_losses, val_accs, epoch_times,
                               alpha_clean, targets_clean,
                               alpha_noisy, alpha_ood,
                               metrics_dict, save_dir="reports/phase1"):
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    sns.set_style("whitegrid")

    # 1. Training curves
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    epochs = range(1, len(train_losses) + 1)
    
    axs[0].plot(epochs, train_losses, 'b-o', linewidth=2)
    axs[0].set_title("Train Loss (Brier)")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    
    axs[1].plot(epochs, val_accs, 'g-o', linewidth=2)
    axs[1].set_title("Validation Accuracy")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    
    axs[2].bar(epochs, epoch_times, color='orange')
    axs[2].set_title("Time per Epoch (seconds)")
    axs[2].set_xlabel("Epoch")
    axs[2].set_ylabel("Seconds")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/training_curves.png", dpi=200, bbox_inches='tight')
    plt.close()

    # 2. Reliability diagram (ECE)
    p_mean = get_expected_probs(alpha_clean)
    conf, preds = p_mean.max(dim=1)
    acc = preds.eq(targets_clean).float()
    bins = torch.linspace(0, 1, 16, device=conf.device)
    bin_acc, bin_conf, bin_count = [], [], []
    for i in range(15):
        mask = (conf >= bins[i]) & (conf < bins[i+1])
        if mask.any():
            bin_conf.append(conf[mask].mean().item())
            bin_acc.append(acc[mask].mean().item())
            bin_count.append(mask.sum().item())
    
    plt.figure(figsize=(8, 6))
    plt.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    plt.scatter(bin_conf, bin_acc, s=80, c=bin_count, cmap='Blues', edgecolors='black')
    plt.colorbar(label='Samples per bin')
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title(f"Reliability Diagram (ECE = {metrics_dict['ECE']:.4f})")
    plt.legend()
    plt.savefig(f"{save_dir}/reliability_diagram.png", dpi=200, bbox_inches='tight')
    plt.close()

    # 3. Entropy distribution (P-Disp)
    ent_clean = -torch.sum(get_expected_probs(alpha_clean) * 
                          torch.log(get_expected_probs(alpha_clean).clamp_min(1e-12)), dim=1).cpu().numpy()
    ent_noisy = -torch.sum(get_expected_probs(alpha_noisy) * 
                          torch.log(get_expected_probs(alpha_noisy).clamp_min(1e-12)), dim=1).cpu().numpy()
    
    plt.figure(figsize=(8, 6))
    sns.boxplot(data=[ent_clean, ent_noisy], notch=True)
    plt.xticks([0, 1], ['Clean Test', '20% Noisy'])
    plt.ylabel("Predictive Entropy")
    plt.title(f"P-Disp: Cohen's d = {metrics_dict['P-Disp']:.3f}")
    plt.savefig(f"{save_dir}/entropy_distribution.png", dpi=200, bbox_inches='tight')
    plt.close()

    # 4. KL discrepancy histogram (P-Disc)
    p_mean_clean = get_expected_probs(alpha_clean)
    p_mean_ood = get_expected_probs(alpha_ood)
    expected_p = p_mean_clean.mean(dim=0)
    kl_clean = _kl_p_mean_expected(p_mean_clean, expected_p).cpu().numpy()
    kl_ood = _kl_p_mean_expected(p_mean_ood, expected_p).cpu().numpy()
    
    plt.figure(figsize=(9, 6))
    sns.histplot(kl_clean, kde=True, color='blue', label='Clean (ID)', alpha=0.7, bins=50)
    sns.histplot(kl_ood, kde=True, color='red', label='OOD (IMDB)', alpha=0.7, bins=50)
    plt.axvline(kl_clean.mean(), color='blue', linestyle='--')
    plt.axvline(kl_ood.mean(), color='red', linestyle='--')
    plt.xlabel("KL(p_mean || E[p])")
    plt.ylabel("Count")
    plt.title(f"P-Disc: AUROC = {metrics_dict['P-Disc']:.4f}")
    plt.legend()
    plt.savefig(f"{save_dir}/kl_discrepancy.png", dpi=200, bbox_inches='tight')
    plt.close()

    # 5. Metrics summary bar chart
    df = pd.DataFrame({
        'Metric': list(metrics_dict.keys()),
        'Value': list(metrics_dict.values())
    })
    plt.figure(figsize=(10, 6))
    colors = ['green' if v > 0.9 else 'red' if k in ['P-Disc','P-Disp'] else 'orange' 
              for k, v in metrics_dict.items()]
    bars = plt.bar(df['Metric'], df['Value'], color=colors)
    plt.axhline(0.90, color='gray', linestyle='--', label='Target threshold (0.90)')
    plt.ylabel("Value")
    plt.title("Phase 1 Final Metrics")
    plt.xticks(rotation=45, ha='right')
    plt.legend()
    for bar, val in zip(bars, df['Value']):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                 f'{val:.3f}', ha='center')
    plt.tight_layout()
    plt.savefig(f"{save_dir}/metrics_summary.png", dpi=200, bbox_inches='tight')
    plt.close()

    # Bonus: one-page markdown report
    with open(f"{save_dir}/phase1_report.md", "w") as f:
        f.write("# Phase 1 – UQ Sanity Check Report\n\n")
        f.write(f"**Dataset**: AG News (clean) + IMDB (OOD)\n")
        f.write(f"**Model**: Frozen bert-base-uncased + Evidential Brier head\n\n")
        f.write("## Final Metrics\n")
        f.write(df.to_markdown(index=False))
        f.write("\n\n## Visuals\n")
        f.write("![Training Curves](training_curves.png)\n")
        f.write("![Reliability](reliability_diagram.png)\n")
        f.write("![Entropy](entropy_distribution.png)\n")
        f.write("![KL](kl_discrepancy.png)\n")
        f.write("![Metrics](metrics_summary.png)\n")
        f.write("\n**All asserts passed** ✅\n")

# ==========================================
# 5. Run Code
# ==========================================

In [110]:
if __name__ == "__main__":
    train_and_evaluate()

Environment setup complete. Seed=42.


Loading weights: 100%|██████████| 199/199 [00:00<00:00, 2048.27it/s, Materializing param=pooler.dense.weight]                               
[1mBertModel LOAD REPORT[0m from: bert-base-uncased
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.bias                       | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Gradient check passed.

--- Starting Training ---
Epoch 1/5 | Loss: 0.3744 | Val Acc: 0.8602 | Time: 2162.2s
Epoch 2/5 | Loss: 0.2496 | Val Acc: 0.8723 | Time: 2218.4s
Epoch 3/5 | Loss: 0.2212 | Val Acc: 0.8785 | Time: 2010.1s
Epoch 4/5 | Loss: 0.2064 | Val Acc: 0.8812 | Time: 1899.0s
Epoch 5/5 | Loss: 0.1970 | Val Acc: 0.8817 | Time: 1961.1s

--- Running Full Metric Suite ---
Clean Accuracy:  0.8859
Brier Score:     0.2001
ECE:             0.1322
ERU:             0.0783
P-Disc (AUROC):  0.0792
P-Disp (d):      0.0000


ImportError: `Import tabulate` failed.  Use pip or conda to install the tabulate package.