In [12]:
# =========================
# 0. Imports & Setup
# =========================
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from sklearn.metrics import roc_auc_score, average_precision_score
import pickle
import time
import torch.nn.functional as F
import os, re


# Reproducibility
RNG = 42
np.random.seed(RNG)
torch.manual_seed(RNG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
# =========================
# 1. Load and Preparing datasets
# =========================
print("Starting: 1. Load and Prepare datasets")
reads_df = pd.read_csv("Raw File/dataset0.csv")

columns_to_drop = ['Unnamed: 0']
reads_df = reads_df.drop(columns_to_drop, axis=1, errors='ignore')

# sorry cherron i cannot with the colnames
reads_df = reads_df.rename(columns={
    'ID': 'transcript_id',
    'POS': 'transcript_position',
    'SEQ': '7mer'
})
reads_df['n_reads'] = reads_df.groupby(['transcript_id', 'transcript_position']).transform('size')


print("Ending: 1. Load and Prepare datasets")
# =========================
# 2. Preprocessing: 7-mer embedding
# =========================
print("Starting: 2. Preprocessing: 7-mer embedding")

def encode_drach_compact(seq):
    """
    Compact one-hot encoding of a 7-mer centered on a DRACH motif.
    Positions:
    - 0: full one-hot (A,C,G,T) â†’ 4 dims
    - 1: D (A,G,T) â†’ 3 dims
    - 2: R (A,G)   â†’ 2 dims
    - 3: A (fixed) â†’ 0 dims
    - 4: C (fixed) â†’ 0 dims
    - 5: H (A,C,T) â†’ 3 dims
    - 6: full one-hot (A,C,G,T) â†’ 4 dims
    Total: 16-dimensional vector
    """
    encoding = []

    base = seq[0]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    base = seq[1]
    encoding.extend(one_hot_base(base, ['A', 'G', 'T']))  # D

    base = seq[2]
    encoding.extend(one_hot_base(base, ['A', 'G']))       # R

    # skip position 3 (always A)
    # skip position 4 (always C)

    base = seq[5]
    encoding.extend(one_hot_base(base, ['A', 'C', 'T']))  # H

    base = seq[6]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    return np.array(encoding, dtype=np.float32)

def one_hot_base(base, allowed):
    """One-hot encode base using only allowed bases."""
    vec = [0] * len(allowed)
    if base in allowed:
        vec[allowed.index(base)] = 1
    return vec

reads_df['7mer_emb'] = reads_df['7mer'].apply(encode_drach_compact)
print("Ending: 2. Preprocessing: 7-mer embedding")


# =========================
# 3. Assign split bins
# =========================
print("Starting: 3. Assign split bins")

def assign_set_type_by_gene(reads_df, split_ratios={'Train': 0.8, 'Val': 0.1, 'Test': 0.1}, random_state=42):
    """
    Assigns each row in reads_df a 'set_type' of Train, Val, or Test,
    ensuring all rows with the same gene_id are in the same set,
    and total number of rows (not just genes) in each set matches desired ratios.
    Label distribution is approximately balanced using a greedy strategy.
    """

    # Step 1: Get stats per gene
    gene_stats = (
        reads_df
        .groupby('gene_id')['label']
        .value_counts()
        .unstack(fill_value=0)
        .rename(columns={0: 'label_0', 1: 'label_1'})
        .reset_index()
    )
    gene_stats['total'] = gene_stats['label_0'] + gene_stats['label_1']

    # Shuffle genes for randomness
    gene_stats = gene_stats.sample(frac=1, random_state=random_state).reset_index(drop=True)

    # Step 2: Overall label distribution and target row counts
    total_rows = gene_stats['total'].sum()
    total_label_1 = gene_stats['label_1'].sum()
    overall_pos_rate = total_label_1 / total_rows

    target_rows = {k: total_rows * split_ratios[k] for k in split_ratios}

    # Step 3: Initialize bins
    bins = {
        'Train': {'genes': [], 'label_0': 0, 'label_1': 0, 'total': 0},
        'Val': {'genes': [], 'label_0': 0, 'label_1': 0, 'total': 0},
        'Test': {'genes': [], 'label_0': 0, 'label_1': 0, 'total': 0},
    }

    def pick_bin():
        # Find the bin with the biggest gap between current and target row count
        diffs = {k: target_rows[k] - bins[k]['total'] for k in bins}
        # Choose the bin that needs rows the most
        return max(diffs, key=diffs.get)

    # Step 4: Assign genes to bins to match row targets and label balance
    for _, row in gene_stats.iterrows():
        chosen_bin = pick_bin()
        bins[chosen_bin]['genes'].append(row['gene_id'])
        bins[chosen_bin]['label_0'] += row['label_0']
        bins[chosen_bin]['label_1'] += row['label_1']
        bins[chosen_bin]['total'] += row['total']

    # Step 5: Map gene_id â†’ set_type
    gene_to_set = {}
    for set_name, bin_data in bins.items():
        for gene_id in bin_data['genes']:
            gene_to_set[gene_id] = set_name

    reads_df['set_type'] = reads_df['gene_id'].map(gene_to_set)

    return reads_df


reads_df = assign_set_type_by_gene(reads_df)

set_counts = reads_df['set_type'].value_counts()
print("ðŸ“Š Number of rows in each set:")
for set_name, count in set_counts.items():
    print(f"  - {set_name}: {count} rows")

# Print label distribution per set (normalized)
label_distributions = reads_df.groupby('set_type')['label'].value_counts(normalize=True).unstack()

print("\nðŸ“ˆ Label distribution (percentage of label 0 and 1) in each set:")
for set_name in label_distributions.index:
    label_0_pct = label_distributions.loc[set_name].get(0, 0) * 100
    label_1_pct = label_distributions.loc[set_name].get(1, 0) * 100
    print(f"  - {set_name}:")
    print(f"      â€¢ Label 0: {label_0_pct:.2f}%")
    print(f"      â€¢ Label 1: {label_1_pct:.2f}%")

print("Ending: 3. Assign split bins")



In [11]:
# print("Column Data Types:")
# print(reads_df.dtypes)

# # Display the number of rows
# print("\nNumber of Rows:", len(reads_df))

In [12]:
# file_path = "Dataset/"
# file_name = "processed_dataset.parquet"
# reads_df.to_parquet(f"{file_path}{file_name}", index=False)
# print(f"Saved processed dataset to {file_path}{file_name}")

In [13]:
# =========================
# 4. Dataset class
# =========================
print("Starting: 4. Dataset class")
class MILReadDataset(Dataset):
    def __init__(self, reads_df, n_reads_per_site=None, agg_config=None):
        """
        reads_df: DataFrame with read-level rows:
                  ['transcript_id','transcript_position','7mer_emb','label',
                   'PreTime','PreSD','PreMean','InTime','InSD','InMean','PostTime','PostSD','PostMean',...]
        n_reads_per_site: int or None
            - int: sample at most this many reads per site
            - None: use all reads
        agg_config: dict for aggregation at site level, e.g.
            {
              "Time": ["min","max","mean","25","75"],
              "SD": ["mean"],
              "Mean": ["mean"]
            }
        """

        self.n_reads_per_site = n_reads_per_site
        self.groups = reads_df.groupby(['transcript_id', 'transcript_position'])
        self.bags = list(self.groups.groups.keys())
        self.reads_df = reads_df
        # self.use_delta = True  # <--- toggle delta features on read levels

        # -----------------------------
        # Feature toggle switches
        # -----------------------------
        # Comment/uncomment entries to include/exclude features
        self.read_feature_flags = {
            "numeric": True,   # PreTime..PostMean
            "7mer": True,      # 7mer embedding at read-level
            "delta": True,
        }
        self.site_feature_flags = {
            "numeric_aggs": True,  # aggregated Time/SD/Mean stats
            "7mer": True,          # site-level 7mer embedding
        }

        # Default aggregation if not passed
        self.agg_config = agg_config or {
            "Time": ["min", "max", "mean", "25", "75"],
            "SD": ["mean"],
            "Mean": ["mean"]
        }

        # Bag lengths
        self.bag_lengths = {k: len(v) for k, v in self.groups}

        # Labels
        self.bag_labels = {}
        for k in self.bags:
            g = self.groups.get_group(k)
            self.bag_labels[k] = int(g['label'].iloc[0])

        # -----------------------------
        # Infer dimensions
        # -----------------------------
        dummy_bag = self[0]
        _, _, _, _, _ = dummy_bag
        print(f"âœ… Dataset initialized: read_dim = {self.read_dim}, site_dim = {self.site_dim}")

    def __len__(self):
        return len(self.bags)

    def __getitem__(self, idx):
        tid, pos = self.bags[idx]
        g = self.groups.get_group((tid, pos))

        # -----------------------------
        # Read-level matrix
        # -----------------------------
        read_parts = []

        if self.read_feature_flags["numeric"]:
            numeric_feats = g[[
                'PreTime','PreSD','PreMean',
                'InTime','InSD','InMean',
                'PostTime','PostSD','PostMean'
            ]].values.astype(np.float32)
            read_parts.append(numeric_feats)

        if self.read_feature_flags["7mer"]:
            kmer_list = list(g['7mer_emb'].values)
            kmer_emb_read = np.vstack(kmer_list).astype(np.float32)
            read_parts.append(kmer_emb_read)


        if self.read_feature_flags["delta"]:
            deltas = []
            # Time deltas
            deltas.append(numeric_feats[:, 3] - numeric_feats[:, 0])  # InTime - PreTime
            deltas.append(numeric_feats[:, 6] - numeric_feats[:, 3])  # PostTime - InTime
            # SD deltas
            deltas.append(numeric_feats[:, 4] - numeric_feats[:, 1])  # InSD - PreSD
            deltas.append(numeric_feats[:, 7] - numeric_feats[:, 4])  # PostSD - InSD
            # Mean deltas
            deltas.append(numeric_feats[:, 5] - numeric_feats[:, 2])  # InMean - PreMean
            deltas.append(numeric_feats[:, 8] - numeric_feats[:, 5])  # PostMean - InMean

            delta_feats = np.stack(deltas, axis=1)  # shape (n_reads, 6)
            read_parts.append(delta_feats)
        
        bag_read_level = np.concatenate(read_parts, axis=1)

        # Random subsampling
        if self.n_reads_per_site is not None and bag_read_level.shape[0] > self.n_reads_per_site:
            idxs = np.random.choice(bag_read_level.shape[0], self.n_reads_per_site, replace=False)
            bag_read_level = bag_read_level[idxs]

        # -----------------------------
        # Site-level vector
        # -----------------------------
        site_parts = []

        if self.site_feature_flags["numeric_aggs"]:
            site_aggs = []
            groups = {"Time": [0, 3, 6], "SD": [1, 4, 7], "Mean": [2, 5, 8]}
            for feat_type, idx_list in groups.items():
                stats = self.agg_config.get(feat_type, [])
                for col_idx in idx_list:
                    vals = numeric_feats[:, col_idx].astype(np.float32)
                    for stat in stats:
                        if stat == "min": site_aggs.append(np.min(vals))
                        elif stat == "max": site_aggs.append(np.max(vals))
                        elif stat == "mean": site_aggs.append(np.mean(vals))
                        elif stat == "25": site_aggs.append(np.percentile(vals, 25))
                        elif stat == "75": site_aggs.append(np.percentile(vals, 75))
            site_parts.append(np.array(site_aggs, dtype=np.float32))

        if self.site_feature_flags["7mer"]:
            first_kmer = np.asarray(g['7mer_emb'].iloc[0], dtype=np.float32).ravel()
            site_parts.append(first_kmer)
            

        bag_site_level = np.concatenate(site_parts, axis=0).astype(np.float32)

        # -----------------------------
        # Label + bookkeeping
        # -----------------------------
        label = int(g['label'].iloc[0])

        # Store dims once (for printing at init)
        if not hasattr(self, "read_dim"):
            self.read_dim = bag_read_level.shape[1]
            self.site_dim = bag_site_level.shape[0]
            print(f"Read-level dim: {self.read_dim} | Site-level dim: {self.site_dim}")

        return (torch.tensor(bag_read_level, dtype=torch.float32),
                torch.tensor(bag_site_level, dtype=torch.float32),
                torch.tensor(label, dtype=torch.float32),
                tid, pos)

    
print("Ending: 4. Dataset class")

# =========================
# 5. Imbalanced sampler
# =========================
print("Starting: 5. Samplers and Collate")
class ImbalancedBagSampler(Sampler):
    """
    Oversamples positive bags to balance classes, with rebalancing every epoch.
    """
    def __init__(self, dataset, pos_neg_ratio=1.0):
        """
        dataset: your MILReadDataset
        balance_ratio: ratio of positive to negative samples after oversampling.
                       - 1.0 = fully balanced (default)
                       - 0.5 = positives are half as many as negatives
                       - >1.0 = more positives than negatives
        """
        self.dataset = dataset
        self.pos_neg_ratio = pos_neg_ratio
        # Pre-cache labels to avoid repeated dataset access
        self.bags = self.dataset.bags
        self.labels = np.array([self.dataset.bag_labels[k] for k in self.bags], dtype=np.int64)
        self.pos_idx = np.where(self.labels == 1)[0]
        self.neg_idx = np.where(self.labels == 0)[0]

        if len(self.pos_idx) == 0:
            raise ValueError("No positive bags found in dataset; pos_neg_ratio sampling not possible.")
        
    def sample_indices(self):
        n_pos_target = int(len(self.neg_idx) * self.pos_neg_ratio)
        sampled_pos = np.random.choice(self.pos_idx, size=n_pos_target, replace=True)
        combined = np.concatenate([self.neg_idx, sampled_pos])
        np.random.shuffle(combined)
        return combined
    def __iter__(self):
        return iter(self.sample_indices())

    def __len__(self):
        # length in terms of number of sampled bag indices
        return len(self.neg_idx) + int(len(self.neg_idx) * self.pos_neg_ratio)
    
class BucketBatchSampler(Sampler):
    """
    Groups bags of similar lengths (based on n_reads) into buckets,
    shuffles bucket order each epoch, and yields random batches.
    Please at the least set bucket_size value to be multiple of batch_size
    """
    def __init__(self, dataset, base_sampler, batch_size=4, bucket_size=200):
        self.dataset = dataset
        self.base_sampler = base_sampler
        self.batch_size = batch_size
        self.bucket_size = bucket_size

    def __iter__(self):
        # Get sampled indices from ImbalancedBagSampler
        indices = self.base_sampler.sample_indices()
        lengths = np.array([self.dataset.bag_lengths[self.dataset.bags[i]] for i in indices])
        sorted_idx = indices[np.argsort(lengths)]

        # Split into buckets of similar lengths
        buckets = [sorted_idx[i:i+self.bucket_size] for i in range(0, len(sorted_idx), self.bucket_size)]
        np.random.shuffle(buckets)

        lst_of_batch = []
        for bucket in buckets:
            np.random.shuffle(bucket)
            for i in range(0, len(bucket), self.batch_size):
                batch = bucket[i:i+self.batch_size]
                lst_of_batch.append(batch)
        np.random.shuffle(lst_of_batch)

        for batch in lst_of_batch:
            yield batch.tolist()

    def __len__(self):
        total = len(self.base_sampler.neg_idx) + int(len(self.base_sampler.neg_idx) * self.base_sampler.pos_neg_ratio)
        return max(1, total // self.batch_size)

# Collate function for minimal padding
def collate_fn(batch):
    """
    batch: list of tuples (bag_read_level, bag_site_level, label, tid, pos)
    returns:
      padded_feats: (B, max_len, feat_dim)
      site_level_tensor: (B, site_dim)
      labels: (B,)
      tids, positions: lists
    """
    bag_read_level_list, bag_site_level_list, labels, tids, positions = zip(*batch)
    
    # pad read-level bags
    max_len = max(x.shape[0] for x in bag_read_level_list)
    feat_dim = bag_read_level_list[0].shape[1]
    padded_feats = torch.zeros(len(bag_read_level_list), max_len, feat_dim, dtype=torch.float32)
    for i, nf in enumerate(bag_read_level_list):
        padded_feats[i, :nf.shape[0], :] = nf

    # stack site-level vectors (they must be same length across batch)
    site_level_tensor = torch.stack([torch.as_tensor(x, dtype=torch.float32) for x in bag_site_level_list])

    labels = torch.stack(labels)

    return padded_feats, site_level_tensor, labels, tids, positions


print("Ending: 5. Samplers and Collate")


Starting: 4. Dataset class
Ending: 4. Dataset class
Starting: 5. Samplers and Collate
Ending: 5. Samplers and Collate


In [3]:
# =========================
# 6. Split Train/Val/Test
# =========================
print("Starting: 6. Split Train/Val/Test")

print("Loading in processed_dataset.parquet")
reads_df = pd.read_parquet("Dataset/processed_dataset.parquet")

# Now split by set_type
train_df = reads_df[reads_df['set_type'] == 'Train']
val_df   = reads_df[reads_df['set_type'] == 'Val']
test_df  = reads_df[reads_df['set_type'] == 'Test']


agg_config = { "Time": ["min", "max", "mean", "25", "75"]
              , "SD": ["mean"]
              , "Mean": ["mean"] 
              } 

train_ds = MILReadDataset(train_df, n_reads_per_site=20, agg_config=agg_config) 
val_ds = MILReadDataset(val_df, n_reads_per_site=20, agg_config=agg_config) 
test_ds = MILReadDataset(test_df, n_reads_per_site=20, agg_config=agg_config)

#train_loader is inside Epoch Loop so as to randomise the positve bags that are oversampled
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)
print("Ending: 6. Split Train/Val/Test")

Starting: 6. Split Train/Val/Test
Loading in processed_dataset.parquet
Read-level dim: 31 | Site-level dim: 37
âœ… Dataset initialized: read_dim = 31, site_dim = 37
Read-level dim: 31 | Site-level dim: 37
âœ… Dataset initialized: read_dim = 31, site_dim = 37
Read-level dim: 31 | Site-level dim: 37
âœ… Dataset initialized: read_dim = 31, site_dim = 37
Ending: 6. Split Train/Val/Test


In [None]:
# reads_df = pd.read_parquet("Dataset/processed_dataset.parquet")


# agg_config = { "Time": ["min", "max", "mean", "25", "75"]
#               , "SD": ["mean"]
#               , "Mean": ["mean"] 
#               } 


# test_ds = MILReadDataset(test_df, n_reads_per_site=20, agg_config=agg_config)
# test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)

Read-level dim: 31 | Site-level dim: 37
âœ… Dataset initialized: read_dim = 31, site_dim = 37


In [15]:
# =========================
# 7. Attention MIL Model
# =========================
print("Starting: 7. Attention MIL Model")
class MultiHeadAttentionPool(nn.Module):
    def __init__(self, hidden_dim, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        # project to head space then scalar score per head
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.head_score = nn.Linear(hidden_dim, n_heads)  # outputs (batch, n_instances, n_heads)
    
    def forward(self, H):  # H: (B, N, hidden_dim)
        # Optionally nonlinearity
        S = torch.tanh(self.proj(H))            # (B, N, hidden_dim)
        scores = self.head_score(S)             # (B, N, n_heads)
        attn = torch.softmax(scores, dim=1)     # softmax over instances per head
        # attn: (B, N, n_heads). compute per-head pooled vectors:
        # transpose H to (B, hidden_dim, N) to do matmul
        pooled = []
        for h in range(self.n_heads):
            a = attn[..., h].unsqueeze(-1)      # (B, N, 1)
            m = torch.sum(a * H, dim=1)         # (B, hidden_dim)
            pooled.append(m)
        # concat head outputs
        M = torch.cat(pooled, dim=1)           # (B, hidden_dim * n_heads)
        return M, attn                         # attn shape (B, N, n_heads)

class AttentionMIL_v2(nn.Module):
    def __init__(self, read_dim, site_dim, hidden_dim=128, n_heads=4, dropout=0.2):
        super().__init__()
        self.read_encoder = nn.Sequential(
            nn.Linear(read_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.site_encoder = nn.Sequential(
            nn.Linear(site_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.pool = MultiHeadAttentionPool(hidden_dim, n_heads=n_heads)
        # project pooled concat (hidden_dim * n_heads) back to hidden_dim
        self.pool_proj = nn.Linear(hidden_dim * n_heads, hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, bag_read_level, bag_site_level):
        # bag_read_level: (B,N,read_dim)
        H = self.read_encoder(bag_read_level)   # (B,N,hidden_dim)
        M_concat, attn = self.pool(H)           # (B, hidden_dim * n_heads)
        M = self.pool_proj(M_concat)            # (B, hidden_dim)
        site = self.site_encoder(bag_site_level) # (B, hidden_dim)
        combined = torch.cat([M, site], dim=-1) # (B, hidden_dim*2)
        out = self.classifier(combined).view(-1)
        return out, attn


class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # BCE with logits
        bce_loss = F.binary_cross_entropy_with_logits(
            logits, targets.float(), reduction="none"
        )
        # Get probabilities
        probs = torch.sigmoid(logits)
        pt = probs * targets + (1 - probs) * (1 - targets)  # p_t

        # Focal loss factor
        focal_factor = (1 - pt) ** self.gamma

        # Apply alpha and focal scaling
        loss = self.alpha * focal_factor * bce_loss

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss
        
def predict_probs(model, bag_read_level, bag_site_level):
    logits, attn = model(bag_read_level, bag_site_level)
    return torch.sigmoid(logits), attn

# --------------------------
# Define model + optimizer
# --------------------------
# Get read and site dimensions from dataset
read_dim = train_ds.read_dim
site_dim = train_ds.site_dim

print(f"Automatically detected read_dim={read_dim}, site_dim={site_dim}")

# Initialize model using the inferred dimensions
model = AttentionMIL_v2(read_dim=read_dim, site_dim=site_dim, hidden_dim=128, dropout=0.2).to(device)

# change the lr to have higher penalisation
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print("Ending: 7. Attention MIL Model")


Starting: 7. Attention MIL Model
Automatically detected read_dim=31, site_dim=37
Ending: 7. Attention MIL Model


In [16]:
# =========================
# 8. Training Loop
# =========================
print("Starting: 8. Training Loop")

# ---- Hyperparameters ----
n_epochs = 300
early_stop_patience = 20  # stop if no PR-AUC improvement for defined epochs
best_val_pr = 0.0
patience_counter = 0

# ---- Sampling & Weighting Config ----
oversample_ratio = 0.5    # oversample positives to 50% of negatives
effective_ratio  = 1.0    # treat pos:neg equally in loss weighting

# Count raw labels
pos_count = sum(label.item() for _, _, label, _, _ in train_ds)
neg_count = len(train_ds) - pos_count

# Effective counts
effective_pos_count = int(neg_count * oversample_ratio)
effective_neg_count = neg_count

print(f"Original positives: {pos_count}, negatives: {neg_count}")
print(f"Effective positives (oversampling {oversample_ratio*100:.0f}%): {effective_pos_count}")

# Loss weighting
pos_weight_value = (effective_neg_count / max(effective_pos_count, 1)) * effective_ratio
criterion = FocalLoss(alpha=pos_weight_value, gamma=2.0)

# Sampler
sampler = ImbalancedBagSampler(train_ds, pos_neg_ratio=oversample_ratio)


# ---- Tracking Metrics ----
metrics_dict = {
    'epoch': [],
    'train_loss': [],
    'val_roc_auc': [],
    'val_pr_auc': [],
    'epoch_time_sec': [],
    'avg_time_per_bag': []
}

for epoch in range(n_epochs):
    epoch_start = time.time()

    # Rebuild dataloader and batch order each epoch with a fresh sampler
    # Please reduce just use bucket_size that is a multiple of the batch_size
    batch_sampler = BucketBatchSampler(train_ds, sampler, batch_size=20, bucket_size=200)
    train_loader = DataLoader(train_ds, batch_sampler=batch_sampler, collate_fn=collate_fn)

    model.train()
    total_loss = 0.0
    n_bags = 0
    for bag_read_level, bag_site_level, label, _, _ in train_loader:
        n_bags += 1
        bag_read_level = bag_read_level.to(device)
        bag_site_level = bag_site_level.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        outs, attns = model(bag_read_level, bag_site_level)
        
        loss = criterion(outs, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if n_bags % 50000 == 0:
            elapsed = time.time() - epoch_start
            # print(f"  Processed {n_bags} bags, avg time per bag: {elapsed / n_bags:.4f}s")

    total_loss /= len(train_loader)

    # Validation
    model.eval()
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for bag_read_level, bag_site_level, label, _, _ in val_loader:
            bag_read_level = bag_read_level.to(device)
            bag_site_level = bag_site_level.to(device)
            label = label.to(device).view(-1)

            # Forward pass
            out, _ = predict_probs(model, bag_read_level, bag_site_level)

            # Convert predictions and labels to list for metrics calculation
            val_preds.extend(out.detach().cpu().numpy())
            val_labels.extend(label.detach().cpu().numpy())
            
    # ---- Compute Metrics ----        
    val_auc = roc_auc_score(val_labels, val_preds)
    val_pr  = average_precision_score(val_labels, val_preds)

    # End timer
    epoch_time = time.time() - epoch_start
    avg_time_per_bag = epoch_time / n_bags

    # Log
    metrics_dict['epoch'].append(epoch+1)
    metrics_dict['train_loss'].append(total_loss)
    metrics_dict['val_roc_auc'].append(val_auc)
    metrics_dict['val_pr_auc'].append(val_pr)
    metrics_dict['epoch_time_sec'].append(epoch_time)
    metrics_dict['avg_time_per_bag'].append(avg_time_per_bag)

    print(f"Epoch {epoch+1}/{n_epochs}: "
          f"TrainLoss={total_loss:.4f}, "
          f"ValROC-AUC={val_auc:.4f}, "
          f"ValPR-AUC={val_pr:.4f}, "
          f"EpochTime={epoch_time:.2f}s,"
          f"AvgTimePerBag={avg_time_per_bag:.4f}s")

    # ---- Early Stopping & Model Saving ----
    if val_pr > best_val_pr:
        best_val_pr = val_pr
        patience_counter = 0
        torch.save(model.state_dict(), f"Models/epoch{epoch+1}_valpr{val_pr:.2f}.pth")
        model_name = f"epoch{epoch+1}_valpr{val_pr:.2f}.pth"
        print(f"Saved best model (ValPR-AUC improved to {val_pr:.4f}) at epoch {epoch+1}")
        
    else:
        patience_counter += 1
        print(f"No improvement in PR-AUC for {patience_counter} epoch(s).")

        if patience_counter >= early_stop_patience:
            print(f"Early stopping at epoch {epoch+1} (no PR-AUC improvement for {early_stop_patience} epochs).")
            break

print("Ending: 8. Training Loop")




Starting: 8. Training Loop


Original positives: 4382.0, negatives: 91455.0
Effective positives (oversampling 50%): 45727
Epoch 1/300: TrainLoss=0.2333, ValROC-AUC=0.8882, ValPR-AUC=0.3790, EpochTime=336.47s,AvgTimePerBag=0.0490s
Saved best model (ValPR-AUC improved to 0.3790) at epoch 1
Epoch 2/300: TrainLoss=0.2147, ValROC-AUC=0.8900, ValPR-AUC=0.3676, EpochTime=364.28s,AvgTimePerBag=0.0531s
No improvement in PR-AUC for 1 epoch(s).
Epoch 3/300: TrainLoss=0.2096, ValROC-AUC=0.8883, ValPR-AUC=0.3915, EpochTime=363.05s,AvgTimePerBag=0.0529s
Saved best model (ValPR-AUC improved to 0.3915) at epoch 3
Epoch 4/300: TrainLoss=0.2075, ValROC-AUC=0.8837, ValPR-AUC=0.3659, EpochTime=366.40s,AvgTimePerBag=0.0534s
No improvement in PR-AUC for 1 epoch(s).
Epoch 5/300: TrainLoss=0.2045, ValROC-AUC=0.8840, ValPR-AUC=0.3897, EpochTime=381.94s,AvgTimePerBag=0.0557s
No improvement in PR-AUC for 2 epoch(s).
Epoch 6/300: TrainLoss=0.2025, ValROC-AUC=0.8871, ValPR-AUC=0.3948, EpochTime=395.94s,AvgTimePerBag=0.0577s
Saved best model (

### Get the training set metrics

In [None]:
# Assuming metrics_dict already exists
# Convert the dictionary to a DataFrame
df = pd.DataFrame(metrics_dict)

# Define the file path for CSV
file_path = 'metrics.csv'

# Save the DataFrame to a CSV file
df.to_csv(file_path, index=False)

print(f"Dictionary saved to {file_path} as CSV.")

In [16]:
def load_and_eval_models(model_path, 
                         mode="both", 
                         top_k=5, 
                         device="cpu", 
                         test_loader=None):
    """
    Load models (single best, top-k ensemble) and evaluate on test set.
    Saves predictions separately for each mode.
    """

    # --- Parse saved models ---
    pattern = r"epoch\d+_valpr(\d+\.\d+)\.pth"
    models = []
    for f in os.listdir(model_path):
        match = re.match(pattern, f)
        if match:
            pr_auc = float(match.group(1))
            models.append((f, pr_auc))
    if not models:
        raise ValueError("No valid model files found in directory.")

    # --- Sort descending by PR-AUC ---
    models.sort(key=lambda x: x[1], reverse=True)

    results = {}

    # --- Load single best ---
    if mode in ["single", "both"]:
        best_model_file, best_pr = models[0]
        model = AttentionMIL_v2(
            read_dim=test_loader.dataset.read_dim,
            site_dim=test_loader.dataset.site_dim,
            hidden_dim=128,
            dropout=0.2,
            n_heads=4
        ).to(device)
        model.load_state_dict(torch.load(f"{model_path}/{best_model_file}", map_location=device))
        model.eval()
        results["single_model"] = model
        results["single_name"] = best_model_file

    # --- Load top-k ensemble ---
    if mode in ["topk", "both"]:
        top_models = models[:top_k]
        ensemble = []
        for m_name, pr in top_models:
            m = AttentionMIL_v2(
                read_dim=test_loader.dataset.read_dim,
                site_dim=test_loader.dataset.site_dim,
                hidden_dim=128,
                dropout=0.2,
                n_heads=4
            ).to(device)
            m.load_state_dict(torch.load(f"{model_path}/{m_name}", map_location=device))
            m.eval()
            ensemble.append(m)
        results["ensemble_models"] = ensemble
        results["ensemble_names"] = [m[0] for m in top_models]

    # --- Evaluation ---
    if test_loader is not None:
        y_true, y_pred_single, y_pred_ens = [], [], []
        single_rows, ens_rows = [], []

        with torch.no_grad():
            for bag_read_level, bag_site_level, label, tid, pos in test_loader:
                bag_read_level = bag_read_level.to(device)
                bag_site_level = bag_site_level.to(device)
                label = label.to(device).view(-1)

                # Single best model
                if "single_model" in results:
                    out, _ = results["single_model"](bag_read_level, bag_site_level)
                    prob = torch.sigmoid(out).item()
                    y_pred_single.append(prob)
                    single_rows.append({
                        'transcript_id': tid[0],
                        'transcript_position': pos.item(),
                        'score': prob
                    })

                # Ensemble
                if "ensemble_models" in results:
                    preds = []
                    for m in results["ensemble_models"]:
                        out, _ = m(bag_read_level, bag_site_level)
                        preds.append(torch.sigmoid(out).item())
                    avg_prob = sum(preds) / len(preds)
                    y_pred_ens.append(avg_prob)
                    ens_rows.append({
                        'transcript_id': tid[0],
                        'transcript_position': pos.item(),
                        'score': avg_prob
                    })

                y_true.append(label.item())

        # Compute and save results
        if y_pred_single:
            roc = roc_auc_score(y_true, y_pred_single)
            pr  = average_precision_score(y_true, y_pred_single)
            print(f"[Single Best] ROC-AUC={roc:.4f}, PR-AUC={pr:.4f}")
            pd.DataFrame(single_rows).to_csv("Results/best_single_test_data_output.csv", index=False)

        if y_pred_ens:
            roc = roc_auc_score(y_true, y_pred_ens)
            pr  = average_precision_score(y_true, y_pred_ens)
            print(f"[Top-{top_k} Ensemble] ROC-AUC={roc:.4f}, PR-AUC={pr:.4f}")
            pd.DataFrame(ens_rows).to_csv(f"Results/top_{top_k}_ensemble_test_data_output.csv", index=False)


In [19]:
model_path = "Models/"

load_and_eval_models(model_path = model_path,mode = "both", top_k = 5,test_loader = test_loader)

[Single Best] ROC-AUC=0.9006, PR-AUC=0.4615
[Top-5 Ensemble] ROC-AUC=0.8972, PR-AUC=0.4618


In [20]:
model_path = "Models/"

load_and_eval_models(model_path = model_path,mode = "both", top_k = 3,test_loader = test_loader)

[Single Best] ROC-AUC=0.9092, PR-AUC=0.4593
[Top-3 Ensemble] ROC-AUC=0.9070, PR-AUC=0.4529


---