

# ✅ WHAT YOU MUST UPDATE IF TRAINING SCRIPT CHANGES

Any time you change preprocessing or architecture in the **training script**, you must copy over the same edits into this inference script in these sections:

### ✅ 1️⃣ 7-mer encoding

```python
def encode_7mer(...)
```

### ✅ 2️⃣ Numeric feature selection

```python
numeric_feats = g[[ ... ]]
```

### ✅ 3️⃣ Model architecture

```python
class AttentionMIL(...)
```

and `input_dim`, `hidden_dim` if changed.

---


In [3]:
# =========================
# 0. Imports & Setup
# =========================
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from sklearn.metrics import roc_auc_score, average_precision_score
import pickle
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [10]:
# =========================
# 1. Load dataset.csv
# =========================
file_path = "Raw File/"
file_name = "dataset2.csv"

print(f"Loading {file_path}{file_name}")
reads_df = pd.read_csv(f"{file_path}{file_name}")

columns_to_drop = ['Unnamed: 0']
reads_df = reads_df.drop(columns_to_drop, axis=1, errors='ignore')

# sorry cherron i cannot with the colnames
reads_df = reads_df.rename(columns={
    'ID': 'transcript_id',
    'POS': 'transcript_position',
    'SEQ': '7mer'
})
reads_df['n_reads'] = reads_df.groupby(['transcript_id', 'transcript_position']).transform('size')
# =========================
# 2. 7-mer Encoding (MUST MATCH TRAINING)
# =========================
print("Encoding 7mers")
def encode_drach_compact(seq):
    """
    Compact one-hot encoding of a 7-mer centered on a DRACH motif.
    Positions:
    - 0: full one-hot (A,C,G,T) → 4 dims
    - 1: D (A,G,T) → 3 dims
    - 2: R (A,G)   → 2 dims
    - 3: A (fixed) → 0 dims
    - 4: C (fixed) → 0 dims
    - 5: H (A,C,T) → 3 dims
    - 6: full one-hot (A,C,G,T) → 4 dims
    Total: 16-dimensional vector
    """
    encoding = []

    base = seq[0]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    base = seq[1]
    encoding.extend(one_hot_base(base, ['A', 'G', 'T']))  # D

    base = seq[2]
    encoding.extend(one_hot_base(base, ['A', 'G']))       # R

    # skip position 3 (always A)
    # skip position 4 (always C)

    base = seq[5]
    encoding.extend(one_hot_base(base, ['A', 'C', 'T']))  # H

    base = seq[6]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    return np.array(encoding, dtype=np.float32)

def one_hot_base(base, allowed):
    """One-hot encode base using only allowed bases."""
    vec = [0] * len(allowed)
    if base in allowed:
        vec[allowed.index(base)] = 1
    return vec

reads_df['7mer_emb'] = reads_df['7mer'].apply(encode_drach_compact)

# =========================
# 3. Dataset Class (NO LABELS NEEDED)
# =========================
class MILReadDatasetInference(Dataset):
    def __init__(self, reads_df, n_reads_per_site=None):
        """
        reads_df: DataFrame of read-level features with columns like
                  ['transcript_id', 'transcript_position', '7mer_emb', 'dwell_-1', ...]
        n_reads_per_site: int or None
            - int: maximum number of reads per site (randomly sampled)
            - None: use all reads
        """
        self.n_reads_per_site = n_reads_per_site
        # group by site (transcript_id, transcript_position)
        self.groups = reads_df.groupby(['transcript_id','transcript_position'])
        self.bags = list(self.groups.groups.keys())
        self.reads_df = reads_df

    def __len__(self):
        return len(self.bags)

    def __getitem__(self, idx):
        tid, pos = self.bags[idx]
        g = self.groups.get_group((tid,pos))

        # numeric features
        numeric_feats = g[['PreTime','PreSD','PreMean',
                           'InTime','InSD','InMean',
                           'PostTime','PostSD','PostMean']].values.astype(np.float32)

        # k-mer embedding
        kmer_emb = np.stack(g['7mer_emb'].values)
        
        # concatenate numeric + embedding
        bag = np.concatenate([numeric_feats, kmer_emb], axis=1)

        # ------------------------
        # Handle n_reads_per_site
        # ------------------------
        if self.n_reads_per_site is not None and bag.shape[0] > self.n_reads_per_site:
            # randomly sample n_reads_per_site reads
            indices = np.random.choice(bag.shape[0], self.n_reads_per_site, replace=False)
            bag = bag[indices]
        
        return torch.tensor(bag), tid, pos

# Create dataset & dataloader
inference_ds = MILReadDatasetInference(reads_df)
inference_loader = torch.utils.data.DataLoader(inference_ds, batch_size=1, shuffle=False)




Loading Raw File/dataset2.csv
Encoding 7mers


In [12]:
# =========================
# 4. Define Model (MUST MATCH TRAINING)
# =========================
class AttentionMIL(nn.Module):
    def __init__(self, input_dim=25, hidden_dim=64):
        super().__init__()
        self.instance_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.classifier = nn.Linear(hidden_dim, 1)

    def forward(self, bag):
        H = self.instance_encoder(bag)  
        A = torch.softmax(self.attention(H), dim=0)  
        M = torch.sum(A * H, dim=0)  
        out = torch.sigmoid(self.classifier(M))
        return out, A

model = AttentionMIL(input_dim=25, hidden_dim=64).to(device)
model.load_state_dict(torch.load("models/epoch38_valpr0.42.pth", map_location=device))
model.eval()
print("Loaded trained model weights")

# =========================
# 5. Inference & Output
# =========================
output_rows = []

with torch.no_grad():
    for bag, tid, pos in inference_loader:
        bag = bag[0].to(device)
        out, _ = model(bag)
        output_rows.append({
            'transcript_id': tid[0],
            'transcript_position': pos.item(),
            'score': out.item()
        })

output_file_name = f"results_of_{file_name}"
output_file_path = "Results/"
output_df = pd.DataFrame(output_rows)
output_df.to_csv(f"{output_file_path}{output_file_name}", index=False)

print(f"Saved predictions of {file_name} to {output_file_path}{output_file_name}")


Loaded trained model weights
Saved predictions of dataset2.csv to Results/results_of_dataset2.csv
