

# ✅ WHAT YOU MUST UPDATE IF TRAINING SCRIPT CHANGES

Any time you change preprocessing or architecture in the **training script**, you must copy over the same edits into this inference script in these sections:

### ✅ 1️⃣ 7-mer encoding

```python
def encode_7mer(...)
```

### ✅ 2️⃣ Numeric feature selection

```python
numeric_feats = g[[ ... ]]
```

### ✅ 3️⃣ Model architecture

```python
class AttentionMIL(...)
```

and `input_dim`, `hidden_dim` if changed.

---


In [11]:
# =========================
# 0. Imports & Setup
# =========================
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from sklearn.metrics import roc_auc_score, average_precision_score
import pickle
import time
import torch.nn.functional as F
import os, re


# Reproducibility
RNG = 42
np.random.seed(RNG)
torch.manual_seed(RNG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [12]:
# =========================
# 1. Load dataset.csv
# =========================
file_path = "Raw File/"
file_name = "dataset0.csv"

print(f"Loading {file_path}{file_name}")
reads_df = pd.read_csv(f"{file_path}{file_name}")

columns_to_drop = ['Unnamed: 0']
reads_df = reads_df.drop(columns_to_drop, axis=1, errors='ignore')

# sorry cherron i cannot with the colnames
reads_df = reads_df.rename(columns={
    'ID': 'transcript_id',
    'POS': 'transcript_position',
    'SEQ': '7mer'
})
reads_df['n_reads'] = reads_df.groupby(['transcript_id', 'transcript_position']).transform('size')
# =========================
# 2. 7-mer Encoding (MUST MATCH TRAINING)
# =========================
print("Encoding 7mers")
def encode_drach_compact(seq):
    """
    Compact one-hot encoding of a 7-mer centered on a DRACH motif.
    Positions:
    - 0: full one-hot (A,C,G,T) → 4 dims
    - 1: D (A,G,T) → 3 dims
    - 2: R (A,G)   → 2 dims
    - 3: A (fixed) → 0 dims
    - 4: C (fixed) → 0 dims
    - 5: H (A,C,T) → 3 dims
    - 6: full one-hot (A,C,G,T) → 4 dims
    Total: 16-dimensional vector
    """
    encoding = []

    base = seq[0]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    base = seq[1]
    encoding.extend(one_hot_base(base, ['A', 'G', 'T']))  # D

    base = seq[2]
    encoding.extend(one_hot_base(base, ['A', 'G']))       # R

    # skip position 3 (always A)
    # skip position 4 (always C)

    base = seq[5]
    encoding.extend(one_hot_base(base, ['A', 'C', 'T']))  # H

    base = seq[6]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    return np.array(encoding, dtype=np.float32)

def one_hot_base(base, allowed):
    """One-hot encode base using only allowed bases."""
    vec = [0] * len(allowed)
    if base in allowed:
        vec[allowed.index(base)] = 1
    return vec

reads_df['7mer_emb'] = reads_df['7mer'].apply(encode_drach_compact)

# =========================
# 3. Dataset Class (NO LABELS NEEDED)
# =========================
class MILReadDatasetInference(Dataset):
    def __init__(self, reads_df, n_reads_per_site=None, agg_config=None):
        """
        reads_df: DataFrame with read-level rows:
                  ['transcript_id','transcript_position','7mer_emb,
                   'PreTime','PreSD','PreMean','InTime','InSD','InMean','PostTime','PostSD','PostMean',...]
        n_reads_per_site: int or None
            - int: sample at most this many reads per site
            - None: use all reads
        agg_config: dict for aggregation at site level, e.g.
            {
              "Time": ["min","max","mean","25","75"],
              "SD": ["mean"],
              "Mean": ["mean"]
            }
        """

        self.n_reads_per_site = n_reads_per_site
        self.groups = reads_df.groupby(['transcript_id', 'transcript_position'])
        self.bags = list(self.groups.groups.keys())
        self.reads_df = reads_df
        # self.use_delta = True  # <--- toggle delta features on read levels

        # -----------------------------
        # Feature toggle switches
        # -----------------------------
        # Comment/uncomment entries to include/exclude features
        self.read_feature_flags = {
            "numeric": True,   # PreTime..PostMean
            "7mer": True,      # 7mer embedding at read-level
            "delta": True,
        }
        self.site_feature_flags = {
            "numeric_aggs": True,  # aggregated Time/SD/Mean stats
            "7mer": True,          # site-level 7mer embedding
        }

        # Default aggregation if not passed
        self.agg_config = agg_config or {
            "Time": ["min", "max", "mean", "25", "75"],
            "SD": ["mean"],
            "Mean": ["mean"]
        }

        # Bag lengths
        self.bag_lengths = {k: len(v) for k, v in self.groups}

        # -----------------------------
        # Infer dimensions
        # -----------------------------
        dummy_bag = self[0]
        _, _, _, _ = dummy_bag
        print(f"✅ Dataset initialized: read_dim = {self.read_dim}, site_dim = {self.site_dim}")

    def __len__(self):
        return len(self.bags)

    def __getitem__(self, idx):
        tid, pos = self.bags[idx]
        g = self.groups.get_group((tid, pos))

        # -----------------------------
        # Read-level matrix
        # -----------------------------
        read_parts = []

        if self.read_feature_flags["numeric"]:
            numeric_feats = g[[
                'PreTime','PreSD','PreMean',
                'InTime','InSD','InMean',
                'PostTime','PostSD','PostMean'
            ]].values.astype(np.float32)
            read_parts.append(numeric_feats)

        if self.read_feature_flags["7mer"]:
            kmer_list = list(g['7mer_emb'].values)
            kmer_emb_read = np.vstack(kmer_list).astype(np.float32)
            read_parts.append(kmer_emb_read)


        if self.read_feature_flags["delta"]:
            deltas = []
            # Time deltas
            deltas.append(numeric_feats[:, 3] - numeric_feats[:, 0])  # InTime - PreTime
            deltas.append(numeric_feats[:, 6] - numeric_feats[:, 3])  # PostTime - InTime
            # SD deltas
            deltas.append(numeric_feats[:, 4] - numeric_feats[:, 1])  # InSD - PreSD
            deltas.append(numeric_feats[:, 7] - numeric_feats[:, 4])  # PostSD - InSD
            # Mean deltas
            deltas.append(numeric_feats[:, 5] - numeric_feats[:, 2])  # InMean - PreMean
            deltas.append(numeric_feats[:, 8] - numeric_feats[:, 5])  # PostMean - InMean

            delta_feats = np.stack(deltas, axis=1)  # shape (n_reads, 6)
            read_parts.append(delta_feats)
        
        bag_read_level = np.concatenate(read_parts, axis=1)

        # Random subsampling
        if self.n_reads_per_site is not None and bag_read_level.shape[0] > self.n_reads_per_site:
            idxs = np.random.choice(bag_read_level.shape[0], self.n_reads_per_site, replace=False)
            bag_read_level = bag_read_level[idxs]

        # -----------------------------
        # Site-level vector
        # -----------------------------
        site_parts = []

        if self.site_feature_flags["numeric_aggs"]:
            site_aggs = []
            groups = {"Time": [0, 3, 6], "SD": [1, 4, 7], "Mean": [2, 5, 8]}
            for feat_type, idx_list in groups.items():
                stats = self.agg_config.get(feat_type, [])
                for col_idx in idx_list:
                    vals = numeric_feats[:, col_idx].astype(np.float32)
                    for stat in stats:
                        if stat == "min": site_aggs.append(np.min(vals))
                        elif stat == "max": site_aggs.append(np.max(vals))
                        elif stat == "mean": site_aggs.append(np.mean(vals))
                        elif stat == "25": site_aggs.append(np.percentile(vals, 25))
                        elif stat == "75": site_aggs.append(np.percentile(vals, 75))
            site_parts.append(np.array(site_aggs, dtype=np.float32))

        if self.site_feature_flags["7mer"]:
            first_kmer = np.asarray(g['7mer_emb'].iloc[0], dtype=np.float32).ravel()
            site_parts.append(first_kmer)
            

        bag_site_level = np.concatenate(site_parts, axis=0).astype(np.float32)

        # Store dims once (for printing at init)
        if not hasattr(self, "read_dim"):
            self.read_dim = bag_read_level.shape[1]
            self.site_dim = bag_site_level.shape[0]
            print(f"Read-level dim: {self.read_dim} | Site-level dim: {self.site_dim}")

        return (torch.tensor(bag_read_level, dtype=torch.float32),
                torch.tensor(bag_site_level, dtype=torch.float32),
                tid, pos)


# Create dataset & dataloader
inference_ds = MILReadDatasetInference(reads_df)
inference_loader = torch.utils.data.DataLoader(inference_ds, batch_size=1, shuffle=False)




Loading Raw File/dataset0.csv
Encoding 7mers
Read-level dim: 31 | Site-level dim: 37
✅ Dataset initialized: read_dim = 31, site_dim = 37


In [13]:
# =========================
# 4. Define Model (MUST MATCH TRAINING)
# =========================
class MultiHeadAttentionPool(nn.Module):
    def __init__(self, hidden_dim, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        # project to head space then scalar score per head
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.head_score = nn.Linear(hidden_dim, n_heads)  # outputs (batch, n_instances, n_heads)
    
    def forward(self, H):  # H: (B, N, hidden_dim)
        # Optionally nonlinearity
        S = torch.tanh(self.proj(H))            # (B, N, hidden_dim)
        scores = self.head_score(S)             # (B, N, n_heads)
        attn = torch.softmax(scores, dim=1)     # softmax over instances per head
        # attn: (B, N, n_heads). compute per-head pooled vectors:
        # transpose H to (B, hidden_dim, N) to do matmul
        pooled = []
        for h in range(self.n_heads):
            a = attn[..., h].unsqueeze(-1)      # (B, N, 1)
            m = torch.sum(a * H, dim=1)         # (B, hidden_dim)
            pooled.append(m)
        # concat head outputs
        M = torch.cat(pooled, dim=1)           # (B, hidden_dim * n_heads)
        return M, attn                         # attn shape (B, N, n_heads)

class AttentionMIL_v2(nn.Module):
    def __init__(self, read_dim, site_dim, hidden_dim=128, n_heads=4, dropout=0.2):
        super().__init__()
        self.read_encoder = nn.Sequential(
            nn.Linear(read_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.site_encoder = nn.Sequential(
            nn.Linear(site_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.pool = MultiHeadAttentionPool(hidden_dim, n_heads=n_heads)
        # project pooled concat (hidden_dim * n_heads) back to hidden_dim
        self.pool_proj = nn.Linear(hidden_dim * n_heads, hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, bag_read_level, bag_site_level):
        # bag_read_level: (B,N,read_dim)
        H = self.read_encoder(bag_read_level)   # (B,N,hidden_dim)
        M_concat, attn = self.pool(H)           # (B, hidden_dim * n_heads)
        M = self.pool_proj(M_concat)            # (B, hidden_dim)
        site = self.site_encoder(bag_site_level) # (B, hidden_dim)
        combined = torch.cat([M, site], dim=-1) # (B, hidden_dim*2)
        out = self.classifier(combined).view(-1)
        return out, attn


read_dim = inference_ds.read_dim
site_dim = inference_ds.site_dim
model = AttentionMIL_v2(read_dim=read_dim, site_dim=site_dim, hidden_dim=128, dropout=0.2).to(device)
model.eval()
print("Loaded trained model weights")

# =========================
# 5. Inference & Output
# =========================
def run_inference(model_path, loader, 
                  mode="both", top_k=5, device="cpu"):
    """
    Inference pipeline:
      - Loads models (best single, top-k ensemble)
      - Runs inference on loader (no labels needed)
      - Saves prediction outputs to CSV
    """

    # --- Parse saved models ---
    pattern = r"epoch\d+_valpr(\d+\.\d+)\.pth"
    models = []
    for f in os.listdir(model_path):
        match = re.match(pattern, f)
        if match:
            pr_auc = float(match.group(1))
            models.append((f, pr_auc))
    if not models:
        raise ValueError("No valid model files found in directory.")

    # --- Sort descending by PR-AUC ---
    models.sort(key=lambda x: x[1], reverse=True)

    results = {}

    # --- Load single best ---
    if mode in ["single", "both"]:
        best_model_file, _ = models[0]
        model = AttentionMIL_v2(
            read_dim=loader.dataset.read_dim,
            site_dim=loader.dataset.site_dim,
            hidden_dim=128,
            dropout=0.2,
            n_heads=4
        ).to(device)
        model.load_state_dict(torch.load(f"{model_path}/{best_model_file}", map_location=device))
        model.eval()
        results["single_model"] = model
        results["single_name"] = best_model_file

    # --- Load top-k ensemble ---
    if mode in ["topk", "both"]:
        top_models = models[:top_k]
        ensemble = []
        for m_name, _ in top_models:
            m = AttentionMIL_v2(
                read_dim=loader.dataset.read_dim,
                site_dim=loader.dataset.site_dim,
                hidden_dim=128,
                dropout=0.2,
                n_heads=4
            ).to(device)
            m.load_state_dict(torch.load(f"{model_path}/{m_name}", map_location=device))
            m.eval()
            ensemble.append(m)
        results["ensemble_models"] = ensemble
        results["ensemble_names"] = [m[0] for m in top_models]

    # --- Inference (no labels) ---
    single_rows, ens_rows = [], []
    with torch.no_grad():
        for bag_read_level, bag_site_level, tid, pos in loader:
            bag_read_level = bag_read_level.to(device)
            bag_site_level = bag_site_level.to(device)

            # Single best model
            if "single_model" in results:
                out, _ = results["single_model"](bag_read_level, bag_site_level)
                prob = torch.sigmoid(out).item()
                single_rows.append({
                    'transcript_id': tid[0],
                    'transcript_position': pos.item(),
                    'score': prob
                })

            # Ensemble
            if "ensemble_models" in results:
                preds = []
                for m in results["ensemble_models"]:
                    out, _ = m(bag_read_level, bag_site_level)
                    preds.append(torch.sigmoid(out).item())
                avg_prob = sum(preds) / len(preds)
                ens_rows.append({
                    'transcript_id': tid[0],
                    'transcript_position': pos.item(),
                    'score': avg_prob
                })

    # --- Save outputs ---
    outputs = {}
    if single_rows:
        df_single = pd.DataFrame(single_rows)
        df_single.to_csv(f"Results/best_single_inference_of_{file_name}", index=False)
        print("Saved best_single_inference.csv")
        outputs["single"] = df_single

    if ens_rows:
        df_ens = pd.DataFrame(ens_rows)
        df_ens.to_csv(f"Results/top_{top_k}_ensemble_inference_of_{file_name}", index=False)
        print(f"Saved top_{top_k}_ensemble_inference.csv")
        outputs["ensemble"] = df_ens

    return outputs


Loaded trained model weights


In [14]:
model_path = "Models/"
results = run_inference(model_path, mode="both", top_k=5, loader=inference_loader, device=device)
print(results["single"].head())
print(results["ensemble"].head())

Saved best_single_inference.csv
Saved top_5_ensemble_inference.csv
     transcript_id  transcript_position     score
0  ENST00000000233                  244  0.220471
1  ENST00000000233                  261  0.279924
2  ENST00000000233                  316  0.081724
3  ENST00000000233                  332  0.336998
4  ENST00000000233                  368  0.205862
     transcript_id  transcript_position     score
0  ENST00000000233                  244  0.269454
1  ENST00000000233                  261  0.314499
2  ENST00000000233                  316  0.100107
3  ENST00000000233                  332  0.346790
4  ENST00000000233                  368  0.290033


In [15]:
print("Done")

Done
