In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle 
import os
from os.path import join
from glob import glob
import sys
from typing import Optional, List, Dict, Union
from collections import Counter
import random
from tqdm import tqdm

sys.path.append("../genev2/")
from ProteinEmbedding import ProteinEmbeddings
from GraphEmbeddings import GraphEmbeddings
from DNAEmbeddings import DNAEmbeddings
from IsolateDB import IsolateDB
from graph import build_graph_from_edges, laplacian_eigenvectors
from Dataset import ViromeDataset, NoisySubset
from Transformer import TransformerClassifier
from Loss import IntegratedMaskedCrossEntropyLoss
from Train import train_val_loop

import h5py
import sqlite3
import faiss
import cudf
import cupy as cp
from cuml.cluster import AgglomerativeClustering, HDBSCAN
from cuml import KMeans
from cuml.manifold import UMAP
from cuml.decomposition import PCA
from cuml.preprocessing import StandardScaler

import math
import torch
import wandb
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Subset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from sklearn.metrics import precision_score, recall_score, average_precision_score

### Load file paths

In [3]:
cluster_path = "/projects/m000151/khoa/repos/prophage/data/cluster_map.csv"
embeddings_path = "/projects/m000151/khoa/repos/prophage/data/protein_embeddings.h5"
virus_path = "/projects/m000151/khoa/repos/prophage/data/virome_db.csv"
genome_path = "/projects/m000151/khoa/repos/prophage/data/genome_db_new.csv"
isolate_path = "/projects/m000151/khoa/repos/prophage/data/isolate_db_full.sqlite"
vclust_path = "/projects/m000151/khoa/repos/prophage/data/cluster_ani30_qcov50_rcov50_set_cover.tsv"
edges_path = "/projects/m000151/khoa/repos/prophage/data/ecoli_ppi_edges.txt"
ecoli_embeddings = "/projects/m000151/khoa/repos/prophage/data/ecoli.esmc_embeddings.pkl"
dna_embeddings =  "/projects/m000151/khoa/repos/prophage/data/dna_features.h5"
dna_embeddings_path =  "/projects/m000151/khoa/repos/prophage/data/compressed_dnafeat.h5"
protein_embeddings_path = "/projects/m000151/khoa/repos/prophage/data/compressed_protein_embeddings.h5"
protein_embeddings_og_path = "/projects/m000151/khoa/repos/prophage/data/protein_embeddings.h5"
graph_embeddings_path = "/projects/m000151/khoa/repos/prophage/data/graph_embeddings.h5"
graph_protein_cluster_map = "/projects/m000151/khoa/repos/prophage/data/protein_cluster_graph_map.npy"

### Prepare embeddings and datasets


In [4]:
proteinEmbeddings = ProteinEmbeddings(
    embeddings_path=protein_embeddings_path,
    cluster_map_path=cluster_path
)
dnaEmbeddings = DNAEmbeddings(
    embeddings_path=dna_embeddings_path
)
graph_tokens = np.load(graph_protein_cluster_map, allow_pickle=True).item()
graphEmbeddings = GraphEmbeddings(graph_embeddings_path, cluster_path, graph_tokens)
db = IsolateDB(isolate_path)

In [None]:
full_dataset = ViromeDataset(db, proteinEmbeddings, dnaEmbeddings, graphEmbeddings, max_sequence_length=10_000)
train_dataset = NoisySubset(full_dataset, train_indices)
val_dataset   = Subset(full_dataset, val_indices)
test_dataset  = Subset(full_dataset, test_indices)

In [None]:
# 1) Hyper-parameters
NUM_EPOCHS    = 40
BATCH_SIZE    = 16
MAX_SEQ_LEN   = 10_000
GRAD_CLIP     = 1.0
DEVICE        = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

### Train loop

In [None]:
wandb.init(
    project="virome-transformer",
    name=f"model_integrated_loss",
    config={
        "track_embed_dim": 8,
        "strand_embed_dim": 2,
        "homologs_embed_dim": 1,
        "embed_dim": 256,
        "num_heads": 4,
        "ffn_dim": 256,
        "num_layers": 2,
        "dropout": 0.2,
        "lr": 1e-4,
        "batch_size": train_loader.batch_size,
        "epochs": 30,
        "max_grad_norm": 0.1,
        "warmup_ratio": 0.05,
        "protein_embed_dim_shrink":128,
        "dna_embed_dim_shrink":128,
        "graph_embed_dim_shrink":16,
        "unk_bias_ratio": 0.95
    }
)

# 2) Build model, loss, optimizer
model = TransformerClassifier(
    type_vocab_size    = len(full_dataset.type_categories),
    biotype_vocab_size = len(full_dataset.biotype_categories),
    strand_vocab_size  = 2,
    track_embed_dim    = wandb.config.track_embed_dim,
    strand_embed_dim   = wandb.config.strand_embed_dim,
    homologs_embed_dim = wandb.config.homologs_embed_dim,
    protein_embed_dim  = proteinEmbeddings.embedding_dim,
    dna_embed_dim      = dnaEmbeddings.embedding_dim,
    graph_embed_dim    = graphEmbeddings.embedding_dim,
    protein_embed_dim_shrink  = wandb.config.protein_embed_dim_shrink,
    dna_embed_dim_shrink      = wandb.config.dna_embed_dim_shrink,
    graph_embed_dim_shrink    = wandb.config.graph_embed_dim_shrink,
    embed_dim          = wandb.config.embed_dim,
    num_heads          = wandb.config.num_heads,
    ffn_dim            = wandb.config.ffn_dim,
    num_layers         = wandb.config.num_layers,
    dropout            = wandb.config.dropout,
    max_seq_len        = 10_000
).to(DEVICE)

# loss_fn   = MaskedCrossEntropyLoss(unk_bias_ratio = 0.98, unk_weight = 1.0)
loss_fn   = IntegratedMaskedCrossEntropyLoss(unk_bias_ratio = wandb.config.unk_bias_ratio)

no_decay = ["bias", "LayerNorm.weight"]
params = [
    {
        "params": [p for n,p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 1e-5,   # most of the weights
    },
    {
        "params": [p for n,p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,    # biases & norm layers
    },
]
optimizer = torch.optim.AdamW(params, lr=wandb.config.lr)

# 3) Create a linear‐warmup + decay scheduler
total_steps = len(train_loader) * wandb.config.epochs
warmup_steps = int(total_steps * wandb.config.warmup_ratio)
# scheduler = get_linear_schedule_with_warmup(
#     optimizer,
#     num_warmup_steps=warmup_steps,
#     num_training_steps=total_steps
# )

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
    num_cycles=0.5,       # one half‐cycle of cosine (optional; default=0.5)
    last_epoch=-1
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"model_d{wandb.config.embed_dim}_h{wandb.config.num_heads}_l{wandb.config.num_layers}_integrated_losss_{timestamp}"
output_dir = f"/projects/m000151/khoa/repos/prophage/outs/ckpts/{file_name}"
os.makedirs(output_dir, exist_ok=True)

# 4) Run training/validation loop with W&B, gradient clipping, and warm‐up
history = train_val_loop(
    model,
    train_loader,
    val_loader,
    optimizer,
    loss_fn,
    DEVICE,
    num_epochs=wandb.config.epochs,
    checkpoint_prefix=f"{output_dir}/epoch_",
    use_wandb=True,
    max_grad_norm=wandb.config.max_grad_norm,
    scheduler=scheduler
)


### Evaluate model

In [None]:
wandb.finish()
best_model = model

In [None]:
best_model = TransformerClassifier(
    type_vocab_size    = len(full_dataset.type_categories),
    biotype_vocab_size = len(full_dataset.biotype_categories),
    strand_vocab_size  = 2,
    track_embed_dim    = wandb.config.track_embed_dim,
    strand_embed_dim   = wandb.config.strand_embed_dim,
    homologs_embed_dim = wandb.config.homologs_embed_dim,
    protein_embed_dim  = proteinEmbeddings.embedding_dim,
    dna_embed_dim      = dnaEmbeddings.embedding_dim,
    graph_embed_dim    = graphEmbeddings.embedding_dim,
    protein_embed_dim_shrink  = wandb.config.protein_embed_dim_shrink,
    dna_embed_dim_shrink      = wandb.config.dna_embed_dim_shrink,
    graph_embed_dim_shrink    = wandb.config.graph_embed_dim_shrink,
    embed_dim          = wandb.config.embed_dim,
    num_heads          = wandb.config.num_heads,
    ffn_dim            = wandb.config.ffn_dim,
    num_layers         = wandb.config.num_layers,
    dropout            = wandb.config.dropout,
    max_seq_len        = 10_000
).to(DEVICE)
best_model.load_state_dict(torch.load("model_d256_h4_l2_integrated_losss_30.pt")["model_state_dict"])

In [None]:
def plot_history(history):
    
    start = 30
    import seaborn as sns
    import matplotlib.pyplot as plt
    sns.set(style="whitegrid")
    train = history['train']
    val   = history['val']
    epochs = list(range(1, len(train['loss'])+1))

    # Plot bacteria metrics
    fig_bacteria, axes_bacteria = plt.subplots(2, 2, figsize=(14, 10))
    fig_bacteria.suptitle('Bacteria Metrics', fontsize=16)

    # Loss
    sns.lineplot(x=epochs[start:], y=train['loss'][start:], ax=axes_bacteria[0,0], label='Train Loss')
    sns.lineplot(x=epochs[start:], y=val['loss'][start:],   ax=axes_bacteria[0,0], label='Val Loss')
    axes_bacteria[0,0].set_title('Loss over Batches')

    # PR-AUC Bacteria
    sns.lineplot(x=epochs[start:], y=train['ap_neg'][start:], ax=axes_bacteria[0,1], label='Train PR-AUC')
    sns.lineplot(x=epochs[start:], y=val['ap_neg'][start:],   ax=axes_bacteria[0,1], label='Val PR-AUC')
    axes_bacteria[0,1].set_title('PR-AUC')

    # Precision & Recall Bacteria
    sns.lineplot(x=epochs[start:], y=train['prec_neg'][start:], ax=axes_bacteria[1,0], label='Train Precision')
    sns.lineplot(x=epochs[start:], y=val['prec_neg'][start:],   ax=axes_bacteria[1,0], label='Val Precision')
    sns.lineplot(x=epochs[start:], y=train['rec_neg'][start:], ax=axes_bacteria[1,0], label='Train Recall')
    sns.lineplot(x=epochs[start:], y=val['rec_neg'][start:],   ax=axes_bacteria[1,0], label='Val Recall')
    axes_bacteria[1,0].set_title('Precision & Recall')

    # Residuals Bacteria
    sns.lineplot(x=epochs[start:], y=np.array(train['pred_neg'])[start:] - np.array(train['true_neg'])[start:],\
                 ax=axes_bacteria[1,1], label='Train Residuals')
    sns.lineplot(x=epochs[start:], y=np.array(val['pred_neg'])[start:] - np.array(val['true_neg'])[start:],\
                 ax=axes_bacteria[1,1], label='Val Residuals')
    axes_bacteria[1,1].set_title('Residuals')

    plt.tight_layout()
    plt.show()

    # Plot virus metrics
    fig_virus, axes_virus = plt.subplots(2, 2, figsize=(14, 10))
    fig_virus.suptitle('Virus Metrics', fontsize=16)

    # Loss
    sns.lineplot(x=epochs[start:], y=train['loss'][start:], ax=axes_virus[0,0], label='Train Loss')
    sns.lineplot(x=epochs[start:], y=val['loss'][start:],   ax=axes_virus[0,0], label='Val Loss')
    axes_virus[0,0].set_title('Loss over Batches')

    # PR-AUC Virus
    sns.lineplot(x=epochs[start:], y=train['ap_pos'][start:], ax=axes_virus[0,1], label='Train PR-AUC')
    sns.lineplot(x=epochs[start:], y=val['ap_pos'][start:],   ax=axes_virus[0,1], label='Val PR-AUC')
    axes_virus[0,1].set_title('PR-AUC')

    # Precision & Recall Virus
    sns.lineplot(x=epochs[start:], y=train['prec_pos'][start:], ax=axes_virus[1,0], label='Train Precision')
    sns.lineplot(x=epochs[start:], y=val['prec_pos'][start:],   ax=axes_virus[1,0], label='Val Precision')
    sns.lineplot(x=epochs[start:], y=train['rec_pos'][start:], ax=axes_virus[1,0], label='Train Recall')
    sns.lineplot(x=epochs[start:], y=val['rec_pos'][start:],   ax=axes_virus[1,0], label='Val Recall')
    axes_virus[1,0].set_title('Precision & Recall')

    # Residuals Virus
    sns.lineplot(x=epochs[start:], y=np.array(train['pred_pos'])[start:] - np.array(train['true_pos'])[start:],\
                 ax=axes_virus[1,1], label='Train Residuals')
    sns.lineplot(x=epochs[start:], y=np.array(val['pred_pos'])[start:] - np.array(val['true_pos'])[start:],\
                 ax=axes_virus[1,1], label='Val Residuals')
    axes_virus[1,1].set_title('Residuals')

    plt.tight_layout()
    plt.show()

# To visualize:
plot_history(history)

In [None]:
def evaluate_on_test(model, test_loader, loss_fn, device="cuda"):
    """
    Runs the model on test_loader and returns per-token P(class=1), true labels,
    and padding masks as 2D arrays of shape (num_samples, seq_length).

    Returns:
        probabilities: np.ndarray[float] of shape (N, L)  # P(class=1)
        true_labels:   np.ndarray[int]   of shape (N, L)
        masks:         np.ndarray[int]   of shape (N, L)  # 1 for real tokens, 0 for padding
    """
    model.eval()
    all_probs = []
    all_trues = []
    all_masks = []
    record_ids = []
    seq_len = None

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            # Move inputs to device
            inputs = {
                'type_ids':    batch['type_track'].to(device),
                'biotype_ids': batch['biotype_track'].to(device),
                'strand_ids':  batch['strand_track'].to(device),
                'homologs':    batch['homologs_track'].to(device),
                'protein_emb': batch['protein_embeddings'].to(device),
                'dna_emb':     batch['dna_embeddings'].to(device),
                'graph_emb':   batch['graph_embeddings'].to(device),
                'mask':        batch['padding_mask'].to(device)
            }
            labels = batch['labels'].to(device)  # (B, L)

            # Forward + loss check
            logits = model(
                inputs['type_ids'], inputs['biotype_ids'], inputs['strand_ids'],
                inputs['homologs'], inputs['protein_emb'], inputs['dna_emb'],
                inputs['graph_emb'], mask=inputs['mask']
            )
            test_loss = loss_fn(logits, labels, inputs['mask'])
            if torch.isnan(test_loss):
                raise RuntimeError(f"NaN in test loss at batch {batch_idx}")

            # Softmax and take probability of class=1
            probs = F.softmax(logits, dim=-1)       # (B, L, C)
            prob_pos = probs[:, :, 1]               # (B, L)

            # Move to CPU numpy
            prob_np   = prob_pos.cpu().numpy()
            labels_np = labels.cpu().numpy()
            mask_np   = inputs['mask'].cpu().numpy().astype(int)

            # Record sequence length
            if seq_len is None:
                seq_len = prob_np.shape[1]
            if prob_np.shape[1] != seq_len:
                raise ValueError(f"Unexpected seq length: got {prob_np.shape[1]}, expected {seq_len}")

            all_probs.append(prob_np)
            all_trues.append(labels_np)
            all_masks.append(mask_np)
            record_ids.extend(batch["record_id"])

    # Stack batches → (N_total, L)
    probabilities = np.vstack(all_probs)
    true_labels   = np.vstack(all_trues)
    masks         = np.vstack(all_masks)

    return probabilities, true_labels, masks, record_ids




In [None]:

def visualize_token_predictions(predictions, true_labels, masks, record_ids,
                                num_samples=5, title="", figname="", sort_order = "descending"):
    """
    Overlay true vs. predicted token labels for up to `num_samples` sequences.
    Here, mask==1 marks tokens to ignore, so only mask==0 positions are counted.
    Samples are sorted by their count of valid tokens (mask==0), descending,
    and each subplot spans exactly those valid tokens.
    """
    # Compute valid token counts (mask==0)
    valid_counts = (masks == 0).sum(axis=1)
    # Sort indices by descending valid token count
    if sort_order == "descending":
        sorted_idx = np.argsort(valid_counts)[::-1]
    else:
        sorted_idx = np.argsort(valid_counts)
    selected = sorted_idx[:num_samples]
    fig, axes = plt.subplots(len(selected), 1,
                             figsize=(12, 2 * len(selected)),
                             sharex=False)
    if len(selected) == 1:
        axes = [axes]

    for ax, idx in zip(axes, selected):
        L = valid_counts[idx]  # number of tokens where mask==0
        x = np.arange(L)

        # Take only the first L tokens (assuming masked==1 are trailing pads)
        true_vals = true_labels[idx, :L]
        pred_vals = predictions[idx, :L]

        ax.plot(x, true_vals, label='True label',
                alpha=0.5, marker='o', linestyle='-')
        ax.plot(x, pred_vals, label='Predicted label',
                alpha=0.5, linestyle='-')

        ax.set_xlim(0, L-1)
        ax.set_ylim(-1.1, 1.1)
        ax.set_yticks([-1, 0, 1])
        ax.set_title(f"Sample {record_ids[idx]} (valid tokens={L})")
        ax.legend(loc='upper right')

    plt.suptitle(title, fontsize=16)
    plt.xlabel('Token position')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    if figname:
        plt.savefig(f"{figname}.png")
    plt.show()




In [None]:
test_preds, test_trues, test_masks, test_record_ids = evaluate_on_test(best_model, test_loader, loss_fn, device="cuda")
val_preds, val_trues, val_masks, val_record_ids = evaluate_on_test(best_model, val_loader, loss_fn, device="cuda")
train_preds, train_trues, train_masks, train_record_ids = evaluate_on_test(best_model, train_loader, loss_fn, device="cuda")

In [None]:
np.savez(f"test_out_{file_name}.npz", prediction = test_preds, label = test_trues, mask = test_masks, record_id = test_record_ids)
np.savez(f"train_out_{file_name}.npz", prediction = train_preds, label = train_trues, mask = train_masks, record_id = train_record_ids)
np.savez(f"val_out_{file_name}.npz", prediction = val_preds, label = val_trues, mask = val_masks, record_id = val_record_ids)


In [None]:
# 1) First, gather all the smoothed sequences that pass your criteria
regions = []  # will hold tuples of (smooth_array, record_id)
for i, pred in enumerate(preds):
    seq = pred[masks[i] == 0]
    if seq.size <= 500:
        continue

    smooth = seq #moving_average(seq, N=1)
    if smooth.max() < 0.5:
        continue

    regions.append((smooth, record_ids[i]))

n = len(regions)
if n == 0:
    print("No high-prob regions found.")
else:
    # 2) Compute a “nice” grid shape
    ncols = int(math.ceil(math.sqrt(n)))
    nrows = int(math.ceil(n / ncols))

    # 3) Create subplots
    fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 3*nrows), squeeze=False)

    # 4) Plot each region in its own axes
    for ax, (smooth, rid) in zip(axes.flat, regions):
        ax.plot(smooth)
        ax.set_title(rid, fontsize='small')
        ax.set_xlabel("Position")
        ax.set_ylabel("Probability")

    # 5) Turn off any unused subplots
    for ax in axes.flat[n:]:
        ax.set_visible(False)

    fig.tight_layout()
    fig.suptitle(f"{n} High-Probability Regions", y=1.02)
    fig.savefig("grid_highprob_0602.png", bbox_inches='tight')
    plt.close(fig)

    print("Plotted", n, "regions in a", nrows, "×", ncols, "grid.")


In [None]:
visualize_token_predictions(val_preds, val_trues, val_masks, val_record_ids, num_samples=100,\
                            title="Testset: True vs Predicted Labels", \
                            figname=f"{file_name}_val_prediction_ascending", sort_order = "ascending")

In [None]:
def plot_probability_boxplots(
    model,
    loader,
    device,
    num_batches: int = 1,
    title: str = ""
):
    """
    Plots violin plots of predicted positive-class probabilities stratified by true label groups:
    Negative (0), Positive (1), Unknown (-1), and Padding.

    Args:
        model: trained TransformerClassifier
        loader: DataLoader
        device: torch device
        num_batches: how many batches to include
        title: plot title
    """
    import seaborn as sns
    import matplotlib.pyplot as plt
    import torch
    import torch.nn.functional as F
    import pandas as pd

    model.eval()
    records = []
    it = iter(loader)
    with torch.no_grad():
        for _ in range(num_batches):
            batch = next(it)
            # move to device
            inputs = {
                'type_ids':    batch['type_track'].to(device),
                'biotype_ids': batch['biotype_track'].to(device),
                'strand_ids':  batch['strand_track'].to(device),
                'homologs':    batch['homologs_track'].to(device),
                'protein_emb': batch['protein_embeddings'].to(device),
                'dna_emb':     batch['dna_embeddings'].to(device),
                'graph_emb':   batch['graph_embeddings'].to(device),
                'mask':        batch['padding_mask'].to(device)
            }
            labels = batch['labels'].to(device)

            logits = model(
                inputs['type_ids'], inputs['biotype_ids'], inputs['strand_ids'],
                inputs['homologs'], inputs['protein_emb'], inputs['dna_emb'], inputs['graph_emb'],
                mask=inputs['mask']
            )
            probs = F.softmax(logits, dim=-1).view(-1, 2)

            flat_labels = labels.view(-1)
            flat_mask   = inputs['mask'].view(-1)

            # gather data
            for idx in range(probs.size(0)):
                if flat_mask[idx]:
                    group = 'Padding'
                else:
                    true = flat_labels[idx].item()
                    if true == 1:
                        group = 'Positive (1)'
                    elif true == 0:
                        group = 'Negative (0)'
                    else:
                        group = 'Unknown (-1)'
                prob_pos = probs[idx, 1].item()
                records.append({'TrueGroup': group, 'ProbPos': prob_pos})

    df = pd.DataFrame(records)
    sns.set(style="whitegrid")
    plt.figure(figsize=(8,6))
    order = ['Negative (0)', 'Positive (1)', 'Unknown (-1)', 'Padding']
    sns.stripplot(x='TrueGroup', y='ProbPos', data=df, order=order, alpha = 0.1)
    plt.title(title or 'Predicted Positive Probability by True Label Group')
    plt.xlabel('True Label Group')
    plt.ylabel('Predicted Probability (class=1)')
    plt.xticks(rotation=30)
    plt.tight_layout()
    plt.show()

In [None]:
plot_probability_boxplots(best_model, train_loader, DEVICE, num_batches=3, title="Train Probabilities")