In [65]:
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
import pandas as pd
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, confusion_matrix, accuracy_score
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import precision_recall_curve, auc
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [31]:
train = pd.read_parquet("data/train_all_features.parquet").drop(columns = ["transcript_id", "transcript_position", "seq", "gene_id"])
test = pd.read_parquet("data/test_all_features.parquet").drop(columns = ["transcript_id", "transcript_position", "seq", "gene_id"])

In [45]:
scaler = StandardScaler()
# scaler = MinMaxScaler()

x_train = train.drop(columns = "label")
x_test = test.drop(columns = "label")

train_scaled_arr = scaler.fit_transform(x_train)
test_scaled_arr = scaler.transform(x_test)
train_scaled = pd.DataFrame(train_scaled_arr, columns=x_train.columns)
test_scaled = pd.DataFrame(test_scaled_arr, columns=x_test.columns)

train_scaled["label"] = train["label"].values
test_scaled["label"] = test["label"].values

In [78]:
class BidirectionalRNN(pl.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, learning_rate=1e-3):
        super(BidirectionalRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, output_size)
        self.learning_rate = learning_rate
        self.criterion = nn.BCEWithLogitsLoss()
        self.val_outputs = []
        
    def forward(self, x):
        # x: (batch_size, seq_len, input_size)
        rnn_out, _ = self.rnn(x)  # rnn_out: (batch_size, seq_len, hidden_size * 2)
        last_hidden = rnn_out[:, -1, :]  # (batch_size, hidden_size * 2)
        out = self.fc(last_hidden)  # (batch_size, output_size)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.criterion(logits, y)
        preds = torch.sigmoid(logits)
        self.log('val_loss', loss, prog_bar=True)

        # Save outputs to instance variable for use in on_validation_epoch_end
        self.val_outputs.append({'preds': preds, 'targets': y})
    
    def on_validation_epoch_end(self):
        # Gather predictions and targets across all validation batches
        all_preds = torch.cat([x['preds'] for x in self.val_outputs], dim=0).cpu().numpy()
        all_targets = torch.cat([x['targets'] for x in self.val_outputs], dim=0).cpu().numpy()
        
        # Calculate Precision-Recall AUC
        precision, recall, _ = precision_recall_curve(all_targets, all_preds)
        pr_auc = auc(recall, precision)
        
        # Log PR AUC
        self.log('val_pr_auc', pr_auc, prog_bar=True)

        # Clear the outputs for the next epoch
        self.val_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Load dataset
def load_dataset(df):
    n_steps = 3

    sequences = []
    for i in range(1, n_steps + 1):
        step_cols = [col for col in df.columns if col.endswith(f'_{i}')]
        step_data = df[step_cols].values
        sequences.append(step_data)
    
    sequences = np.stack(sequences, axis=1)  # (num_samples, n_steps, 3)
    
    labels = df['label'].values
    return sequences, labels

# Create datasets
def create_dataloaders(train_data, train_labels, val_data, val_labels, batch_size=32):
    train_dataset = TensorDataset(torch.tensor(train_data, dtype=torch.float32), torch.tensor(train_labels, dtype=torch.float32))
    val_dataset = TensorDataset(torch.tensor(val_data, dtype=torch.float32), torch.tensor(val_labels, dtype=torch.float32))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

# Model training
def train_model(train_data, train_labels, val_data, val_labels, input_size, hidden_size, output_size):
    train_loader, val_loader = create_dataloaders(train_data, train_labels, val_data, val_labels)

    model = BidirectionalRNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_pr_auc',
        mode='max',
        save_top_k=1,
        filename='best-checkpoint',
        verbose=True
    )

    trainer = pl.Trainer(max_epochs=5, callbacks=[checkpoint_callback], log_every_n_steps=10)
    trainer.fit(model, train_loader, val_loader)

    return model

full_train_data, full_train_labels = load_dataset(train_scaled)

train_data, val_data, train_labels, val_labels = train_test_split(
    full_train_data, full_train_labels, test_size=0.2, random_state=42, stratify=full_train_labels
)

# Model parameters
input_size = train_data.shape[2]  # Based on your dataset
hidden_size = 128
output_size = 1  # Binary classification

model = train_model(train_data, train_labels, val_data, val_labels, input_size, hidden_size, output_size)

Trainer will use only 1 of 10 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=10)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8,9]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | rnn       | RNN               | 44.8 K | train
1 | fc        | Linear            | 257    | train
2 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
45.1 K    Trainable params
0         Non-trainable params
45.1 K    Total params
0.180     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/fangyu.hoo/miniconda3/envs/search3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/fangyu.hoo/miniconda3/envs/search3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 0, global step 2421: 'val_pr_auc' reached 0.37105 (best 0.37105), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_5/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 1, global step 4842: 'val_pr_auc' reached 0.39967 (best 0.39967), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_5/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 2, global step 7263: 'val_pr_auc' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 3, global step 9684: 'val_pr_auc' reached 0.42618 (best 0.42618), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_5/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 4, global step 12105: 'val_pr_auc' reached 0.43014 (best 0.43014), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_5/checkpoints/best-checkpoint.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=5` reached.


In [79]:
def create_test_dataloader(test_data, batch_size=32):
    test_dataset = TensorDataset(torch.tensor(test_data, dtype=torch.float32))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

test_data, test_labels = load_dataset(test_scaled)
test_loader = create_test_dataloader(test_data)

In [80]:
def predict_test_set(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    all_logits = []
    with torch.no_grad():  # Disable gradient calculation for inference
        for batch in test_loader:
            x_test = batch[0]  # Extract the test features
            logits = model(x_test)  # Get the predictions/logits
            proba = torch.sigmoid(logits)
            all_logits.append(proba)
    
    # Concatenate all logits into a single tensor
    all_logits = torch.cat(all_logits, dim=0).cpu().numpy()
    return all_logits

test_logits = predict_test_set(model, test_loader)

In [81]:
print(f'roc auc: {round(roc_auc_score(test_labels, test_logits),4)}')
precision, recall, thresholds = precision_recall_curve(test_labels, test_logits)
print(f'pr auc: {round(auc(recall, precision),4)}')

roc auc: 0.9066
pr auc: 0.4584
