In [1]:
# =========================
# 0. Imports & Setup
# =========================
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from sklearn.metrics import roc_auc_score, average_precision_score
import pickle
import time

# Reproducibility
RNG = 42
np.random.seed(RNG)
torch.manual_seed(RNG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# =========================
# 1. Load and Preparing datasets
# =========================
print("Starting: 1. Load and Prepare datasets")
reads_df = pd.read_csv("Dataset/dataset.csv")

columns_to_drop = ['Unnamed: 0']
reads_df = reads_df.drop(columns_to_drop, axis=1, errors='ignore')

# sorry cherron i cannot with the colnames
reads_df = reads_df.rename(columns={
    'ID': 'transcript_id',
    'POS': 'transcript_position',
    'SEQ': '7mer'
})
reads_df['n_reads'] = reads_df.groupby(['transcript_id', 'transcript_position']).transform('size')


print("Ending: 1. Load and Prepare datasets")
# =========================
# 2. Preprocessing: 7-mer embedding
# =========================
print("Starting: 2. Preprocessing: 7-mer embedding")

def encode_drach_compact(seq):
    """
    Compact one-hot encoding of a 7-mer centered on a DRACH motif.
    Positions:
    - 0: full one-hot (A,C,G,T) â†’ 4 dims
    - 1: D (A,G,T) â†’ 3 dims
    - 2: R (A,G)   â†’ 2 dims
    - 3: A (fixed) â†’ 0 dims
    - 4: C (fixed) â†’ 0 dims
    - 5: H (A,C,T) â†’ 3 dims
    - 6: full one-hot (A,C,G,T) â†’ 4 dims
    Total: 16-dimensional vector
    """
    encoding = []

    base = seq[0]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    base = seq[1]
    encoding.extend(one_hot_base(base, ['A', 'G', 'T']))  # D

    base = seq[2]
    encoding.extend(one_hot_base(base, ['A', 'G']))       # R

    # skip position 3 (always A)
    # skip position 4 (always C)

    base = seq[5]
    encoding.extend(one_hot_base(base, ['A', 'C', 'T']))  # H

    base = seq[6]
    encoding.extend(one_hot_base(base, ['A', 'C', 'G', 'T']))

    return np.array(encoding, dtype=np.float32)

def one_hot_base(base, allowed):
    """One-hot encode base using only allowed bases."""
    vec = [0] * len(allowed)
    if base in allowed:
        vec[allowed.index(base)] = 1
    return vec

reads_df['7mer_emb'] = reads_df['7mer'].apply(encode_drach_compact)
print("Ending: 2. Preprocessing: 7-mer embedding")


# =========================
# 3. Assign split bins
# =========================
print("Starting: 3. Assign split bins")

def assign_set_type_by_gene(reads_df, split_ratios={'Train': 0.8, 'Val': 0.1, 'Test': 0.1}, random_state=42):
    """
    Assigns each row in reads_df a 'set_type' of Train, Val, or Test,
    ensuring all rows with the same gene_id are in the same set,
    and total number of rows (not just genes) in each set matches desired ratios.
    Label distribution is approximately balanced using a greedy strategy.
    """

    # Step 1: Get stats per gene
    gene_stats = (
        reads_df
        .groupby('gene_id')['label']
        .value_counts()
        .unstack(fill_value=0)
        .rename(columns={0: 'label_0', 1: 'label_1'})
        .reset_index()
    )
    gene_stats['total'] = gene_stats['label_0'] + gene_stats['label_1']

    # Shuffle genes for randomness
    gene_stats = gene_stats.sample(frac=1, random_state=random_state).reset_index(drop=True)

    # Step 2: Overall label distribution and target row counts
    total_rows = gene_stats['total'].sum()
    total_label_1 = gene_stats['label_1'].sum()
    overall_pos_rate = total_label_1 / total_rows

    target_rows = {k: total_rows * split_ratios[k] for k in split_ratios}

    # Step 3: Initialize bins
    bins = {
        'Train': {'genes': [], 'label_0': 0, 'label_1': 0, 'total': 0},
        'Val': {'genes': [], 'label_0': 0, 'label_1': 0, 'total': 0},
        'Test': {'genes': [], 'label_0': 0, 'label_1': 0, 'total': 0},
    }

    def pick_bin():
        # Find the bin with the biggest gap between current and target row count
        diffs = {k: target_rows[k] - bins[k]['total'] for k in bins}
        # Choose the bin that needs rows the most
        return max(diffs, key=diffs.get)

    # Step 4: Assign genes to bins to match row targets and label balance
    for _, row in gene_stats.iterrows():
        chosen_bin = pick_bin()
        bins[chosen_bin]['genes'].append(row['gene_id'])
        bins[chosen_bin]['label_0'] += row['label_0']
        bins[chosen_bin]['label_1'] += row['label_1']
        bins[chosen_bin]['total'] += row['total']

    # Step 5: Map gene_id â†’ set_type
    gene_to_set = {}
    for set_name, bin_data in bins.items():
        for gene_id in bin_data['genes']:
            gene_to_set[gene_id] = set_name

    reads_df['set_type'] = reads_df['gene_id'].map(gene_to_set)

    return reads_df


reads_df = assign_set_type_by_gene(reads_df)

set_counts = reads_df['set_type'].value_counts()
print("ðŸ“Š Number of rows in each set:")
for set_name, count in set_counts.items():
    print(f"  - {set_name}: {count} rows")

# Print label distribution per set (normalized)
label_distributions = reads_df.groupby('set_type')['label'].value_counts(normalize=True).unstack()

print("\nðŸ“ˆ Label distribution (percentage of label 0 and 1) in each set:")
for set_name in label_distributions.index:
    label_0_pct = label_distributions.loc[set_name].get(0, 0) * 100
    label_1_pct = label_distributions.loc[set_name].get(1, 0) * 100
    print(f"  - {set_name}:")
    print(f"      â€¢ Label 0: {label_0_pct:.2f}%")
    print(f"      â€¢ Label 1: {label_1_pct:.2f}%")

print("Ending: 3. Assign split bins")



Using device: cuda
Starting: 1. Load and Prepare datasets
Ending: 1. Load and Prepare datasets
Starting: 2. Preprocessing: 7-mer embedding
Ending: 2. Preprocessing: 7-mer embedding
Starting: 3. Assign split bins
ðŸ“Š Number of rows in each set:
  - Train: 8820055 rows
  - Val: 1105069 rows
  - Test: 1101982 rows

ðŸ“ˆ Label distribution (percentage of label 0 and 1) in each set:
  - Test:
      â€¢ Label 0: 94.08%
      â€¢ Label 1: 5.92%
  - Train:
      â€¢ Label 0: 95.59%
      â€¢ Label 1: 4.41%
  - Val:
      â€¢ Label 0: 95.90%
      â€¢ Label 1: 4.10%
Ending: 3. Assign split bins


In [2]:
print("Column Data Types:")
print(reads_df.dtypes)

# Display the number of rows
print("\nNumber of Rows:", len(reads_df))

Column Data Types:
transcript_id           object
transcript_position      int64
7mer                    object
PreTime                float64
PreSD                  float64
PreMean                float64
InTime                 float64
InSD                   float64
InMean                 float64
PostTime               float64
PostSD                 float64
PostMean               float64
gene_id                 object
label                    int64
n_reads                  int64
7mer_emb                object
set_type                object
dtype: object

Number of Rows: 11027106


In [3]:
reads_df.to_parquet("Dataset/first_step_processed_dataset.parquet", index=False)
print("Saved processed dataset to Dataset/first_step_processed_dataset.parquet")


Saved processed dataset to Dataset/first_step_processed_dataset.parquet


In [4]:
# =========================
# 4. Dataset class
# =========================
print("Starting: 4. Dataset class")
class MILReadDataset(Dataset):
    def __init__(self, reads_df, n_reads_per_site=None):
        """
        reads_df: DataFrame of read-level features with columns like
                  ['transcript_id', 'transcript_position', '7mer_emb', 'dwell_-1', ...]
        n_reads_per_site: int or None
            - int: maximum number of reads per site (randomly sampled)
            - None: use all reads
        """
        self.n_reads_per_site = n_reads_per_site
        # group by site (transcript_id, transcript_position)
        self.groups = reads_df.groupby(['transcript_id','transcript_position'])
        self.bags = list(self.groups.groups.keys())
        self.reads_df = reads_df

    def __len__(self):
        return len(self.bags)

    def __getitem__(self, idx):
        tid, pos = self.bags[idx]
        g = self.groups.get_group((tid,pos))

        # numeric features
        numeric_feats = g[['PreTime','PreSD','PreMean',
                           'InTime','InSD','InMean',
                           'PostTime','PostSD','PostMean']].values.astype(np.float32)

        # k-mer embedding
        kmer_emb = np.stack(g['7mer_emb'].values)
        
        # concatenate numeric + embedding
        bag = np.concatenate([numeric_feats, kmer_emb], axis=1)

        # ------------------------
        # Handle n_reads_per_site
        # ------------------------
        if self.n_reads_per_site is not None and bag.shape[0] > self.n_reads_per_site:
            # randomly sample n_reads_per_site reads
            indices = np.random.choice(bag.shape[0], self.n_reads_per_site, replace=False)
            bag = bag[indices]

        # label for the bag
        label = g['label'].iloc[0]
        
        return torch.tensor(bag), torch.tensor(label, dtype=torch.float32), tid, pos
print("Ending: 4. Dataset class")

# =========================
# 5. Imbalanced sampler
# =========================
class ImbalancedBagSampler(Sampler):
    def __init__(self, dataset):
        labels = np.array([dataset[i][1].item() for i in range(len(dataset))])
        pos_idx = np.where(labels==1)[0]
        neg_idx = np.where(labels==0)[0]
        self.indices = np.concatenate([np.random.choice(pos_idx, size=len(neg_idx), replace=True), neg_idx])
        np.random.shuffle(self.indices)

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


Starting: 4. Dataset class
Ending: 4. Dataset class


In [5]:

# =========================
# 6. Split Train/Val/Test
# =========================
print("Starting: 6. Split Train/Val/Test")

print("Loading in first_step_processed_dataset.parquet")
reads_df = pd.read_parquet("Dataset/first_step_processed_dataset.parquet")

# Now split by set_type
train_df = reads_df[reads_df['set_type'] == 'Train']
val_df   = reads_df[reads_df['set_type'] == 'Val']
test_df  = reads_df[reads_df['set_type'] == 'Test']

train_ds = MILReadDataset(train_df, 20)
val_ds   = MILReadDataset(val_df, 20)
test_ds  = MILReadDataset(test_df, 20)

train_loader = DataLoader(train_ds, batch_size=1, sampler=ImbalancedBagSampler(train_ds))
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)
print("Ending: 6. Split Train/Val/Test")

# =========================
# 7. Attention MIL Model
# =========================
print("Starting: 7. Attention MIL Model")
class AttentionMIL(nn.Module):
    def __init__(self, input_dim=25, hidden_dim=64):
        super().__init__()
        self.instance_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        self.classifier = nn.Linear(hidden_dim, 1)

    def forward(self, bag):
        H = self.instance_encoder(bag)  
        A = torch.softmax(self.attention(H), dim=0)  
        M = torch.sum(A * H, dim=0)  
        out = torch.sigmoid(self.classifier(M))
        return out, A

model = AttentionMIL(input_dim=25, hidden_dim=64).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print("Ending: 7. Attention MIL Model")

# =========================
# 8. Training Loop
# =========================
print("Starting: 8. Training Loop")


# ---- Hyperparameters ----
n_epochs = 100
early_stop_patience = 10  # stop if no PR-AUC improvement for 5 epochs
best_val_pr = 0.0
patience_counter = 0

# ---- Tracking Metrics ----
metrics_dict = {
    'epoch': [],
    'train_loss': [],
    'val_roc_auc': [],
    'val_pr_auc': [],
    'epoch_time_sec': []
}

for epoch in range(n_epochs):
    start_time = time.time()

    model.train()
    total_loss = 0.0
    for bag, label, _, _ in train_loader:
        bag = bag[0].to(device)
        label = label.to(device)

        optimizer.zero_grad()
        out, attn = model(bag)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    total_loss /= len(train_loader)

    # Validation
    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for bag, label, _, _ in val_loader:
            bag = bag[0].to(device)
            label = label.to(device)
            out, attn = model(bag)

            val_preds.append(out.item())
            val_labels.append(label.item())

    # ---- Compute Metrics ----        
    val_auc = roc_auc_score(val_labels, val_preds)
    val_pr  = average_precision_score(val_labels, val_preds)

    # End timer
    epoch_time = time.time() - start_time

    # Log
    metrics_dict['epoch'].append(epoch+1)
    metrics_dict['train_loss'].append(total_loss)
    metrics_dict['val_roc_auc'].append(val_auc)
    metrics_dict['val_pr_auc'].append(val_pr)
    metrics_dict['epoch_time_sec'].append(epoch_time)

    print(f"Epoch {epoch+1}/{n_epochs}: "
          f"TrainLoss={total_loss:.4f}, "
          f"ValROC-AUC={val_auc:.4f}, "
          f"ValPR-AUC={val_pr:.4f}, "
          f"Time={epoch_time:.2f}s")

    # ---- Early Stopping & Model Saving ----
    if val_pr > best_val_pr:
        best_val_pr = val_pr
        patience_counter = 0
        torch.save(model.state_dict(), f"Models/epoch{epoch+1}_valpr{val_pr:.2f}.pth")

        print(f"Saved best model (ValPR-AUC improved to {val_pr:.4f}) at epoch {epoch+1}")
        
    else:
        patience_counter += 1
        print(f"No improvement in PR-AUC for {patience_counter} epoch(s).")

        if patience_counter >= early_stop_patience:
            print(f"Early stopping at epoch {epoch+1} (no PR-AUC improvement for {early_stop_patience} epochs).")
            break

print("Ending: 8. Training Loop")




Starting: 6. Split Train/Val/Test
Loading in first_step_processed_dataset.parquet
Ending: 6. Split Train/Val/Test
Starting: 7. Attention MIL Model
Ending: 7. Attention MIL Model
Starting: 8. Training Loop
Epoch 1/100: TrainLoss=0.5031, ValROC-AUC=0.8686, ValPR-AUC=0.2742, Time=544.06s
Saved best model (ValPR-AUC improved to 0.2742) at epoch 1
Epoch 2/100: TrainLoss=0.4759, ValROC-AUC=0.8774, ValPR-AUC=0.2859, Time=452.43s
Saved best model (ValPR-AUC improved to 0.2859) at epoch 2
Epoch 3/100: TrainLoss=0.4630, ValROC-AUC=0.8790, ValPR-AUC=0.3119, Time=447.26s
Saved best model (ValPR-AUC improved to 0.3119) at epoch 3
Epoch 4/100: TrainLoss=0.4515, ValROC-AUC=0.8869, ValPR-AUC=0.3419, Time=453.23s
Saved best model (ValPR-AUC improved to 0.3419) at epoch 4
Epoch 5/100: TrainLoss=0.4459, ValROC-AUC=0.8941, ValPR-AUC=0.3710, Time=452.71s
Saved best model (ValPR-AUC improved to 0.3710) at epoch 5
Epoch 6/100: TrainLoss=0.4402, ValROC-AUC=0.8871, ValPR-AUC=0.3398, Time=448.96s
No improvement

In [7]:
# =========================
# 9. Testing & Output
# =========================
print("Starting: 9. Testing & Output")
model.load_state_dict(torch.load("Models/epoch38_valpr0.42.pth"))
model.eval()
output_rows = []
y_test_true = []
y_test_pred = []

with torch.no_grad():
    for bag, label, tid, pos in test_loader:
        bag = bag[0].to(device)
        label = label.to(device)
        out, attn = model(bag)
        output_rows.append({'transcript_id': tid[0], 'transcript_position': pos.item(), 'score': out.item()})
        y_test_true.append(label.item())
        y_test_pred.append(out.item())

output_df = pd.DataFrame(output_rows)
output_df.to_csv("output.csv", index=False)
print("Saved predictions to output.csv")
print("Ending: 9. Testing & Output")

# =========================
# 10. Evaluation Metrics
# =========================
print("Starting: 10. Evaluation Metrics")
test_roc_auc = roc_auc_score(y_test_true, y_test_pred)
test_pr_auc  = average_precision_score(y_test_true, y_test_pred)

metrics_df = pd.DataFrame([{'test_roc_auc': test_roc_auc, 'test_pr_auc': test_pr_auc}])
metrics_df.to_csv("evaluation_metrics.csv", index=False)
print(f"Test ROC AUC: {test_roc_auc:.4f}, Test PR AUC: {test_pr_auc:.4f}")
print("Saved evaluation metrics to evaluation_metrics.csv")
print("Ending: 10. Evaluation Metrics")

Starting: 9. Testing & Output
Saved predictions to output.csv
Ending: 9. Testing & Output
Starting: 10. Evaluation Metrics
Test ROC AUC: 0.8924, Test PR AUC: 0.4458
Saved evaluation metrics to evaluation_metrics.csv
Ending: 10. Evaluation Metrics


---

In [9]:
import pandas as pd

# Assuming metrics_dict already exists
# Convert the dictionary to a DataFrame
df = pd.DataFrame(metrics_dict)

# Define the file path for CSV
file_path = 'metrics.csv'

# Save the DataFrame to a CSV file
df.to_csv(file_path, index=False)

print(f"Dictionary saved to {file_path} as CSV.")


Dictionary saved to metrics.csv as CSV.
