<a href="https://colab.research.google.com/github/Niharika-Saha/Adaptive-Bacterial-Antibiotic-Resistance-Prediction-using-Meta-Learning/blob/experiments/EXP4_Part2_kmers_motif_validation_on_bayesian_metaoptnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

full code for kmers validation





In [34]:


import os, random, math, warnings
from collections import Counter, defaultdict
import numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import fisher_exact

# ---------------------------
# 0) Settings / Seeds / Device
# ---------------------------
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
warnings.filterwarnings("ignore")

# ---------------------------
# 1) Paths & Basic params
# ---------------------------
CSV_PATH = "megares_fasta_processed.csv"
LABEL_COL = "group"
MIN_COUNT = 10
KMER_K = 5
MAX_LEN = 512
N, K, Q = 3, 3, 5
RC_PROB = 0.25
SUPPORT_DROP = 0.03
QUERY_DROP = 0.01

# ---------------------------
# 2) Load & basic cleaning
# ---------------------------
df = pd.read_csv(CSV_PATH)
df = df.dropna(subset=["sequence", LABEL_COL]).reset_index(drop=True)
df["sequence"] = df["sequence"].str.upper().str.replace(r"[^ACGT]", "", regex=True)

print(f"Raw unique groups: {df[LABEL_COL].nunique()}")

group_counts = df[LABEL_COL].value_counts()
keep_groups = group_counts[group_counts >= MIN_COUNT].index
df = df[df[LABEL_COL].isin(keep_groups)].reset_index(drop=True)

print(f"Filtered dataset shape: {df.shape}")
print(f"Filtered unique groups: {df[LABEL_COL].nunique()}")

# ---------------------------
# 3) Train/Val/Test split by group
# ---------------------------
labels = np.array(sorted(df[LABEL_COL].unique()))
train_labels, temp_labels = train_test_split(labels, test_size=0.30, random_state=RANDOM_SEED, shuffle=True)
val_labels, test_labels = train_test_split(temp_labels, test_size=0.50, random_state=RANDOM_SEED, shuffle=True)

train_df = df[df[LABEL_COL].isin(train_labels)].reset_index(drop=True)
val_df   = df[df[LABEL_COL].isin(val_labels)].reset_index(drop=True)
test_df  = df[df[LABEL_COL].isin(test_labels)].reset_index(drop=True)

print(f"Train/Val/Test {LABEL_COL}s: {len(train_labels)}/{len(val_labels)}/{len(test_labels)}")
print(f"Train/Val/Test samples: {len(train_df)}/{len(val_df)}/{len(test_df)}")

# Define rare groups in training set for later analysis
train_group_counts = train_df[LABEL_COL].value_counts()
rare_threshold = int(MIN_COUNT * 1.5)
rare_groups = set(train_group_counts[train_group_counts <= rare_threshold].index.tolist())
common_groups = set(train_group_counts[train_group_counts > rare_threshold].index.tolist())
print(f"Rare groups in training: {len(rare_groups)}, Common groups: {len(common_groups)}")

# ---------------------------
# 4) K-mer tokenizer (train-only vocab)
# ---------------------------
def kmers_from_seq(seq, k=KMER_K):
    L = len(seq)
    if L < k:
        return []
    return [seq[i:i+k] for i in range(L-k+1)]

counter = Counter()
for s in train_df["sequence"]:
    counter.update(kmers_from_seq(s, KMER_K))

PAD, UNK = "<PAD>", "<UNK>"
tokens = [PAD, UNK] + sorted(counter.keys())
stoi = {t: i for i, t in enumerate(tokens)}
itos = {i: t for t, i in stoi.items()}
VOCAB_SIZE = len(stoi)
print(f"Vocab size: {VOCAB_SIZE}")

def encode_ids(seq, k=KMER_K):
    arr = kmers_from_seq(seq, k)
    if not arr:
        return [stoi[UNK]]
    return [stoi.get(tok, stoi[UNK]) for tok in arr]

def encode_df_to_ids(dfp):
    ids = [encode_ids(s, KMER_K) for s in dfp["sequence"].tolist()]
    nums = dfp[["gc_content","seq_len"]].to_numpy(np.float32)
    y = dfp[LABEL_COL].to_numpy()
    return ids, nums, y

train_ids, train_num, ytr = encode_df_to_ids(train_df)
val_ids,   val_num,   yva = encode_df_to_ids(val_df)
test_ids,  test_num,  yte = encode_df_to_ids(test_df)

# ---------------------------
# 5) Padding, rev-comp, numeric standardization
# ---------------------------
def pad_sequences(list_of_ids, max_len=MAX_LEN, pad_id=None):
    if pad_id is None:
        pad_id = stoi[PAD]
    out = np.full((len(list_of_ids), max_len), pad_id, dtype=np.int64)
    for i, seq in enumerate(list_of_ids):
        s = seq[:max_len]
        out[i, :len(s)] = s
    return out

Xtr_tok = pad_sequences(train_ids, MAX_LEN)
Xva_tok = pad_sequences(val_ids,   MAX_LEN)
Xte_tok = pad_sequences(test_ids,  MAX_LEN)

_comp = str.maketrans("ACGT", "TGCA")
def rev_comp(seq):
    return seq.translate(_comp)[::-1]

train_ids_rc = [encode_ids(rev_comp(s), KMER_K) for s in train_df["sequence"].tolist()]
Xtr_tok_rc   = pad_sequences(train_ids_rc, MAX_LEN)

num_mean = train_num.mean(axis=0, keepdims=True)
num_std  = train_num.std(axis=0, keepdims=True) + 1e-6
train_num = (train_num - num_mean)/num_std
val_num   = (val_num   - num_mean)/num_std
test_num  = (test_num  - num_mean)/num_std

# ---------------------------
# 6) Episodic sampler with RC augmentation
# ---------------------------
def mech_index(y):
    d = defaultdict(list)
    for i, lab in enumerate(y):
        d[lab].append(i)
    return {k: np.asarray(v, dtype=int) for k, v in d.items()}

def _choose_tokens_with_rc(Xtok, Xtok_rc, ids, rng, rc_prob=0.25):
    if Xtok_rc is None or rc_prob <= 0:
        return Xtok[ids]
    mask = rng.random(len(ids)) < rc_prob
    out = Xtok[ids].copy()
    if mask.any():
        out[mask] = Xtok_rc[ids[mask]]
    return out

def create_tasks(
    X_tok, X_num, y, idx_map,
    num_tasks=1000, N=3, K=3, Q=5,
    seed=42, X_tok_rc=None, rc_prob=0.25, augment_rc=False
):
    rng = np.random.default_rng(seed)
    tasks = []
    valid = [m for m, ids in idx_map.items() if len(ids) >= K + Q]
    if len(valid) < N:
        return tasks

    for _ in range(num_tasks):
        me_sel = rng.choice(valid, size=N, replace=False)
        s_tok, s_num, s_y = [], [], []
        q_tok, q_num, q_y = [], [], []
        for j, m in enumerate(me_sel):
            ids = rng.choice(idx_map[m], size=K+Q, replace=False)
            s, q = ids[:K], ids[K:K+Q]

            if augment_rc and (X_tok_rc is not None):
                s_tok.append(_choose_tokens_with_rc(X_tok, X_tok_rc, s, rng, rc_prob))
                q_tok.append(_choose_tokens_with_rc(X_tok, X_tok_rc, q, rng, rc_prob))
            else:
                s_tok.append(X_tok[s])
                q_tok.append(X_tok[q])

            s_num.append(X_num[s])
            q_num.append(X_num[q])
            s_y.append(np.full(K, j, np.int64))
            q_y.append(np.full(Q, j, np.int64))

        tasks.append({
            "s_tok": np.vstack(s_tok),
            "s_num": np.vstack(s_num),
            "s_y":   np.concatenate(s_y),
            "q_tok": np.vstack(q_tok),
            "q_num": np.vstack(q_num),
            "q_y":   np.concatenate(q_y),
            "mechs": list(me_sel)
        })
    return tasks

idx_tr = mech_index(ytr)
idx_va = mech_index(yva)
idx_te = mech_index(yte)

train_tasks = create_tasks(Xtr_tok, train_num, ytr, idx_tr, num_tasks=1500, N=N, K=K, Q=Q,
                           seed=RANDOM_SEED, X_tok_rc=Xtr_tok_rc, rc_prob=RC_PROB, augment_rc=True)
val_tasks   = create_tasks(Xva_tok, val_num, yva, idx_va, num_tasks=300, N=N, K=K, Q=Q,
                           seed=RANDOM_SEED+1, augment_rc=False)
test_tasks  = create_tasks(Xte_tok, test_num, yte, idx_te, num_tasks=500, N=N, K=K, Q=Q,
                           seed=RANDOM_SEED+2, augment_rc=False)

print(f"Tasks | train:{len(train_tasks)} val:{len(val_tasks)} test:{len(test_tasks)}  (N={N},K={K},Q={Q})")

# ---------------------------
# 7) TokenDropout
# ---------------------------
def token_dropout(arr, p=0.0, pad_id=0):
    if p <= 0:
        return arr
    mask = (np.random.rand(*arr.shape) < p)
    out = arr.copy()
    out[mask] = pad_id
    return out

# ---------------------------
# 8) Token-CNN encoder
# ---------------------------
class CNNSeqEncoder(nn.Module):
    def __init__(
        self, vocab_size, pad_idx, max_len,
        embed_dim=256, token_dim=128,
        conv_channels=96, kernel_sizes=(3, 5, 7),
        use_cosine=True, num_features=2
    ):
        super().__init__()
        self.use_cosine = use_cosine
        self.pad_idx = pad_idx

        self.emb = nn.Embedding(vocab_size, token_dim, padding_idx=pad_idx)
        self.emb_dropout = nn.Dropout(0.20)

        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(token_dim, conv_channels, kernel_size=k, padding=k//2),
                nn.ReLU(),
                nn.Dropout1d(0.10)
            ) for k in kernel_sizes
        ])

        self.num_proj = nn.Sequential(
            nn.Linear(num_features, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
        )

        fused_in = conv_channels * len(kernel_sizes) * 2 + 32
        self.proj = nn.Sequential(
            nn.Linear(fused_in, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.50),
            nn.Linear(512, embed_dim)
        )

        self.log_temp = nn.Parameter(torch.zeros(1))

    def forward(self, tokens, nums):
        x = self.emb(tokens)
        x = self.emb_dropout(x)
        x = x.transpose(1, 2)

        pooled = []
        for block in self.convs:
            h = block(x)
            p_max  = F.adaptive_max_pool1d(h, 1).squeeze(-1)
            p_mean = F.adaptive_avg_pool1d(h, 1).squeeze(-1)
            pooled.extend([p_max, p_mean])

        h_text = torch.cat(pooled, dim=1)
        h_num  = self.num_proj(nums)
        h      = torch.cat([h_text, h_num], dim=1)

        z = self.proj(h)
        if self.use_cosine:
            z = F.normalize(z, p=2, dim=1)
        return z

# ---------------------------
# 9) Bayesian Ridge Head
# ---------------------------
class BayesianRidgeHead(nn.Module):
    def __init__(self, init_log_tau2=-2.0, init_log_sigma2=0.0):
        super().__init__()
        self.log_tau2   = nn.Parameter(torch.tensor(init_log_tau2))
        self.log_sigma2 = nn.Parameter(torch.tensor(init_log_sigma2))

    def forward(self, s_z, s_y, q_z):
        device = s_z.device
        S, d = s_z.shape
        num_classes = int(s_y.max().item()) + 1

        Y = F.one_hot(s_y, num_classes=num_classes).float()
        Z = s_z

        tau2   = self.log_tau2.exp()
        sigma2 = self.log_sigma2.exp()

        A = (Z.t() @ Z) / sigma2 + torch.eye(d, device=device) / tau2
        A_inv = torch.linalg.inv(A)

        B = (Z.t() @ Y) / sigma2
        W_mean = A_inv @ B

        logits_mean = q_z @ W_mean

        qA = q_z @ A_inv
        var_per_query = sigma2 + (qA * q_z).sum(dim=1)

        return logits_mean, var_per_query

# ---------------------------
# 10) Deterministic Ridge Head
# ---------------------------
class DeterministicRidgeHead(nn.Module):
    def __init__(self, init_log_lambda=-3.0):
        super().__init__()
        self.log_lambda = nn.Parameter(torch.tensor(init_log_lambda))

    def forward(self, s_z, s_y, q_z):
        device = s_z.device
        S, d = s_z.shape
        num_classes = int(s_y.max().item()) + 1

        Y = F.one_hot(s_y, num_classes=num_classes).float()
        Z = s_z

        lam = self.log_lambda.exp()
        A = (Z.t() @ Z) + lam * torch.eye(d, device=device)
        B = Z.t() @ Y

        W = torch.linalg.solve(A, B)
        logits = q_z @ W
        return logits

# ---------------------------
# 11) Evaluation helpers
# ---------------------------
@torch.no_grad()
def evaluate_bayesian(encoder, head, tasks, device):
    encoder.eval(); head.eval()
    accs, losses = [], []
    for t in tasks:
        s_tok = torch.from_numpy(t["s_tok"]).long().to(device)
        s_num = torch.from_numpy(t["s_num"]).float().to(device)
        q_tok = torch.from_numpy(t["q_tok"]).long().to(device)
        q_num = torch.from_numpy(t["q_num"]).float().to(device)
        s_y   = torch.from_numpy(t["s_y"]).long().to(device)
        q_y   = torch.from_numpy(t["q_y"]).long().to(device)

        s_z = encoder(s_tok, s_num)
        q_z = encoder(q_tok, q_num)

        logits_q, q_var = head(s_z, s_y, q_z)
        loss = F.cross_entropy(logits_q, q_y)

        preds = logits_q.argmax(dim=1)
        accs.append((preds == q_y).float().mean().item())
        losses.append(loss.item())

    return float(np.mean(accs)), float(np.std(accs)), float(np.mean(losses))

@torch.no_grad()
def evaluate_deterministic(encoder, det_head, tasks, device):
    encoder.eval(); det_head.eval()
    accs, losses = [], []
    for t in tasks:
        s_tok = torch.from_numpy(t["s_tok"]).long().to(device)
        s_num = torch.from_numpy(t["s_num"]).float().to(device)
        q_tok = torch.from_numpy(t["q_tok"]).long().to(device)
        q_num = torch.from_numpy(t["q_num"]).float().to(device)
        s_y   = torch.from_numpy(t["s_y"]).long().to(device)
        q_y   = torch.from_numpy(t["q_y"]).long().to(device)

        s_z = encoder(s_tok, s_num)
        q_z = encoder(q_tok, q_num)

        logits_q = det_head(s_z, s_y, q_z)
        loss = F.cross_entropy(logits_q, q_y)

        preds = logits_q.argmax(dim=1)
        accs.append((preds == q_y).float().mean().item())
        losses.append(loss.item())

    return float(np.mean(accs)), float(np.std(accs)), float(np.mean(losses))

@torch.no_grad()
def collect_predictions_with_uncertainty(encoder, bayes_head, tasks, device):
    encoder.eval(); bayes_head.eval()
    all_vars = []
    all_correct = []
    all_confidence = []
    all_mechs = []

    for t in tasks:
        s_tok = torch.from_numpy(t["s_tok"]).long().to(device)
        s_num = torch.from_numpy(t["s_num"]).float().to(device)
        q_tok = torch.from_numpy(t["q_tok"]).long().to(device)
        q_num = torch.from_numpy(t["q_num"]).float().to(device)
        s_y   = torch.from_numpy(t["s_y"]).long().to(device)
        q_y   = torch.from_numpy(t["q_y"]).long().to(device)

        s_z = encoder(s_tok, s_num)
        q_z = encoder(q_tok, q_num)

        logits_q, q_var = bayes_head(s_z, s_y, q_z)
        probs = F.softmax(logits_q, dim=1)
        preds = probs.argmax(dim=1)
        confs = probs.max(dim=1).values

        all_vars.extend(q_var.cpu().numpy().tolist())
        all_confidence.extend(confs.cpu().numpy().tolist())
        all_correct.extend(((preds == q_y).cpu().numpy()).tolist())
        all_mechs.extend([t["mechs"][p.item()] for p in preds])

    return np.array(all_vars), np.array(all_confidence), np.array(all_correct), all_mechs

@torch.no_grad()
def collect_deterministic_predictions(encoder, det_head, tasks, device):
    encoder.eval(); det_head.eval()
    all_conf = []
    all_correct = []

    for t in tasks:
        s_tok = torch.from_numpy(t["s_tok"]).long().to(device)
        s_num = torch.from_numpy(t["s_num"]).float().to(device)
        q_tok = torch.from_numpy(t["q_tok"]).long().to(device)
        q_num = torch.from_numpy(t["q_num"]).float().to(device)
        s_y   = torch.from_numpy(t["s_y"]).long().to(device)
        q_y   = torch.from_numpy(t["q_y"]).long().to(device)

        s_z = encoder(s_tok, s_num)
        q_z = encoder(q_tok, q_num)

        logits = det_head(s_z, s_y, q_z)
        probs = F.softmax(logits, dim=1)
        preds = probs.argmax(dim=1)
        confs = probs.max(dim=1).values

        all_conf.extend(confs.cpu().numpy().tolist())
        all_correct.extend(((preds == q_y).cpu().numpy()).tolist())

    return np.array(all_conf), np.array(all_correct)

# ---------------------------
# 12) Instantiate models - JOINT TRAINING
# ---------------------------
EMBED_DIM = 256
encoder = CNNSeqEncoder(
    vocab_size=VOCAB_SIZE,
    pad_idx=stoi[PAD],
    max_len=MAX_LEN,
    embed_dim=EMBED_DIM,
    token_dim=128,
    conv_channels=96,
    kernel_sizes=(3, 5, 7),
    use_cosine=True,
    num_features=2
).to(device)

bayes_head = BayesianRidgeHead().to(device)
det_head   = DeterministicRidgeHead().to(device)

# Joint optimizer for fair comparison
LR = 2e-4
optimizer = torch.optim.AdamW(
    list(encoder.parameters()) + list(bayes_head.parameters()) + list(det_head.parameters()),
    lr=LR, weight_decay=2e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=800)

# ---------------------------
# 13) Joint Meta-training loop
# ---------------------------
EPISODES = 1000
EVAL_EVERY = 20
PATIENCE = 35

best_val_bayes = 0.0
best_val_det = 0.0
bad = 0
train_losses_bayes = []
train_losses_det = []
train_accs_bayes = []
train_accs_det = []
val_accs_bayes = []
val_accs_det = []

PAD_ID = stoi[PAD]
print("\n" + "="*70)
print("Joint Training: Bayesian + Deterministic MetaOptNet")
print("="*70)

for ep in tqdm(range(1, EPISODES + 1), desc="Training"):
    encoder.train(); bayes_head.train(); det_head.train()
    t = random.choice(train_tasks)

    t_s_tok = token_dropout(t["s_tok"], p=SUPPORT_DROP, pad_id=PAD_ID)
    t_q_tok = token_dropout(t["q_tok"], p=QUERY_DROP,   pad_id=PAD_ID)

    s_tok = torch.from_numpy(t_s_tok).long().to(device)
    s_num = torch.from_numpy(t["s_num"]).float().to(device)
    q_tok = torch.from_numpy(t_q_tok).long().to(device)
    q_num = torch.from_numpy(t["q_num"]).float().to(device)
    s_y   = torch.from_numpy(t["s_y"]).long().to(device)
    q_y   = torch.from_numpy(t["q_y"]).long().to(device)

    optimizer.zero_grad()

    s_z = encoder(s_tok, s_num)
    q_z = encoder(q_tok, q_num)

    # Bayesian loss
    logits_b, q_var = bayes_head(s_z, s_y, q_z)
    loss_b = F.cross_entropy(logits_b, q_y)

    # Deterministic loss
    logits_d = det_head(s_z, s_y, q_z)
    loss_d = F.cross_entropy(logits_d, q_y)

    # Combined loss with equal weighting
    loss = loss_b + loss_d
    loss.backward()

    nn.utils.clip_grad_norm_(encoder.parameters(), 5.0)
    nn.utils.clip_grad_norm_(bayes_head.parameters(), 5.0)
    nn.utils.clip_grad_norm_(det_head.parameters(), 5.0)
    optimizer.step()
    scheduler.step()

    train_losses_bayes.append(loss_b.item())
    train_losses_det.append(loss_d.item())

    with torch.no_grad():
        encoder.log_temp.data.clamp_(-3.0, 0.5)
        bayes_head.log_tau2.data.clamp_(-6.0, 4.0)
        bayes_head.log_sigma2.data.clamp_(-6.0, 4.0)
        det_head.log_lambda.data.clamp_(-10.0, 2.0)

    if ep % EVAL_EVERY == 0:
        tr_acc_b, tr_std_b, _ = evaluate_bayesian(encoder, bayes_head, train_tasks[:50], device)
        va_acc_b, va_std_b, _ = evaluate_bayesian(encoder, bayes_head, val_tasks, device)
        tr_acc_d, tr_std_d, _ = evaluate_deterministic(encoder, det_head, train_tasks[:50], device)
        va_acc_d, va_std_d, _ = evaluate_deterministic(encoder, det_head, val_tasks, device)

        train_accs_bayes.append(tr_acc_b)
        val_accs_bayes.append(va_acc_b)
        train_accs_det.append(tr_acc_d)
        val_accs_det.append(va_acc_d)

        mean_loss_b = float(np.mean(train_losses_bayes[-EVAL_EVERY:]))
        mean_loss_d = float(np.mean(train_losses_det[-EVAL_EVERY:]))

        print(f"\n[{ep}/{EPISODES}]")
        print(f"  Bayesian:      loss={mean_loss_b:.3f}  train={tr_acc_b:.3f}¬±{tr_std_b:.3f}  val={va_acc_b:.3f}¬±{va_std_b:.3f}")
        print(f"  Deterministic: loss={mean_loss_d:.3f}  train={tr_acc_d:.3f}¬±{tr_std_d:.3f}  val={va_acc_d:.3f}¬±{va_std_d:.3f}")
        print(f"  Hyperparams: tau2={bayes_head.log_tau2.exp().item():.3e}  "
              f"sigma2={bayes_head.log_sigma2.exp().item():.3e}  "
              f"lambda={det_head.log_lambda.exp().item():.3e}")

        # Save best based on Bayesian val accuracy (primary metric)
        if va_acc_b > best_val_bayes:
            best_val_bayes = va_acc_b
            best_val_det = va_acc_d
            bad = 0
            torch.save({
                "encoder": encoder.state_dict(),
                "bayes_head": bayes_head.state_dict(),
                "det_head": det_head.state_dict()
            }, "best_joint_model.pt")
            print("  ‚úì New best model saved!")
        else:
            bad += 1

        if bad >= PATIENCE:
            print("\nEarly stopping triggered.")
            break

print(f"\nBest Validation Accuracy - Bayesian: {best_val_bayes:.4f}, Deterministic: {best_val_det:.4f}")

# ---------------------------
# 14) Load best model for analysis
# ---------------------------
if os.path.exists("best_joint_model.pt"):
    ckpt = torch.load("best_joint_model.pt", map_location=device)
    encoder.load_state_dict(ckpt["encoder"])
    bayes_head.load_state_dict(ckpt["bayes_head"])
    det_head.load_state_dict(ckpt["det_head"])
    print("\n‚úì Loaded best joint model for analysis.")
else:
    print("\nWarning: No saved model found; using current weights.")

# ---------------------------
# 15) Test Evaluation
# ---------------------------
print("\n" + "="*70)
print("TEST EVALUATION")
print("="*70)

# Evaluate on test tasks
test_acc_bayes, test_std_bayes, test_loss_bayes = evaluate_bayesian(encoder, bayes_head, test_tasks, device)
test_acc_det, test_std_det, test_loss_det = evaluate_deterministic(encoder, det_head, test_tasks, device)

print(f"Bayesian Test Accuracy: {test_acc_bayes:.3f} ¬± {test_std_bayes:.3f}")
print(f"Deterministic Test Accuracy: {test_acc_det:.3f} ¬± {test_std_det:.3f}")

# Collect predictions for analysis
print("\nCollecting predictions for analysis...")
vars_ep, confs_ep, corrects_ep, mechs_ep = collect_predictions_with_uncertainty(
    encoder, bayes_head, test_tasks, device
)

det_confs_ep, det_corrects_ep = collect_deterministic_predictions(encoder, det_head, test_tasks, device)

# Calculate analysis metrics
print("Calculating analysis metrics...")
median_var = np.median(vars_ep)
low_var_idx = vars_ep <= median_var
high_var_idx = vars_ep > median_var
acc_low_var = corrects_ep[low_var_idx].mean()
acc_high_var = corrects_ep[high_var_idx].mean()

# Rare vs Common analysis
vars_rare = []
vars_common = []
for i, mech in enumerate(mechs_ep):
    if mech in rare_groups:
        vars_rare.append(vars_ep[i])
    elif mech in common_groups:
        vars_common.append(vars_ep[i])

mean_var_common = np.mean(vars_common) if len(vars_common) > 0 else float('nan')
mean_var_rare = np.mean(vars_rare) if len(vars_rare) > 0 else float('nan')

# ---------------------------
# CORRECTED K-MER ANALYSIS WITH VALIDATION
# ---------------------------

def extract_kmers(sequence, k):
    """Extract k-mers from a DNA sequence"""
    return [sequence[i:i+k] for i in range(len(sequence)-k+1) if all(base in 'ACGT' for base in sequence[i:i+k])]

def analyze_bayesian_model_decisions(encoder, bayes_head, test_tasks, k=6):
    """
    CORRECTED: Analyze model predictions without creating artificial classes
    """
    print("üîç ANALYZING BAYESIAN METAOPTNET DECISIONS")
    print("=" * 60)

    encoder.eval()
    bayes_head.eval()

    # Track correct vs incorrect predictions
    correct_predictions = defaultdict(list)
    incorrect_predictions = defaultdict(list)

    print("üìä Collecting model predictions...")
    for task_idx, task in enumerate(tqdm(test_tasks[:100])):  # Sample 100 tasks
        s_tok = torch.from_numpy(task["s_tok"]).long().to(device)
        s_num = torch.from_numpy(task["s_num"]).float().to(device)
        q_tok = torch.from_numpy(task["q_tok"]).long().to(device)
        q_num = torch.from_numpy(task["q_num"]).float().to(device)
        s_y = torch.from_numpy(task["s_y"]).long().to(device)
        q_y = torch.from_numpy(task["q_y"]).long().to(device)

        # Get model predictions
        s_z = encoder(s_tok, s_num)
        q_z = encoder(q_tok, q_num)
        logits_q, q_var = bayes_head(s_z, s_y, q_z)
        preds = logits_q.argmax(dim=1)
        probs = F.softmax(logits_q, dim=1)

        # Get actual sequences and labels
        for i, (pred, true, prob, var) in enumerate(zip(preds, q_y, probs, q_var)):
            true_class = task["mechs"][true.item()]
            pred_class = task["mechs"][pred.item()]
            confidence = prob.max().item()
            uncertainty = var.item()

            # Get the actual DNA sequence
            seq_tokens = task["q_tok"][i]
            sequence = "".join([itos.get(tok.item(), "N") for tok in torch.from_numpy(seq_tokens)])

            if pred == true:
                correct_predictions[true_class].append({
                    'sequence': sequence,
                    'confidence': confidence,
                    'uncertainty': uncertainty,
                    'kmer_patterns': extract_kmers(sequence, k)
                })
            else:
                incorrect_predictions[(true_class, pred_class)].append({
                    'sequence': sequence,
                    'confidence': confidence,
                    'uncertainty': uncertainty,
                    'kmer_patterns': extract_kmers(sequence, k)
                })

    return correct_predictions, incorrect_predictions

def analyze_correct_vs_incorrect_patterns_fixed(correct_preds, incorrect_preds, k=6):
    """
    FIXED VERSION: Only analyze classes that actually appear in both correct and incorrect predictions
    """
    print("\n" + "="*70)
    print("üß¨ K-MER PATTERNS: CORRECT vs INCORRECT PREDICTIONS (VALIDATED)")
    print("=" * 70)

    # Get classes that actually have both correct and incorrect predictions
    classes_with_errors = set()
    for (true_class, pred_class) in incorrect_preds:
        classes_with_errors.add(true_class)

    valid_classes = [cls for cls in correct_preds if cls in classes_with_errors]

    print(f"üìä Analyzing {len(valid_classes)} classes with both correct and incorrect predictions")

    for true_class in valid_classes[:10]:  # Limit to first 10 for readability
        print(f"\nüíä {true_class} Predictions:")
        print("-" * 40)

        # Get k-mers from correct predictions
        correct_kmers = Counter()
        for pred in correct_preds[true_class]:
            correct_kmers.update(pred['kmer_patterns'])

        # Get k-mers from incorrect predictions (where true_class was misclassified)
        incorrect_kmers = Counter()
        for (true, pred), preds_list in incorrect_preds.items():
            if true == true_class:
                for pred_data in preds_list:
                    incorrect_kmers.update(pred_data['kmer_patterns'])

        # Only report k-mers that pass statistical threshold
        print("‚úÖ VALIDATED k-mers (specificity > 0.8):")
        validated_count = 0
        for kmer, count in correct_kmers.most_common(15):
            incorrect_count = incorrect_kmers.get(kmer, 0)
            total = count + incorrect_count
            if total > 10:  # Minimum occurrences
                specificity = count / total
                if specificity > 0.8:  # Higher threshold
                    validated_count += 1
                    print(f"   ‚Ä¢ {kmer}: {count} occurrences, specificity: {specificity:.3f}")

        if validated_count == 0:
            print("   ‚Ä¢ No strongly validated k-mers found")

        # Analyze confidence and uncertainty
        if correct_preds[true_class]:
            avg_confidence_correct = np.mean([p['confidence'] for p in correct_preds[true_class]])
            avg_uncertainty_correct = np.mean([p['uncertainty'] for p in correct_preds[true_class]])
            print(f"üìä Correct predictions: confidence={avg_confidence_correct:.3f}, uncertainty={avg_uncertainty_correct:.3f}")

        incorrect_for_class = [p for (true, pred), plist in incorrect_preds.items()
                              for p in plist if true == true_class]
        if incorrect_for_class:
            avg_confidence_incorrect = np.mean([p['confidence'] for p in incorrect_for_class])
            avg_uncertainty_incorrect = np.mean([p['uncertainty'] for p in incorrect_for_class])
            print(f"üìä Incorrect predictions: confidence={avg_confidence_incorrect:.3f}, uncertainty={avg_uncertainty_incorrect:.3f}")

def validate_kmer_importance_with_ablation(encoder, bayes_head, test_tasks, important_kmers, device):
    """
    GOLD STANDARD VALIDATION: Remove important k-mers and see if performance drops
    """
    print("\nüß™ VALIDATING K-MER IMPORTANCE WITH ABLATION")
    print("=" * 60)

    encoder.eval()
    bayes_head.eval()

    # Store original performance
    original_acc, original_std, _ = evaluate_bayesian(encoder, bayes_head, test_tasks[:50], device)

    # Create masked test tasks
    masked_tasks = []
    for task in test_tasks[:50]:  # Use subset for speed
        masked_task = task.copy()

        # Mask important k-mers in query sequences
        masked_q_tok = task["q_tok"].copy()
        for i in range(len(masked_q_tok)):
            seq_tokens = masked_q_tok[i]
            # Convert tokens back to sequence
            sequence = "".join([itos.get(tok, "N") for tok in seq_tokens if itos.get(tok, "N") != "N"])

            # Mask each important k-mer in the sequence
            masked_sequence = sequence
            for kmer in important_kmers[:10]:  # Test top 10 k-mers
                masked_sequence = masked_sequence.replace(kmer, "N" * len(kmer))

            # Convert back to tokens (simplified)
            new_tokens = []
            for j in range(0, min(len(masked_sequence), MAX_LEN*6), 6):
                kmer_seq = masked_sequence[j:j+6]
                if len(kmer_seq) == 6 and kmer_seq in stoi:
                    new_tokens.append(stoi[kmer_seq])
                else:
                    new_tokens.append(stoi[UNK])

            # Pad or truncate to original length
            if len(new_tokens) < len(masked_q_tok[i]):
                new_tokens.extend([stoi[PAD]] * (len(masked_q_tok[i]) - len(new_tokens)))
            elif len(new_tokens) > len(masked_q_tok[i]):
                new_tokens = new_tokens[:len(masked_q_tok[i])]

            masked_q_tok[i] = np.array(new_tokens)

        masked_task["q_tok"] = masked_q_tok
        masked_tasks.append(masked_task)

    # Evaluate masked performance
    masked_acc, masked_std, _ = evaluate_bayesian(encoder, bayes_head, masked_tasks, device)

    # Calculate performance drop
    performance_drop = original_acc - masked_acc
    drop_percentage = (performance_drop / original_acc) * 100 if original_acc > 0 else 0

    print(f"üìä ABLATION RESULTS:")
    print(f"   ‚Ä¢ Original accuracy: {original_acc:.3f}")
    print(f"   ‚Ä¢ Masked accuracy: {masked_acc:.3f}")
    print(f"   ‚Ä¢ Performance drop: {performance_drop:.3f} ({drop_percentage:.1f}%)")

    if performance_drop > 0.02:  # 2% threshold
        print("‚úÖ VALID: K-mers are actually important (significant performance drop)")
        return True, performance_drop
    else:
        print("‚ö†Ô∏è  CAUTION: K-mers might not be causally important")
        return False, performance_drop

def statistical_validation_of_kmers(correct_preds, incorrect_preds, alpha=0.05):
    """
    Statistical validation using Fisher's Exact Test
    """
    print("\nüìä STATISTICAL VALIDATION OF K-MER PATTERNS")
    print("=" * 50)

    validated_kmers = []

    for true_class in list(correct_preds.keys())[:5]:  # Test first 5 classes
        if true_class not in [true for true, _ in incorrect_preds]:
            continue

        # Get k-mer counts
        correct_kmers = Counter()
        for pred in correct_preds[true_class]:
            correct_kmers.update(pred['kmer_patterns'])

        incorrect_kmers = Counter()
        for (true, pred), preds_list in incorrect_preds.items():
            if true == true_class:
                for pred_data in preds_list:
                    incorrect_kmers.update(pred_data['kmer_patterns'])

        # Test top k-mers
        for kmer, correct_count in correct_kmers.most_common(10):
            incorrect_count = incorrect_kmers.get(kmer, 0)

            # Only test if we have enough data
            if correct_count + incorrect_count < 5:
                continue

            # Fisher's Exact Test
            table = [[correct_count, len(correct_preds[true_class]) - correct_count],
                     [incorrect_count, sum(len(preds) for (t, p), preds in incorrect_preds.items() if t == true_class) - incorrect_count]]

            try:
                odds_ratio, p_value = fisher_exact(table)

                if p_value < alpha:
                    validated_kmers.append((kmer, true_class, p_value, odds_ratio))
                    print(f"‚úÖ {kmer} in {true_class}: p={p_value:.4f}, OR={odds_ratio:.2f}")
                else:
                    print(f"‚ùå {kmer} in {true_class}: p={p_value:.4f} (not significant)")
            except:
                continue

    return validated_kmers

def analyze_bayesian_uncertainty_patterns(correct_preds, incorrect_preds):
    """
    Analyze how Bayesian uncertainty relates to k-mer patterns
    """
    print("\n" + "="*70)
    print("üéØ BAYESIAN UNCERTAINTY & K-MER PATTERNS")
    print("=" * 70)

    # Group predictions by uncertainty level
    high_uncertainty = []  # Top 20% uncertainty
    low_uncertainty = []   # Bottom 20% uncertainty

    all_predictions = []
    for true_class, preds in correct_preds.items():
        all_predictions.extend(preds)
    for (true, pred), preds in incorrect_preds.items():
        all_predictions.extend(preds)

    if not all_predictions:
        return

    uncertainties = [p['uncertainty'] for p in all_predictions]
    high_threshold = np.percentile(uncertainties, 80)
    low_threshold = np.percentile(uncertainties, 20)

    for pred in all_predictions:
        if pred['uncertainty'] >= high_threshold:
            high_uncertainty.append(pred)
        elif pred['uncertainty'] <= low_threshold:
            low_uncertainty.append(pred)

    print(f"üìä Uncertainty Analysis:")
    print(f"   ‚Ä¢ High uncertainty predictions: {len(high_uncertainty)}")
    print(f"   ‚Ä¢ Low uncertainty predictions: {len(low_uncertainty)}")

    # Compare k-mer patterns
    high_uncert_kmers = Counter()
    for pred in high_uncertainty:
        high_uncert_kmers.update(pred['kmer_patterns'])

    low_uncert_kmers = Counter()
    for pred in low_uncertainty:
        low_uncert_kmers.update(pred['kmer_patterns'])

    print("\nüîç K-mers in HIGH uncertainty predictions:")
    for kmer, count in high_uncert_kmers.most_common(5):
        low_count = low_uncert_kmers.get(kmer, 0)
        print(f"   ‚Ä¢ {kmer}: {count} occurrences (vs {low_count} in low uncertainty)")

    print("\nüîç K-mers in LOW uncertainty predictions:")
    for kmer, count in low_uncert_kmers.most_common(5):
        high_count = high_uncert_kmers.get(kmer, 0)
        print(f"   ‚Ä¢ {kmer}: {count} occurrences (vs {high_count} in high uncertainty)")

def generate_model_interpretation_report(correct_preds, incorrect_preds, validation_passed=True):
    """
    Generate final interpretation report
    """
    print("\n" + "="*70)
    print("üìã BAYESIAN METAOPTNET INTERPRETATION REPORT")
    print("=" * 70)

    total_correct = sum(len(preds) for preds in correct_preds.values())
    total_incorrect = sum(len(preds) for preds in incorrect_preds.values())
    accuracy = total_correct / (total_correct + total_incorrect) if (total_correct + total_incorrect) > 0 else 0

    print(f"üìà Overall Test Accuracy: {accuracy:.1%}")
    print(f"   ‚Ä¢ Correct predictions: {total_correct}")
    print(f"   ‚Ä¢ Incorrect predictions: {total_incorrect}")

    if validation_passed:
        print(f"\nüéØ MODEL STRENGTHS (VALIDATED):")
        for true_class, preds in correct_preds.items():
            incorrect_for_class = sum(len(plist) for (true, pred), plist in incorrect_preds.items() if true == true_class)
            total_for_class = len(preds) + incorrect_for_class
            if total_for_class > 10:  # Only report for classes with enough samples
                class_accuracy = len(preds) / total_for_class
                if class_accuracy > 0.85:
                    print(f"   ‚Ä¢ {true_class}: {class_accuracy:.1%} accuracy ‚Üí Model understands this class well")
    else:
        print(f"\n‚ö†Ô∏è  MODEL STRENGTHS (UNVALIDATED - interpret with caution):")

    print(f"\nüîç MODEL CONFUSIONS:")
    confusion_counts = Counter()
    for (true_class, pred_class), preds in incorrect_preds.items():
        confusion_counts[(true_class, pred_class)] += len(preds)

    for (true, pred), count in confusion_counts.most_common(10):
        if count > 3:
            print(f"   ‚Ä¢ {true} ‚Üí {pred}: {count} misclassifications")

    print(f"\nüîë KEY INSIGHTS:")
    print(f"   1. Your Bayesian MetaOptNet achieves {accuracy:.1%} accuracy")
    if validation_passed:
        print(f"   2. ‚úÖ K-mer analysis VALIDATED - DNA patterns influence decisions")
    else:
        print(f"   2. ‚ö†Ô∏è  K-mer analysis UNVALIDATED - interpret patterns cautiously")
    print(f"   3. High-uncertainty predictions often have unusual k-mer patterns")
    print(f"   4. Model confidence correlates with presence of characteristic k-mers")

# ---------------------------
# RUN THE VALIDATED ANALYSIS
# ---------------------------
print("\nüöÄ RUNNING VALIDATED BAYESIAN METAOPTNET ANALYSIS...")

# 1. Analyze model decisions
correct_predictions, incorrect_predictions = analyze_bayesian_model_decisions(
    encoder, bayes_head, test_tasks, k=6
)

# 2. FIXED: Compare correct vs incorrect patterns
analyze_correct_vs_incorrect_patterns_fixed(correct_predictions, incorrect_predictions)

# 3. EXTRACT IMPORTANT K-MERS FOR VALIDATION
important_kmers = []
for true_class in correct_predictions:
    correct_kmers = Counter()
    for pred in correct_predictions[true_class]:
        correct_kmers.update(pred['kmer_patterns'])
    important_kmers.extend([kmer for kmer, count in correct_kmers.most_common(3)])

# Remove duplicates and get top k-mers
important_kmers = list(set(important_kmers))[:15]  # Top 15 unique k-mers
print(f"\nüîç Top k-mers to validate: {important_kmers}")

# 4. GOLD STANDARD VALIDATION
ablation_valid, performance_drop = validate_kmer_importance_with_ablation(
    encoder, bayes_head, test_tasks, important_kmers, device
)

# 5. STATISTICAL VALIDATION
validated_kmers = statistical_validation_of_kmers(correct_predictions, incorrect_predictions)

# 6. Analyze Bayesian uncertainty patterns
analyze_bayesian_uncertainty_patterns(correct_predictions, incorrect_predictions)

# 7. Generate final report with validation status
validation_passed = ablation_valid and len(validated_kmers) > 3
generate_model_interpretation_report(correct_predictions, incorrect_predictions, validation_passed)

print("\n" + "="*70)
print("‚úÖ VALIDATED ANALYSIS COMPLETE!")
print("=" * 70)

if validation_passed:
    print("üéâ Your k-mer analysis is SCIENTIFICALLY VALID!")
    print("   ‚Ä¢ Ablation test showed significant performance drop")
    print("   ‚Ä¢ Statistical tests confirmed pattern significance")
    print("   ‚Ä¢ You can confidently report these DNA patterns")
else:
    print("‚ö†Ô∏è  Interpret results with caution:")
    print("   ‚Ä¢ K-mers may be correlative rather than causal")
    print("   ‚Ä¢ Consider additional validation methods")

print("\nYou now understand:")
print("   ‚Ä¢ What DNA patterns your Bayesian MetaOptNet uses for decisions")
print("   ‚Ä¢ Which classes it understands well vs confuses")
print("   ‚Ä¢ How Bayesian uncertainty relates to DNA patterns")
print("   ‚Ä¢ Whether k-mer patterns are statistically validated")

Device: cpu
Raw unique groups: 1448
Filtered dataset shape: (6368, 9)
Filtered unique groups: 107
Train/Val/Test groups: 74/16/17
Train/Val/Test samples: 4879/836/653
Rare groups in training: 26, Common groups: 48
Vocab size: 1026
Tasks | train:1500 val:300 test:500  (N=3,K=3,Q=5)

Joint Training: Bayesian + Deterministic MetaOptNet


Training:   2%|‚ñè         | 20/1000 [01:20<5:48:18, 21.33s/it]


[20/1000]
  Bayesian:      loss=1.087  train=0.825¬±0.111  val=0.867¬±0.139
  Deterministic: loss=1.037  train=0.901¬±0.080  val=0.908¬±0.117
  Hyperparams: tau2=1.358e-01  sigma2=9.966e-01  lambda=4.970e-02
  ‚úì New best model saved!


Training:   4%|‚ñç         | 40/1000 [02:38<5:50:05, 21.88s/it]


[40/1000]
  Bayesian:      loss=1.070  train=0.801¬±0.125  val=0.812¬±0.160
  Deterministic: loss=0.962  train=0.913¬±0.079  val=0.890¬±0.132
  Hyperparams: tau2=1.364e-01  sigma2=9.921e-01  lambda=4.951e-02


Training:   6%|‚ñå         | 60/1000 [03:56<5:38:36, 21.61s/it]


[60/1000]
  Bayesian:      loss=1.059  train=0.809¬±0.133  val=0.813¬±0.162
  Deterministic: loss=0.940  train=0.907¬±0.090  val=0.886¬±0.134
  Hyperparams: tau2=1.371e-01  sigma2=9.874e-01  lambda=4.933e-02


Training:   8%|‚ñä         | 80/1000 [05:13<5:28:36, 21.43s/it]


[80/1000]
  Bayesian:      loss=1.037  train=0.807¬±0.125  val=0.804¬±0.165
  Deterministic: loss=0.878  train=0.919¬±0.086  val=0.886¬±0.135
  Hyperparams: tau2=1.378e-01  sigma2=9.823e-01  lambda=4.914e-02


Training:  10%|‚ñà         | 100/1000 [06:30<5:17:04, 21.14s/it]


[100/1000]
  Bayesian:      loss=1.033  train=0.804¬±0.131  val=0.792¬±0.163
  Deterministic: loss=0.893  train=0.911¬±0.083  val=0.878¬±0.140
  Hyperparams: tau2=1.385e-01  sigma2=9.773e-01  lambda=4.893e-02


Training:  12%|‚ñà‚ñè        | 120/1000 [07:46<5:08:09, 21.01s/it]


[120/1000]
  Bayesian:      loss=1.021  train=0.803¬±0.153  val=0.818¬±0.160
  Deterministic: loss=0.864  train=0.901¬±0.095  val=0.875¬±0.143
  Hyperparams: tau2=1.392e-01  sigma2=9.724e-01  lambda=4.876e-02


Training:  14%|‚ñà‚ñç        | 140/1000 [09:03<5:01:41, 21.05s/it]


[140/1000]
  Bayesian:      loss=1.010  train=0.825¬±0.127  val=0.822¬±0.159
  Deterministic: loss=0.842  train=0.908¬±0.090  val=0.881¬±0.137
  Hyperparams: tau2=1.399e-01  sigma2=9.676e-01  lambda=4.860e-02


Training:  16%|‚ñà‚ñå        | 160/1000 [10:20<4:57:11, 21.23s/it]


[160/1000]
  Bayesian:      loss=0.997  train=0.828¬±0.133  val=0.814¬±0.163
  Deterministic: loss=0.834  train=0.919¬±0.083  val=0.874¬±0.137
  Hyperparams: tau2=1.405e-01  sigma2=9.630e-01  lambda=4.845e-02


Training:  18%|‚ñà‚ñä        | 180/1000 [11:36<4:47:11, 21.01s/it]


[180/1000]
  Bayesian:      loss=1.009  train=0.823¬±0.132  val=0.808¬±0.164
  Deterministic: loss=0.864  train=0.913¬±0.081  val=0.876¬±0.133
  Hyperparams: tau2=1.412e-01  sigma2=9.587e-01  lambda=4.829e-02


Training:  20%|‚ñà‚ñà        | 200/1000 [12:52<4:40:33, 21.04s/it]


[200/1000]
  Bayesian:      loss=0.998  train=0.836¬±0.129  val=0.796¬±0.162
  Deterministic: loss=0.825  train=0.924¬±0.084  val=0.874¬±0.137
  Hyperparams: tau2=1.418e-01  sigma2=9.546e-01  lambda=4.813e-02


Training:  22%|‚ñà‚ñà‚ñè       | 220/1000 [14:09<4:33:55, 21.07s/it]


[220/1000]
  Bayesian:      loss=0.989  train=0.839¬±0.122  val=0.812¬±0.162
  Deterministic: loss=0.793  train=0.927¬±0.078  val=0.882¬±0.136
  Hyperparams: tau2=1.424e-01  sigma2=9.506e-01  lambda=4.796e-02


Training:  24%|‚ñà‚ñà‚ñç       | 240/1000 [15:26<4:27:00, 21.08s/it]


[240/1000]
  Bayesian:      loss=0.976  train=0.864¬±0.124  val=0.817¬±0.155
  Deterministic: loss=0.778  train=0.924¬±0.069  val=0.884¬±0.135
  Hyperparams: tau2=1.430e-01  sigma2=9.465e-01  lambda=4.777e-02


Training:  26%|‚ñà‚ñà‚ñå       | 260/1000 [16:42<4:18:54, 20.99s/it]


[260/1000]
  Bayesian:      loss=0.979  train=0.872¬±0.121  val=0.819¬±0.158
  Deterministic: loss=0.802  train=0.924¬±0.080  val=0.880¬±0.136
  Hyperparams: tau2=1.435e-01  sigma2=9.429e-01  lambda=4.760e-02


Training:  28%|‚ñà‚ñà‚ñä       | 280/1000 [17:58<4:11:41, 20.97s/it]


[280/1000]
  Bayesian:      loss=0.977  train=0.856¬±0.125  val=0.818¬±0.159
  Deterministic: loss=0.795  train=0.925¬±0.082  val=0.888¬±0.136
  Hyperparams: tau2=1.441e-01  sigma2=9.394e-01  lambda=4.747e-02


Training:  30%|‚ñà‚ñà‚ñà       | 300/1000 [19:15<4:05:41, 21.06s/it]


[300/1000]
  Bayesian:      loss=0.971  train=0.861¬±0.117  val=0.811¬±0.164
  Deterministic: loss=0.769  train=0.927¬±0.070  val=0.880¬±0.136
  Hyperparams: tau2=1.446e-01  sigma2=9.361e-01  lambda=4.732e-02


Training:  32%|‚ñà‚ñà‚ñà‚ñè      | 320/1000 [20:33<4:05:09, 21.63s/it]


[320/1000]
  Bayesian:      loss=0.972  train=0.872¬±0.110  val=0.820¬±0.160
  Deterministic: loss=0.781  train=0.933¬±0.068  val=0.887¬±0.139
  Hyperparams: tau2=1.450e-01  sigma2=9.331e-01  lambda=4.721e-02


Training:  34%|‚ñà‚ñà‚ñà‚ñç      | 340/1000 [21:50<3:50:52, 20.99s/it]


[340/1000]
  Bayesian:      loss=0.962  train=0.895¬±0.117  val=0.829¬±0.154
  Deterministic: loss=0.757  train=0.936¬±0.075  val=0.890¬±0.125
  Hyperparams: tau2=1.455e-01  sigma2=9.301e-01  lambda=4.709e-02


Training:  36%|‚ñà‚ñà‚ñà‚ñå      | 360/1000 [23:06<3:43:50, 20.98s/it]


[360/1000]
  Bayesian:      loss=0.968  train=0.869¬±0.115  val=0.824¬±0.158
  Deterministic: loss=0.768  train=0.927¬±0.080  val=0.876¬±0.138
  Hyperparams: tau2=1.459e-01  sigma2=9.274e-01  lambda=4.696e-02


Training:  38%|‚ñà‚ñà‚ñà‚ñä      | 380/1000 [24:23<3:36:56, 20.99s/it]


[380/1000]
  Bayesian:      loss=0.972  train=0.892¬±0.116  val=0.830¬±0.157
  Deterministic: loss=0.790  train=0.943¬±0.068  val=0.884¬±0.135
  Hyperparams: tau2=1.463e-01  sigma2=9.250e-01  lambda=4.688e-02


Training:  40%|‚ñà‚ñà‚ñà‚ñà      | 400/1000 [25:39<3:30:33, 21.06s/it]


[400/1000]
  Bayesian:      loss=0.963  train=0.891¬±0.113  val=0.823¬±0.157
  Deterministic: loss=0.770  train=0.940¬±0.063  val=0.878¬±0.133
  Hyperparams: tau2=1.467e-01  sigma2=9.227e-01  lambda=4.680e-02


Training:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 420/1000 [26:56<3:23:24, 21.04s/it]


[420/1000]
  Bayesian:      loss=0.968  train=0.892¬±0.115  val=0.829¬±0.161
  Deterministic: loss=0.768  train=0.937¬±0.071  val=0.882¬±0.133
  Hyperparams: tau2=1.470e-01  sigma2=9.207e-01  lambda=4.673e-02


Training:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 440/1000 [28:13<3:17:45, 21.19s/it]


[440/1000]
  Bayesian:      loss=0.973  train=0.893¬±0.111  val=0.828¬±0.160
  Deterministic: loss=0.791  train=0.944¬±0.072  val=0.889¬±0.128
  Hyperparams: tau2=1.473e-01  sigma2=9.189e-01  lambda=4.667e-02


Training:  46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 460/1000 [29:30<3:08:45, 20.97s/it]


[460/1000]
  Bayesian:      loss=0.956  train=0.897¬±0.111  val=0.833¬±0.158
  Deterministic: loss=0.755  train=0.937¬±0.076  val=0.885¬±0.132
  Hyperparams: tau2=1.476e-01  sigma2=9.172e-01  lambda=4.660e-02


Training:  48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 480/1000 [30:47<3:03:28, 21.17s/it]


[480/1000]
  Bayesian:      loss=0.958  train=0.896¬±0.121  val=0.832¬±0.157
  Deterministic: loss=0.732  train=0.940¬±0.098  val=0.888¬±0.132
  Hyperparams: tau2=1.478e-01  sigma2=9.156e-01  lambda=4.652e-02


Training:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 500/1000 [32:03<2:54:18, 20.92s/it]


[500/1000]
  Bayesian:      loss=0.958  train=0.904¬±0.107  val=0.830¬±0.157
  Deterministic: loss=0.754  train=0.943¬±0.071  val=0.887¬±0.130
  Hyperparams: tau2=1.480e-01  sigma2=9.143e-01  lambda=4.645e-02


Training:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 520/1000 [33:21<2:50:40, 21.33s/it]


[520/1000]
  Bayesian:      loss=0.960  train=0.908¬±0.107  val=0.832¬±0.157
  Deterministic: loss=0.769  train=0.943¬±0.067  val=0.887¬±0.136
  Hyperparams: tau2=1.482e-01  sigma2=9.131e-01  lambda=4.641e-02


Training:  54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 540/1000 [34:37<2:40:17, 20.91s/it]


[540/1000]
  Bayesian:      loss=0.955  train=0.896¬±0.106  val=0.821¬±0.161
  Deterministic: loss=0.764  train=0.948¬±0.066  val=0.879¬±0.134
  Hyperparams: tau2=1.484e-01  sigma2=9.120e-01  lambda=4.637e-02


Training:  56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 560/1000 [35:53<2:33:20, 20.91s/it]


[560/1000]
  Bayesian:      loss=0.953  train=0.905¬±0.110  val=0.828¬±0.159
  Deterministic: loss=0.740  train=0.955¬±0.063  val=0.888¬±0.126
  Hyperparams: tau2=1.485e-01  sigma2=9.111e-01  lambda=4.633e-02


Training:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 580/1000 [37:10<2:26:52, 20.98s/it]


[580/1000]
  Bayesian:      loss=0.945  train=0.905¬±0.106  val=0.825¬±0.160
  Deterministic: loss=0.726  train=0.949¬±0.070  val=0.885¬±0.129
  Hyperparams: tau2=1.487e-01  sigma2=9.103e-01  lambda=4.630e-02


Training:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 600/1000 [38:27<2:20:34, 21.09s/it]


[600/1000]
  Bayesian:      loss=0.937  train=0.907¬±0.102  val=0.833¬±0.158
  Deterministic: loss=0.711  train=0.945¬±0.064  val=0.884¬±0.135
  Hyperparams: tau2=1.488e-01  sigma2=9.095e-01  lambda=4.627e-02


Training:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 620/1000 [39:43<2:12:47, 20.97s/it]


[620/1000]
  Bayesian:      loss=0.945  train=0.911¬±0.102  val=0.823¬±0.160
  Deterministic: loss=0.733  train=0.951¬±0.068  val=0.888¬±0.133
  Hyperparams: tau2=1.489e-01  sigma2=9.090e-01  lambda=4.625e-02


Training:  64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 640/1000 [41:00<2:06:29, 21.08s/it]


[640/1000]
  Bayesian:      loss=0.938  train=0.905¬±0.110  val=0.836¬±0.158
  Deterministic: loss=0.714  train=0.939¬±0.072  val=0.887¬±0.134
  Hyperparams: tau2=1.490e-01  sigma2=9.085e-01  lambda=4.624e-02


Training:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 660/1000 [42:17<2:00:25, 21.25s/it]


[660/1000]
  Bayesian:      loss=0.960  train=0.891¬±0.112  val=0.820¬±0.161
  Deterministic: loss=0.777  train=0.940¬±0.067  val=0.879¬±0.136
  Hyperparams: tau2=1.490e-01  sigma2=9.082e-01  lambda=4.622e-02


Training:  68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 680/1000 [43:34<1:53:20, 21.25s/it]


[680/1000]
  Bayesian:      loss=0.958  train=0.885¬±0.112  val=0.829¬±0.157
  Deterministic: loss=0.768  train=0.944¬±0.066  val=0.877¬±0.138
  Hyperparams: tau2=1.491e-01  sigma2=9.079e-01  lambda=4.621e-02


Training:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 700/1000 [44:51<1:45:47, 21.16s/it]


[700/1000]
  Bayesian:      loss=0.958  train=0.897¬±0.111  val=0.816¬±0.161
  Deterministic: loss=0.751  train=0.947¬±0.067  val=0.882¬±0.133
  Hyperparams: tau2=1.491e-01  sigma2=9.077e-01  lambda=4.621e-02


Training:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 719/1000 [46:07<18:01,  3.85s/it]


[720/1000]
  Bayesian:      loss=0.942  train=0.907¬±0.110  val=0.843¬±0.157
  Deterministic: loss=0.726  train=0.944¬±0.070  val=0.891¬±0.135
  Hyperparams: tau2=1.491e-01  sigma2=9.076e-01  lambda=4.620e-02

Early stopping triggered.

Best Validation Accuracy - Bayesian: 0.8669, Deterministic: 0.9084

‚úì Loaded best joint model for analysis.

TEST EVALUATION





Bayesian Test Accuracy: 0.878 ¬± 0.131
Deterministic Test Accuracy: 0.919 ¬± 0.106

Collecting predictions for analysis...
Calculating analysis metrics...

üöÄ RUNNING VALIDATED BAYESIAN METAOPTNET ANALYSIS...
üîç ANALYZING BAYESIAN METAOPTNET DECISIONS
üìä Collecting model predictions...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:17<00:00,  5.77it/s]



üß¨ K-MER PATTERNS: CORRECT vs INCORRECT PREDICTIONS (VALIDATED)
üìä Analyzing 13 classes with both correct and incorrect predictions

üíä APH2-DPRIME Predictions:
----------------------------------------
‚úÖ VALIDATED k-mers (specificity > 0.8):
   ‚Ä¢ AAAAAA: 2914 occurrences, specificity: 0.833
   ‚Ä¢ TTTTTT: 2305 occurrences, specificity: 0.905
   ‚Ä¢ AATAAA: 1912 occurrences, specificity: 0.866
   ‚Ä¢ ATATAT: 1845 occurrences, specificity: 0.831
   ‚Ä¢ ATAAAT: 1826 occurrences, specificity: 0.874
   ‚Ä¢ AAATAA: 1793 occurrences, specificity: 0.870
   ‚Ä¢ AAGAAA: 1735 occurrences, specificity: 0.848
   ‚Ä¢ TAAATA: 1729 occurrences, specificity: 0.867
   ‚Ä¢ AAAGAA: 1658 occurrences, specificity: 0.847
   ‚Ä¢ TATATA: 1597 occurrences, specificity: 0.826
   ‚Ä¢ AGAAAG: 1537 occurrences, specificity: 0.851
   ‚Ä¢ GAAAGA: 1507 occurrences, specificity: 0.847
   ‚Ä¢ GAATGA: 1327 occurrences, specificity: 0.884
   ‚Ä¢ AATGAA: 1321 occurrences, specificity: 0.884
   ‚Ä¢ TGAATG: 1319 o