# Multiple Instance Learning (MIL) for Read-Level RNA Site Classification
*A PyTorch implementation with bag-level training, imbalanced sampling, and efficient batching.*

## 1. Introduction

In this notebook, we implement a **Multiple Instance Learning (MIL)** framework for read-level RNA site classification.

- Each *bag* corresponds to a transcript site (defined by `transcript_id` and `transcript_position`).
- Each *instance* corresponds to a single sequencing read belonging to that site.
- Each bag has a **single binary label** (`0` or `1`), while individual reads are unlabeled.

We train the model to predict the bag label by aggregating per-read predictions using a pooling mechanism.

In [None]:
# =========================
# 1. Imports & Setup
# =========================
# === Standard Libraries ===
import os
import re
import time
import random
import pickle

# === Data Manipulation ===
import pandas as pd
import numpy as np

# === PyTorch ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

# === Sklearn Metrics & Utilities ===
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_recall_curve
)
from sklearn.preprocessing import StandardScaler

# === Visualization ===
import matplotlib.pyplot as plt

# Reproducibility
RNG = 42
np.random.seed(RNG)
torch.manual_seed(RNG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

### Configuration & Parameters

This section defines all dataset paths, model hyperparameters, and training options.  
Adjust these values as needed, then simply **Run All** to execute the full pipeline.


In [None]:
# =========================
# 0. Configuration & Parameters
# =========================

# ---- General Settings ----
RNG = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Data Paths ----
DATA_DIR = "../Dataset"
DATA_FILE = "processed_dataset_lat2_432_epoch600_trained.parquet"
MODEL_DIR = "Models"
RESULT_DIR = "Results"

# ---- Data Sampling ----
N_READS_PER_SITE = 20
USE_DELTA = False
NUMERIC_SCALER = None   # None uses StandardScalar(), else put in other scaler function
KMER_SCALAR = None      # None uses StandardScalar(), else put in other scaler function
POS_NEG_RATIO = 0.5     # oversample positives to 50% of negatives
EFFECTIVE_RATIO = 1.0   # weighting ratio between pos:neg

# ---- Model Parameters ----
HIDDEN_DIM = 128
DROPOUT = 0.2
POOLING_METHOD = "noisy-or"  # ["noisy-or", "mean", "max"]

# ---- Optimizer & Training ----
LR = 1e-3
N_EPOCHS = 300
EARLY_STOP_PATIENCE = 20
BATCH_SIZE = 200
BUCKET_SIZE = 1000

# ---- Evaluation ----
SAVE_PR_CURVE = True
MODEL_SELECTION_METRIC = "valscore"  # options: valscore, valpr, rocauc

# Create output directories
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)


## 2. Dataset Overview

Each read is represented by numeric features (signal statistics) and sequence-based embeddings (5-mer context).

**Columns:**
- `transcript_id`, `transcript_position` → define bag identity
- `label` → bag-level binary label
- Numeric features: `PreTime`, `PreSD`, `PreMean`, ..., `PostMean`
- K-mer embeddings: `Pre_5mer_*`, `In_5mer_*`, `Post_5mer_*`

The processed dataset is stored in:
DATA_DIR/DATA_FILE

In [None]:
reads_df = pd.read_parquet(f"{DATA_DIR}/{DATA_FILE}")

## 3. Dataset Class: MILReadDataset

This class organizes reads into bags and scales their features, with an **optional delta feature toggle**.

**Key Features:**
- Groups reads by `(transcript_id, transcript_position)` to form bags.
- Applies `StandardScaler` normalization to numeric and k-mer features.
- Supports **optional delta computation** between pre/in/post signal windows:
  - Controlled by `use_delta` flag in the constructor.
- Returns for each bag:
  - `bag_read_level`: tensor of shape `(n_reads, feature_dim)`
  - `label`: scalar bag label
  - `tid`, `pos`: identifiers for tracking

**Returned sample structure:**
```python
(
  torch.tensor(bag_read_level, dtype=torch.float32),
  torch.tensor(label, dtype=torch.float32),
  tid, pos
)

In [None]:
# =========================
# 3. Dataset class
# =========================
print("Starting: 3. Dataset class")

class MILReadDataset(Dataset):
    def __init__(self, reads_df, n_reads_per_site=None, numeric_scaler=None, kmer_scaler=None, use_delta=False):
        """
        reads_df: DataFrame with read-level rows:
                  ['transcript_id','transcript_position','label',
                   'PreTime','PreSD','PreMean','InTime','InSD','InMean',
                   'PostTime','PostSD','PostMean','7mer_emb','7mer_emb_reduced']
        n_reads_per_site: int or None
            - int: max number of reads per bag
        numeric_scaler: fitted StandardScaler for numeric features (if None, fit on full dataframe)
        kmer_scaler: fitted StandardScaler for kmer embeddings (if None, fit on full dataframe)
        use_delta: bool
            - If True, adds delta (difference) features between Pre/In/Post regions
        """
        self.n_reads_per_site = n_reads_per_site
        self.use_delta = use_delta
        self.groups = reads_df.groupby(['transcript_id', 'transcript_position'])
        self.bags = list(self.groups.groups.keys())
        self.reads_df = reads_df.copy()

        # -----------------------------
        # Define numeric and kmer columns
        # -----------------------------
        self.numeric_cols = [
            'PreTime','PreSD','PreMean','InTime','InSD','InMean','PostTime','PostSD','PostMean'
        ]
        self.kmer_cols = [col for col in reads_df.columns if col.startswith(('Pre_5mer_', 'In_5mer_', 'Post_5mer_'))]

        # -----------------------------
        # Fit or use provided scalers
        # -----------------------------
        if numeric_scaler is None:
            self.numeric_scaler = StandardScaler()
            self.reads_df[self.numeric_cols] = self.numeric_scaler.fit_transform(self.reads_df[self.numeric_cols])
        else:
            self.numeric_scaler = numeric_scaler
            self.reads_df[self.numeric_cols] = self.numeric_scaler.transform(self.reads_df[self.numeric_cols])

        if kmer_scaler is None:
            self.kmer_scaler = StandardScaler()
            self.reads_df[self.kmer_cols] = self.kmer_scaler.fit_transform(self.reads_df[self.kmer_cols])
        else:
            self.kmer_scaler = kmer_scaler
            self.reads_df[self.kmer_cols] = self.kmer_scaler.transform(self.reads_df[self.kmer_cols])

        # -----------------------------
        # Bag lengths and labels
        # -----------------------------
        self.bag_lengths = {k: len(v) for k, v in self.groups}
        self.bag_labels = {}
        for k in self.bags:
            g = self.groups.get_group(k)
            self.bag_labels[k] = int(g['label'].iloc[0])

        # Infer read feature dimension
        dummy_bag = self[0]
        bag_feats, _, _, _ = dummy_bag
        self.read_dim = bag_feats.shape[1]
        print(f"✅ MILReadDataset initialized: read_dim={self.read_dim}")

    def __len__(self):
        return len(self.bags)

    def __getitem__(self, idx):
        tid, pos = self.bags[idx]
        g = self.groups.get_group((tid, pos))

        read_parts = []

        # Numeric features (already scaled)
        numeric_feats = g[self.numeric_cols].values.astype(np.float32)
        read_parts.append(numeric_feats)

        # Optional delta features
        if self.use_delta:
            deltas = []
            deltas.append(numeric_feats[:, 3] - numeric_feats[:, 0])  # InTime - PreTime
            deltas.append(numeric_feats[:, 6] - numeric_feats[:, 3])  # PostTime - InTime
            deltas.append(numeric_feats[:, 4] - numeric_feats[:, 1])  # InSD - PreSD
            deltas.append(numeric_feats[:, 7] - numeric_feats[:, 4])  # PostSD - InSD
            deltas.append(numeric_feats[:, 5] - numeric_feats[:, 2])  # InMean - PreMean
            deltas.append(numeric_feats[:, 8] - numeric_feats[:, 5])  # PostMean - InMean
            delta_feats = np.stack(deltas, axis=1).astype(np.float32)
            read_parts.append(delta_feats)

        # Embedded 5-mer features
        kmer_feats = g[self.kmer_cols].values.astype(np.float32)
        read_parts.append(kmer_feats)

        # Combine all read-level features
        bag_read_level = np.concatenate(read_parts, axis=1)

        # Random subsampling
        if self.n_reads_per_site is not None and bag_read_level.shape[0] > self.n_reads_per_site:
            idxs = np.random.choice(bag_read_level.shape[0], self.n_reads_per_site, replace=False)
            bag_read_level = bag_read_level[idxs]

        label = int(g['label'].iloc[0])

        return (torch.tensor(bag_read_level, dtype=torch.float32),
                torch.tensor(label, dtype=torch.float32),
                tid, pos)

print("Ending: 3. Dataset class")

## 4. Bucketed Batch Sampling and Collation

Bags vary in the number of reads. To improve computational efficiency:

- **BucketBatchSampler** sorts bags by length and groups them into "buckets" of similar length.
- **Within each bucket:** random shuffling ensures stochasticity.
- **Collate function:** dynamically pads all bags in a batch to the maximum bag length, minimizing wasted computation.

This approach mimics variable-length batching in NLP models.


In [None]:
# =========================
# 4. Imbalanced sampler
# =========================
print("Starting: 4. Samplers and Collate")
class ImbalancedBagSampler(Sampler):
    """
    Oversamples positive bags to balance classes, with rebalancing every epoch.
    """
    def __init__(self, dataset, pos_neg_ratio=1.0):
        """
        dataset: your MILReadDataset
        balance_ratio: ratio of positive to negative samples after oversampling.
                       - 1.0 = fully balanced (default)
                       - 0.5 = positives are half as many as negatives
                       - >1.0 = more positives than negatives
        """
        self.dataset = dataset
        self.pos_neg_ratio = pos_neg_ratio
        # Pre-cache labels to avoid repeated dataset access
        self.bags = self.dataset.bags
        self.labels = np.array([self.dataset.bag_labels[k] for k in self.bags], dtype=np.int64)
        self.pos_idx = np.where(self.labels == 1)[0]
        self.neg_idx = np.where(self.labels == 0)[0]

        if len(self.pos_idx) == 0:
            raise ValueError("No positive bags found in dataset; pos_neg_ratio sampling not possible.")
        
    def sample_indices(self):
        n_pos_target = int(len(self.neg_idx) * self.pos_neg_ratio)
        sampled_pos = np.random.choice(self.pos_idx, size=n_pos_target, replace=True)
        combined = np.concatenate([self.neg_idx, sampled_pos])
        np.random.shuffle(combined)
        return combined
    def __iter__(self):
        return iter(self.sample_indices())

    def __len__(self):
        # length in terms of number of sampled bag indices
        return len(self.neg_idx) + int(len(self.neg_idx) * self.pos_neg_ratio)
    
class BucketBatchSampler(Sampler):
    """
    Groups bags of similar lengths (based on n_reads) into buckets,
    shuffles bucket order each epoch, and yields random batches.
    Please at the least set bucket_size value to be multiple of batch_size
    """
    def __init__(self, dataset, base_sampler, batch_size=50, bucket_size=200):
        self.dataset = dataset
        self.base_sampler = base_sampler
        self.batch_size = batch_size
        self.bucket_size = bucket_size

    def __iter__(self):
        # Get sampled indices from ImbalancedBagSampler
        indices = self.base_sampler.sample_indices()
        lengths = np.array([self.dataset.bag_lengths[self.dataset.bags[i]] for i in indices])
        sorted_idx = indices[np.argsort(lengths)]

        # Split into buckets of similar lengths
        buckets = [sorted_idx[i:i+self.bucket_size] for i in range(0, len(sorted_idx), self.bucket_size)]
        np.random.shuffle(buckets)

        lst_of_batch = []
        for bucket in buckets:
            np.random.shuffle(bucket)
            for i in range(0, len(bucket), self.batch_size):
                batch = bucket[i:i+self.batch_size]
                lst_of_batch.append(batch)
        np.random.shuffle(lst_of_batch)

        for batch in lst_of_batch:
            yield batch.tolist()

    def __len__(self):
        total = len(self.base_sampler.neg_idx) + int(len(self.base_sampler.neg_idx) * self.base_sampler.pos_neg_ratio)
        return max(1, total // self.batch_size)

# Collate function for minimal padding
def collate_fn(batch):
    """
    batch: list of tuples (bag_read_level, label, tid, pos)
    Returns:
      padded_feats: (B, max_len, feat_dim)
      labels: (B,)
      tids, positions
    """
    bag_read_level_list, labels, tids, positions = zip(*batch)

    max_len = max(x.shape[0] for x in bag_read_level_list)
    feat_dim = bag_read_level_list[0].shape[1]
    padded_feats = torch.zeros(len(bag_read_level_list), max_len, feat_dim, dtype=torch.float32)
    for i, nf in enumerate(bag_read_level_list):
        padded_feats[i, :nf.shape[0], :] = nf

    labels = torch.stack(labels)

    return padded_feats, labels, tids, positions

print("Ending: 4. Samplers and Collate")

## 5. Model Architecture: MILReadOnly

This model applies a *read-level encoder* followed by a *bag-level pooling mechanism*.

### **Components**
- **ReadEncoder:** 3-layer residual MLP with LayerNorm and GELU activations.
- **Bag Predictor:** Linear layer projecting read embeddings → per-read logits.

### **Pooling Options**
1. **Noisy-OR** – probabilistic aggregation:
   \[
   P(\text{bag positive}) = 1 - \prod_{i=1}^N (1 - P_i)
   \]
2. **Mean pooling** – average of read probabilities  
3. **Max pooling** – selects the most confident read probability

### **Output Flow**
| Step | Shape | Description |
|------|--------|-------------|
| Reads | (B, N, read_dim) | Input reads per bag |
| Encoded reads | (B, N, hidden_dim) | Instance representations |
| Read logits | (B, N) | Read-level scores |
| Bag probability | (B,) | Aggregated bag prediction |


In [None]:
# =========================
# 5. Attention MIL Model
# =========================
print("Starting: 5. Attention MIL Model")
class ReadEncoder(nn.Module):
    def __init__(self, read_dim, hidden_dim=128, dropout=0.2):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Linear(read_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.block2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.block3 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Residual connections
        h = self.block1(x)
        h = self.block2(h) + h
        h = self.block3(h) + h
        return h
class MILReadOnly(nn.Module):
    def __init__(self, read_dim, hidden_dim=128, dropout=0.2, pooling="noisy-or"):
        """
        read_dim: dimension of per-read feature vector
        hidden_dim: hidden dim for read encoder
        dropout: dropout probability
        pooling: str, one of ["noisy-or", "mean", "max"]
        """
        super().__init__()
        assert pooling in ["noisy-or", "mean", "max", "learnable"], \
            "pooling must be one of 'noisy-or', 'mean', 'max'"
        self.pooling_type = pooling

        # Per-read encoder
        self.read_encoder = ReadEncoder(read_dim, hidden_dim, dropout)


        # Bag-level predictor
        self.bag_predictor = nn.Linear(hidden_dim, 1)

    def forward(self, bag_read_level):
        """
        bag_read_level: (B, N, read_dim)
        Returns:
            bag_logits: (B,)
            read_probs or attention weights: (B, N)
        """
        B, N, _ = bag_read_level.shape

        # Encode reads
        read_feats = self.read_encoder(bag_read_level)  # (B, N, hidden_dim)

        # Per-read logits for pooling methods
        read_logits = self.bag_predictor(read_feats).squeeze(-1)  # (B, N)
        read_probs = torch.sigmoid(read_logits)

        if self.pooling_type == "noisy-or":
            bag_probs = 1 - torch.prod(1 - read_probs, dim=1)
        elif self.pooling_type == "mean":
            bag_probs = read_probs.mean(dim=1)
        elif self.pooling_type == "max":
            bag_probs, _ = read_probs.max(dim=1)

        attn_weights = read_probs  # just to return something consistent
        bag_logits = torch.log(bag_probs / (1 - bag_probs + 1e-7) + 1e-7)

        return bag_logits, attn_weights

# ---------------------------
# Focal Loss (unchanged)
# ---------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction="none")
        probs = torch.sigmoid(logits)
        pt = probs * targets + (1 - probs) * (1 - targets)
        focal_factor = (1 - pt) ** self.gamma
        loss = self.alpha * focal_factor * bce_loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss
        
def predict_probs(model, bag_read_level):
    logits, read_probs = model(bag_read_level)
    return torch.sigmoid(logits), read_probs

print("Ending: 5. Attention MIL Model")


## 6a. Split Train / Val / Test

This step loads the dataset from disk and splits it into Train, Validation, and Test subsets using the set_type column.

The splits are wrapped in MILReadDataset objects.

train_loader is generated fresh each epoch (inside the training loop) to support dynamic oversampling of positive samples.

val_loader and test_loader are initialized once with batch_size=1

In [None]:
# =========================
# 6a. Split Train/Val/Test 
# =========================
print("Starting: 6a. Split Train/Val/Test")

reads_df = pd.read_parquet(f"../Dataset/{DATA_FILE}")
print(f"Loading in {DATA_FILE}")

# Now split by set_type
train_df = reads_df[reads_df['set_type'] == 'Train']
val_df   = reads_df[reads_df['set_type'] == 'Val']
test_df  = reads_df[reads_df['set_type'] == 'Test']

train_ds = MILReadDataset(train_df, n_reads_per_site= N_READS_PER_SITE, numeric_scaler=NUMERIC_SCALER, kmer_scaler=KMER_SCALAR, use_delta=USE_DELTA) 
val_ds = MILReadDataset(val_df, n_reads_per_site = N_READS_PER_SITE, numeric_scaler=NUMERIC_SCALER, kmer_scaler=KMER_SCALAR, use_delta=USE_DELTA) 
test_ds = MILReadDataset(test_df, n_reads_per_site = N_READS_PER_SITE, numeric_scaler=NUMERIC_SCALER, kmer_scaler=KMER_SCALAR, use_delta=USE_DELTA)

#train_loader is inside Epoch Loop so as to randomise the positve bags that are oversampled
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)
print("Ending: 6a. Split Train/Val/Test")

## 6b. Define Model and Optimizer

In this step:

The read_dim (number of input features per read) is automatically inferred from the training dataset.

A MILReadOnly model is initialized using this read_dim, a hidden_dim of 128, and dropout of 0.2.

An Adam optimizer is used with a slightly higher learning rate (1e-3) for faster convergence.

In [None]:
# =========================
# 6b. Define model + optimizer
# =========================
# Get read and site dimensions from dataset
print("Starting: 6b. Define model + optimizer")

read_dim = train_ds.read_dim

print(f"Automatically detected read_dim={read_dim}")

# Initialize model using the inferred dimensions
model = MILReadOnly(read_dim=read_dim, hidden_dim=HIDDEN_DIM, dropout=DROPOUT, pooling=POOLING_METHOD).to(device)

# change the lr to have higher penalisation
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print("Ending: 6b. Define model + optimizer")

## 7. Training Pipeline

Each epoch:
1. Rebuilds a new **oversampled training loader** using fresh positive sampling.
2. Trains over all bags using MIL loss.
3. Evaluates validation AUC and PR-AUC.
4. Implements **early stopping** based on sum of ROC-AUC and PR-AUC improvement.
5. Saves the model checkpoint each epoch.

**Key metrics:**
- Training loss
- Validation ROC-AUC
- Validation PR-AUC
- Epoch time and average time per bag


### **Handling Class Imbalance**

RNA modification sites are typically imbalanced — far more negatives than positives.

We use **two strategies** to handle imbalance:

1. **Oversampling** (data-level correction):  
   Implemented via `ImbalancedBagSampler`, which resamples positive bags with replacement
   according to a configurable positive-to-negative ratio (`pos_neg_ratio`).

2. **Loss weighting** (loss-level correction):  
   The `BCEWithLogitsLoss` includes a `pos_weight` parameter to penalize misclassified positives more strongly.

Together, this ensures balanced gradients and effective learning even under strong label imbalance.



In [None]:
# =========================
# 7. Training Loop
# =========================
print("Starting: 7. Training Loop")

# ---- Hyperparameters ----
best_val_score = 0.0
patience_counter = 0


# Count raw labels
pos_count = sum(label.item() for _, label, _, _ in train_ds)
neg_count = len(train_ds) - pos_count

# Effective counts
effective_pos_count = int(neg_count * POS_NEG_RATIO)
effective_neg_count = neg_count

print(f"Original positives: {pos_count}, negatives: {neg_count}")
print(f"Effective positives (oversampling {POS_NEG_RATIO*100:.0f}%): {effective_pos_count}")

# Loss weighting
pos_weight_value = (effective_neg_count / max(effective_pos_count, 1)) * EFFECTIVE_RATIO
pos_weight = torch.tensor(pos_weight_value).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# criterion = FocalLoss(alpha=pos_weight_value, gamma=2.0)

# Sampler
sampler = ImbalancedBagSampler(train_ds, pos_neg_ratio=POS_NEG_RATIO)

# ---- Tracking Metrics ----
metrics_dict = {
    'epoch': [],
    'train_loss': [],
    'val_roc_auc': [],
    'val_pr_auc': [],
    'epoch_time_sec': [],
    'avg_time_per_bag': []
}

for epoch in range(N_EPOCHS):
    epoch_start = time.time()

    # Rebuild dataloader and batch order each epoch with a fresh sampler
    # Please reduce just use bucket_size that is a multiple of the batch_size
    batch_sampler = BucketBatchSampler(train_ds, sampler, batch_size=BATCH_SIZE, bucket_size=BUCKET_SIZE)
    train_loader = DataLoader(train_ds, batch_sampler=batch_sampler, collate_fn=collate_fn)

    model.train()
    total_loss = 0.0
    n_bags = 0
    for bag_read_level, label, _, _ in train_loader:
        n_bags += 1
        bag_read_level = bag_read_level.to(device)
        label = label.to(device).float()

        optimizer.zero_grad()
        outs, attns = model(bag_read_level)
        
        loss = criterion(outs, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # if n_bags % 50000 == 0:
        #     elapsed = time.time() - epoch_start
        #     print(f"  Processed {n_bags} bags, avg time per bag: {elapsed / n_bags:.4f}s")

    total_loss /= len(train_loader)

    # Validation
    model.eval()
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for bag_read_level, label, _, _ in val_loader:
            bag_read_level = bag_read_level.to(device)
            label = label.to(device).view(-1)

            # Forward pass
            out, _ = predict_probs(model, bag_read_level)

            # Convert predictions and labels to list for metrics calculation
            val_preds.extend(out.detach().cpu().numpy())
            val_labels.extend(label.detach().cpu().numpy())
            
    # ---- Compute Metrics ----        
    val_auc = roc_auc_score(val_labels, val_preds)
    val_pr  = average_precision_score(val_labels, val_preds)

    # End timer
    epoch_time = time.time() - epoch_start
    avg_time_per_bag = epoch_time / n_bags

    # Log
    metrics_dict['epoch'].append(epoch+1)
    metrics_dict['train_loss'].append(total_loss)
    metrics_dict['val_roc_auc'].append(val_auc)
    metrics_dict['val_pr_auc'].append(val_pr)
    metrics_dict['epoch_time_sec'].append(epoch_time)
    metrics_dict['avg_time_per_bag'].append(avg_time_per_bag)

    print(f"Epoch {epoch+1}/{N_EPOCHS}: "
          f"TrainLoss={total_loss:.4f}, "
          f"ValROC-AUC={val_auc:.4f}, "
          f"ValPR-AUC={val_pr:.4f}, "
          f"EpochTime={epoch_time:.2f}s,"
          f"AvgTimePerBag={avg_time_per_bag:.4f}s")

    # ---- Early Stopping & Model Saving ----
    val_score = val_auc + val_pr  # combined metric

    torch.save(model.state_dict(), f"{MODEL_DIR}/epoch{epoch+1}_rocauc{val_auc:.3f}_valpr{val_pr:.3f}_valscore{val_score:.3f}.pth")
    print(f"Saved model (ValROC-AUC={val_auc:.4f}, ValPR-AUC={val_pr:.4f}, Score={val_score:.4f}) at epoch {epoch+1}")

    if val_score > best_val_score:
        best_val_score = val_score
        patience_counter = 0
    else:
        patience_counter += 1
        print(f"No improvement in combined score for {patience_counter} epoch(s).")

        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"Early stopping at epoch {epoch+1} (no improvement for {EARLY_STOP_PATIENCE} epochs).")
            break

print("Ending: 7. Training Loop")

Saving training set metrics

In [None]:
# Assuming metrics_dict already exists
# Convert the dictionary to a DataFrame
df = pd.DataFrame(metrics_dict)

# Define the file path for CSV
file_path = 'metrics.csv'

# Save the DataFrame to a CSV file
df.to_csv(file_path, index=False)

print(f"Dictionary saved to {file_path} as CSV.")

## 8. Evaluation Metrics

After training, the model is evaluated using:

- **ROC-AUC (Receiver Operating Characteristic Area Under Curve):**  
  Measures discrimination between positive and negative bags.
  
- **PR-AUC (Precision-Recall Area Under Curve):**  
  More sensitive to class imbalance, especially when positives are rare.

Both are computed at the **bag level** since labels are per bag.


In [None]:
def load_and_eval_model(
    model_path,
    device="cpu",
    test_loader=None,
    manual_model_file=None,
    is_full_test=False,
    save_pr_curve=False,
    selection_metric="valscore"  # NEW
):
    """
    Load and evaluate a model. Auto-selects the best model based on specified metric.
    """

    valid_metrics = {"valscore", "valpr", "rocauc"}
    if selection_metric not in valid_metrics:
        raise ValueError(f"Invalid selection_metric. Choose from {valid_metrics}")

    # --- Get model file ---
    if manual_model_file:
        model_file = manual_model_file
        print(f"Using manually specified model: {model_file}")
    else:
        pattern = r"epoch(\d+)_rocauc([\d.]+)_valpr([\d.]+)_valscore([\d.]+)\.pth"
        models = []

        for f in os.listdir(model_path):
            match = re.match(pattern, f)
            if match:
                epoch = int(match.group(1))
                rocauc = float(match.group(2))
                valpr = float(match.group(3))               
                valscore = float(match.group(4))

                metric_value = {
                    "valscore": valscore,
                    "valpr": valpr,
                    "rocauc": rocauc
                }[selection_metric]

                models.append((f, metric_value, epoch))

        if not models:
            raise ValueError("No valid model files found in directory.")

        # Sort by selected metric descending, then epoch descending
        models.sort(key=lambda x: (x[1], x[2]), reverse=True)
        model_file = models[0][0]

        print(f"Auto-selected best model by '{selection_metric}': {model_file} "
              f"({selection_metric} = {models[0][1]:.4f}, epoch = {models[0][2]})")

    # --- Load the model ---
    model = MILReadOnly(
        read_dim=test_loader.dataset.read_dim,
        hidden_dim=HIDDEN_DIM,
        dropout=DROPOUT,
        pooling=POOLING_METHOD
    ).to(device)

    model.load_state_dict(torch.load(os.path.join(model_path, model_file), map_location=device))
    model.eval()

    # --- Evaluation ---
    y_true = []
    y_pred = []
    output_rows = []

    with torch.no_grad():
        for bag_read_level, label, tid, pos in test_loader:
            bag_read_level = bag_read_level.to(device)
            label = label.to(device).view(-1)

            out, _ = model(bag_read_level)
            prob = torch.sigmoid(out).item()

            y_true.append(label.item())
            y_pred.append(prob)

            output_rows.append({
                "transcript_id": tid[0],
                "transcript_position": pos.item(),
                "score": prob
            })

    # --- Metrics ---
    roc = roc_auc_score(y_true, y_pred)
    pr  = average_precision_score(y_true, y_pred)

    print(f"[Evaluation] ROC-AUC={roc:.4f}, PR-AUC={pr:.4f}")

    # --- Save predictions ---
    model_name = os.path.splitext(model_file)[0]
    output_file = f"{RESULT_DIR}/{'full_test' if is_full_test else 'test'}_data_output_{model_name}.csv"
    pd.DataFrame(output_rows).to_csv(output_file, index=False)
    print(f"Saved predictions to: {output_file}")

    # --- Save PR Curve ---
    if save_pr_curve:
        precision, recall, _ = precision_recall_curve(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        plt.plot(recall, precision, label=f'PR-AUC = {pr:.4f}', color='blue')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.grid(True)
        plt.legend(loc='lower left')
        plt.tight_layout()

        pr_curve_file = f"{RESULT_DIR}/{'full_test' if is_full_test else 'test'}_pr_curve_{model_name}.png"
        plt.savefig(pr_curve_file)
        plt.close()
        print(f"Saved Precision-Recall curve to: {pr_curve_file}")


## 9. Saving and Inference

Each model is saved under the `MODEL_DIR` directory with the naming pattern: {MODEL_DIR}/epoch{epoch_number}_rocauc{val_auc:.3f}_valpr{val_pr:.3f}_valscore{val_score:.3f}.pth

For inference, use:
```python
model.load_state_dict(torch.load(f"{MODEL_DIR}/best_model.pth"))
probs, read_weights = predict_probs(model, bag_tensor)
```
`probs` = bag-level probability

`read_weights` = per-read importance estimates (e.g., for interpretability)

In [None]:
load_and_eval_model(
    MODEL_DIR,
    device=device,
    test_loader=test_loader,
    save_pr_curve=SAVE_PR_CURVE,
    selection_metric=MODEL_SELECTION_METRIC
)


In [None]:
# load_and_eval_model(
#     model_path=MODEL_DIR,
#     device="cuda",
#     test_loader=test_loader,
#     manual_model_file="epoch53_valpr0.42.pth",
#     save_pr_curve= SAVE_PR_CURVE
#     selection_metric="valscore"
# )


In [None]:
print(f"Loading in {DATA_FILE} for testing on full data")
reads_df = pd.read_parquet(f"../Dataset/{DATA_FILE}")

test_full_ds = MILReadDataset(reads_df, n_reads_per_site = N_READS_PER_SITE, numeric_scaler=NUMERIC_SCALER, kmer_scaler=KMER_SCALAR, use_delta=USE_DELTA)

test_full_loader  = DataLoader(test_full_ds, batch_size=1, shuffle=False)


In [None]:
load_and_eval_model(
    model_path=MODEL_DIR,
    device=device,
    test_loader=test_full_loader,
    is_full_test= True,
    save_pr_curve=SAVE_PR_CURVE,
    selection_metric=MODEL_SELECTION_METRIC
)


---
## 10. Summary

✅ **This pipeline implements true Multiple Instance Learning:**
- Each bag (site) consists of multiple unlabeled reads (instances).
- The model outputs read-level probabilities, then pools them into one bag probability.
- The bag prediction is compared against the bag-level label during training.

✅ **Core Features**
| Feature | Description |
|----------|--------------|
| MIL formulation | Proper bag-instance hierarchy |
| Pooling | Noisy-OR / mean / max |
| Imbalance handling | Oversampling + weighted loss |
| Batching | Bucketed variable-length batching |
| Architecture | Read encoder + bag predictor |

This notebook demonstrates a full end-to-end MIL setup for biological read-level data.
