In [None]:

CKPT_PATH = r"models\singletype\CHG.pth"  
TEST_CSV  = r"data\test_CHH.csv"  

BATCH_SIZE = 512
NUM_WORKERS = 0
SIG_SCALAR_MODE = "none"   

import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd

from scripts.dataLoader import (
    load_dataset, make_data, MyDataSet, encode_seq_13mer
)
from moduls import MethyNano

def _nz(x: torch.Tensor) -> torch.Tensor:
  
    return torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

def _collate_cls_13mer(batch):
 
    seq_ids, sig, stats, labels = [], [], [], []
    for s, n, y in batch:
        n = torch.as_tensor(n, dtype=torch.float32)  # [13,103]

        st = n[:, :3]
        sg = n[:, 3:]

        if SIG_SCALAR_MODE != "none":
            if SIG_SCALAR_MODE == "first":
                v = sg[:, 0]
            elif SIG_SCALAR_MODE == "center":
                v = sg[:, 50]
            elif SIG_SCALAR_MODE == "mean":
                v = sg.mean(dim=-1)
            else:
                v = sg.mean(dim=-1)
            sg = v.unsqueeze(-1).expand(-1, sg.size(-1))  # [13] -> [13,100]

        st = _nz(st)
        sg = _nz(sg)

        seq_ids.append(torch.tensor(encode_seq_13mer(s), dtype=torch.long))
        stats.append(st)
        sig.append(sg)
        labels.append(torch.tensor(int(y), dtype=torch.long))

    seq_ids = torch.stack(seq_ids, 0)   # [B,13]
    stats   = torch.stack(stats,   0)   # [B,13,3]
    sig     = torch.stack(sig,     0)   # [B,13,100]
    labels  = torch.stack(labels,  0)   # [B]
    return seq_ids, sig, stats, labels

def build_test_loader(test_csv, batch_size=256, num_workers=4):
  
    testData = load_dataset(test_csv, feature_mode="both", mask=-1)
    seq, nano, label = make_data(testData)
    ds = MyDataSet(seq, nano, label)
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=(num_workers > 0),
        collate_fn=_collate_cls_13mer,
        drop_last=False,
    )
    return loader, seq, label 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = MethyNano(
    with_projection=False,
    with_classification=True,
    dimension=256,
    n_heads=8,
    dropout=0.1,
    base_sig=160,
).to(device)

print("Loading checkpoint from:", CKPT_PATH)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
state = ckpt.get("model", ckpt)  
missing, unexpected = model.load_state_dict(state, strict=False)
print(f"missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")

model.eval()

test_loader, test_seq_list, test_label_tensor = build_test_loader(
    TEST_CSV, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
)

all_probs = []
all_logits = []
all_labels = []

with torch.no_grad():
    pbar = tqdm(test_loader, desc="Predict", dynamic_ncols=True)
    for seq_ids, sig, stats, labels in pbar:
        sig = _nz(sig).to(device)
        stats = _nz(stats).to(device)
        seq_ids = seq_ids.to(device)
        labels = labels.to(device)

        out = model(sig, seq_ids, stats)
        logits = out["logits"]                            # [B,2]
        probs = F.softmax(logits, dim=-1)[:, 1]       

        all_logits.append(logits.cpu())
        all_probs.append(probs.cpu())
        all_labels.append(labels.cpu())


all_logits = torch.cat(all_logits, dim=0)    # [N,2]
all_probs  = torch.cat(all_probs,  dim=0)    # [N]
all_labels = torch.cat(all_labels, dim=0)    # [N]

preds = (all_probs >= 0.5).long()           

print("Total samples:", all_labels.numel())
print("Positive ratio (true):", float((all_labels == 1).float().mean()))
print("Positive ratio (pred):", float((preds == 1).float().mean()))


acc = float((preds == all_labels).float().mean())
print(f"Test Accuracy: {acc:.4f}")

y_true = all_labels.numpy().astype(int)
y_pred = preds.numpy().astype(int)
y_score = all_probs.numpy().astype(float)


tp = int(((y_pred == 1) & (y_true == 1)).sum())
tn = int(((y_pred == 0) & (y_true == 0)).sum())
fp = int(((y_pred == 1) & (y_true == 0)).sum())
fn = int(((y_pred == 0) & (y_true == 1)).sum())

acc = (tp + tn) / max(1, tp + tn + fp + fn)
precision = tp / max(1, tp + fp)
recall    = tp / max(1, tp + fn)
f1        = 2 * precision * recall / max(1e-12, precision + recall)

order = np.argsort(-y_score)      
y_sorted = y_true[order]

P = max(1, (y_true == 1).sum())
N = max(1, (y_true == 0).sum())

tp_c = np.cumsum(y_sorted == 1)
fp_c = np.cumsum(y_sorted == 0)

tpr = tp_c / P
fpr = fp_c / N

auroc = float(np.trapz(
    np.concatenate([[0.0], tpr, [1.0]]),
    np.concatenate([[0.0], fpr, [1.0]]),
))

prec_curve = tp_c / np.maximum(1, tp_c + fp_c)
rec_curve  = tp_c / P

auprc = float(np.trapz(
    np.concatenate([[1.0], prec_curve, [prec_curve[-1] if prec_curve.size else 1.0]]),
    np.concatenate([[0.0], rec_curve, [1.0]]),
))

print("==== Test metrics ====")
print(f"ACC    : {acc:.4f}")
print(f"PREC   : {precision:.4f}")
print(f"RECALL : {recall:.4f}")
print(f"F1     : {f1:.4f}")
print(f"AUROC  : {auroc:.4f}")
print(f"AUPRC  : {auprc:.4f}")
print("======================")


seq_col   = list(test_seq_list)
true_col  = all_labels.numpy().astype(int).tolist()
pred_col  = preds.numpy().astype(int).tolist()
prob_col  = all_probs.numpy().tolist()

df_out = pd.DataFrame({
    "seq_13mer": seq_col,
    "label_true": true_col,
    "label_pred": pred_col,
    "prob_pos": prob_col,
})


display(df_out.head(10))


out_path = "test_predictions_methynano.csv"
df_out.to_csv(out_path, index=False)
print("Saved prediction table to:", out_path)
