In [39]:
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
import pandas as pd
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, confusion_matrix, accuracy_score
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import precision_recall_curve, auc
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [40]:
train = pd.read_parquet("data/train_all_features.parquet").drop(columns = ["transcript_id", "transcript_position", "seq", "gene_id"])
test = pd.read_parquet("data/test_all_features.parquet").drop(columns = ["transcript_id", "transcript_position", "seq", "gene_id"])

In [41]:
scaler = StandardScaler()
# scaler = MinMaxScaler()

x_train = train.drop(columns = "label")
x_test = test.drop(columns = "label")

train_scaled_arr = scaler.fit_transform(x_train)
test_scaled_arr = scaler.transform(x_test)
train_scaled = pd.DataFrame(train_scaled_arr, columns=x_train.columns)
test_scaled = pd.DataFrame(test_scaled_arr, columns=x_test.columns)

train_scaled["label"] = train["label"].values
test_scaled["label"] = test["label"].values

In [None]:
class BidirectionalRNN(pl.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, learning_rate=1e-3):
        super(BidirectionalRNN, self).__init__()
        self.save_hyperparameters()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, output_size)
        self.learning_rate = learning_rate
        self.criterion = nn.BCEWithLogitsLoss()
        self.val_outputs = []
        
    def forward(self, x):
        # x: (batch_size, seq_len, input_size)
        rnn_out, _ = self.rnn(x)  # rnn_out: (batch_size, seq_len, hidden_size * 2)
        last_hidden = rnn_out[:, -1, :]  # (batch_size, hidden_size * 2)
        out = self.fc(last_hidden)  # (batch_size, output_size)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze(1)
        loss = self.criterion(logits, y)
        preds = torch.sigmoid(logits)
        self.log('val_loss', loss, prog_bar=True)

        # Save outputs to instance variable for use in on_validation_epoch_end
        self.val_outputs.append({'preds': preds, 'targets': y})
    
    def on_validation_epoch_end(self):
        # Gather predictions and targets across all validation batches
        all_preds = torch.cat([x['preds'] for x in self.val_outputs], dim=0).cpu().numpy()
        all_targets = torch.cat([x['targets'] for x in self.val_outputs], dim=0).cpu().numpy()
        
        # Calculate Precision-Recall AUC
        precision, recall, _ = precision_recall_curve(all_targets, all_preds)
        pr_auc = auc(recall, precision)

        # Calculate ROC AUC
        roc_auc = roc_auc_score(all_targets, all_preds)
        
        # Combine PR AUC and ROC AUC
        combined_metric = (pr_auc + roc_auc)/2
        
        # Log the combined metric
        self.log('combined_metric', combined_metric, prog_bar=True)

        # Log PR AUC and ROC AUC separately
        self.log('val_pr_auc', pr_auc, prog_bar=True)
        self.log('val_roc_auc', roc_auc, prog_bar=True)

        # Clear the outputs for the next epoch
        self.val_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Load dataset
def load_dataset(df):
    n_steps = 3

    sequences = []
    for i in range(1, n_steps + 1):
        step_cols = [col for col in df.columns if col.endswith(f'_{i}')]
        step_data = df[step_cols].values
        sequences.append(step_data)
    
    sequences = np.stack(sequences, axis=1)  # (num_samples, n_steps, 3)
    
    labels = df['label'].values
    return sequences, labels

# Create datasets
def create_dataloaders(train_data, train_labels, val_data, val_labels, batch_size=32):
    train_dataset = TensorDataset(torch.tensor(train_data, dtype=torch.float32), torch.tensor(train_labels, dtype=torch.float32))
    val_dataset = TensorDataset(torch.tensor(val_data, dtype=torch.float32), torch.tensor(val_labels, dtype=torch.float32))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

# Model training
def train_model(train_data, train_labels, val_data, val_labels, input_size, hidden_size, output_size):
    train_loader, val_loader = create_dataloaders(train_data, train_labels, val_data, val_labels)

    model = BidirectionalRNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='combined_metric',
        mode='max',
        save_top_k=1,
        filename='best-checkpoint',
        verbose=True
    )

    trainer = pl.Trainer(max_epochs=15, callbacks=[checkpoint_callback], log_every_n_steps=10)
    trainer.fit(model, train_loader, val_loader)

    best_model_path = checkpoint_callback.best_model_path
    best_model = BidirectionalRNN.load_from_checkpoint(best_model_path)

    return best_model

full_train_data, full_train_labels = load_dataset(train_scaled)

train_data, val_data, train_labels, val_labels = train_test_split(
    full_train_data, full_train_labels, test_size=0.2, random_state=42, stratify=full_train_labels
)

# Model parameters
input_size = train_data.shape[2]  # Based on your dataset
hidden_size = 32
output_size = 1  # Binary classification

model = train_model(train_data, train_labels, val_data, val_labels, input_size, hidden_size, output_size)

In [54]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_recall_curve, auc, roc_auc_score
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint

class MultitaskAutoencoderRNN(pl.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, learning_rate=1e-3, alpha=0.5):
        super(MultitaskAutoencoderRNN, self).__init__()
        self.save_hyperparameters()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers = 2, batch_first=True, bidirectional=True)
        
        # Classification head
        self.fc_class = nn.Linear(hidden_size * 2, output_size)
        
        # Decoder for the autoencoding task
        self.fc_decoder = nn.Linear(hidden_size * 2, input_size)
        self.learning_rate = learning_rate
        
        # Loss functions
        self.criterion_class = nn.BCEWithLogitsLoss()
        self.criterion_reconstruction = nn.MSELoss()
        
        # Weights for multitask loss
        self.alpha = alpha
        self.val_outputs = []

    def forward(self, x):
        # x: (batch_size, seq_len, input_size)
        rnn_out, _ = self.rnn(x)  # rnn_out: (batch_size, seq_len, hidden_size * 2)
        last_hidden = rnn_out[:, -1, :]  # (batch_size, hidden_size * 2)
        
        # Classification output
        classification_output = self.fc_class(last_hidden)  # (batch_size, output_size)
        
        # Reconstruction output (using all hidden states, not just the last)
        reconstruction_output = self.fc_decoder(rnn_out)  # (batch_size, seq_len, input_size)
        
        return classification_output, reconstruction_output

    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Forward pass
        classification_logits, reconstruction_output = self(x)
        classification_logits = classification_logits.squeeze(1)
        
        # Compute the classification loss
        loss_class = self.criterion_class(classification_logits, y)
        
        # Compute the reconstruction loss
        loss_reconstruction = self.criterion_reconstruction(reconstruction_output, x)
        
        # Combine the losses
        loss = self.alpha * loss_class + (1 - self.alpha) * loss_reconstruction
        
        # Log losses
        self.log('train_loss', loss)
        self.log('train_class_loss', loss_class)
        self.log('train_reconstruction_loss', loss_reconstruction)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        # Forward pass
        classification_logits, reconstruction_output = self(x)
        classification_logits = classification_logits.squeeze(1)
        
        # Compute the classification loss
        loss_class = self.criterion_class(classification_logits, y)
        
        # Compute the reconstruction loss
        loss_reconstruction = self.criterion_reconstruction(reconstruction_output, x)
        
        # Combine the losses
        loss = self.alpha * loss_class + (1 - self.alpha) * loss_reconstruction
        
        # Calculate metrics
        preds = torch.sigmoid(classification_logits)
        self.log('val_loss', loss, prog_bar=True)
        
        # Save outputs for PR AUC and ROC AUC calculation
        self.val_outputs.append({'preds': preds, 'targets': y})

    def on_validation_epoch_end(self):
        # Gather predictions and targets across all validation batches
        all_preds = torch.cat([x['preds'] for x in self.val_outputs], dim=0).cpu().numpy()
        all_targets = torch.cat([x['targets'] for x in self.val_outputs], dim=0).cpu().numpy()
        
        # Calculate Precision-Recall AUC
        precision, recall, _ = precision_recall_curve(all_targets, all_preds)
        pr_auc = auc(recall, precision)

        # Calculate ROC AUC
        roc_auc = roc_auc_score(all_targets, all_preds)
        
        # Combine PR AUC and ROC AUC
        combined_metric = (pr_auc + roc_auc) / 2
        
        # Log the combined metric
        self.log('combined_metric', combined_metric, prog_bar=True)

        # Log PR AUC and ROC AUC separately
        self.log('val_pr_auc', pr_auc, prog_bar=True)
        self.log('val_roc_auc', roc_auc, prog_bar=True)

        # Clear the outputs for the next epoch
        self.val_outputs.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

def load_dataset(df):
    n_steps = 3

    sequences = []
    for i in range(1, n_steps + 1):
        step_cols = [col for col in df.columns if col.endswith(f'_{i}')]
        step_data = df[step_cols].values
        sequences.append(step_data)
    
    sequences = np.stack(sequences, axis=1)  # (num_samples, n_steps, 3)
    
    labels = df['label'].values
    return sequences, labels

# Create datasets
def create_dataloaders(train_data, train_labels, val_data, val_labels, batch_size=32):
    train_dataset = TensorDataset(torch.tensor(train_data, dtype=torch.float32), torch.tensor(train_labels, dtype=torch.float32))
    val_dataset = TensorDataset(torch.tensor(val_data, dtype=torch.float32), torch.tensor(val_labels, dtype=torch.float32))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

# Model training
def train_model(train_data, train_labels, val_data, val_labels, input_size, hidden_size, output_size):
    train_loader, val_loader = create_dataloaders(train_data, train_labels, val_data, val_labels)

    model = MultitaskAutoencoderRNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='combined_metric',
        mode='max',
        save_top_k=1,
        filename='best-checkpoint',
        verbose=True
    )

    trainer = pl.Trainer(max_epochs=30, callbacks=[checkpoint_callback], log_every_n_steps=10)
    trainer.fit(model, train_loader, val_loader)

    best_model_path = checkpoint_callback.best_model_path
    best_model = MultitaskAutoencoderRNN.load_from_checkpoint(best_model_path)

    return best_model

full_train_data, full_train_labels = load_dataset(train_scaled)

train_data, val_data, train_labels, val_labels = train_test_split(
    full_train_data, full_train_labels, test_size=0.10, random_state=42, stratify=full_train_labels
)

# Model parameters
input_size = train_data.shape[2]  # Based on your dataset
hidden_size = 128
output_size = 1  # Binary classification

model = train_model(train_data, train_labels, val_data, val_labels, input_size, hidden_size, output_size)

Trainer will use only 1 of 10 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=10)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8,9]

  | Name                     | Type              | Params | Mode 
-----------------------------------------------------------------------
0 | rnn                      | RNN               | 143 K  | train
1 | fc_class                 | Linear            | 257    | train
2 | fc_decoder               | Linear            | 11.6 K | train
3 | criterion_class          | BCEWithLogitsLoss | 0      | train
4 | criterion_reconstruction | MSELoss           | 0      | train
-----------------------------------------------------------------------
1

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/fangyu.hoo/miniconda3/envs/search3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/fangyu.hoo/miniconda3/envs/search3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 0, global step 2724: 'combined_metric' reached 0.64692 (best 0.64692), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 1, global step 5448: 'combined_metric' reached 0.66571 (best 0.66571), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 2, global step 8172: 'combined_metric' reached 0.67123 (best 0.67123), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 3, global step 10896: 'combined_metric' reached 0.67889 (best 0.67889), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 4, global step 13620: 'combined_metric' reached 0.68197 (best 0.68197), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 5, global step 16344: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 6, global step 19068: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 7, global step 21792: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 8, global step 24516: 'combined_metric' reached 0.68904 (best 0.68904), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 9, global step 27240: 'combined_metric' reached 0.68975 (best 0.68975), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 10, global step 29964: 'combined_metric' reached 0.69666 (best 0.69666), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 11, global step 32688: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 12, global step 35412: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 13, global step 38136: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 14, global step 40860: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 15, global step 43584: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 16, global step 46308: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 17, global step 49032: 'combined_metric' reached 0.70300 (best 0.70300), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 18, global step 51756: 'combined_metric' reached 0.70487 (best 0.70487), saving model to '/mnt/ssfs/usr/fangyu.hoo/dsa4262/lightning_logs/version_29/checkpoints/best-checkpoint.ckpt' as top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 19, global step 54480: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 20, global step 57204: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 21, global step 59928: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 22, global step 62652: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 23, global step 65376: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 24, global step 68100: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 25, global step 70824: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 26, global step 73548: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 27, global step 76272: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 28, global step 78996: 'combined_metric' was not in top 1


Validation: |                                             | 0/? [00:00<?, ?it/s]

Epoch 29, global step 81720: 'combined_metric' was not in top 1
`Trainer.fit` stopped: `max_epochs=30` reached.


In [55]:
def create_test_dataloader(test_data, batch_size=32):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    test_dataset = TensorDataset(torch.tensor(test_data, dtype=torch.float32).to(device))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

test_data, test_labels = load_dataset(test_scaled)
test_loader = create_test_dataloader(test_data)

In [56]:
def predict_test_set(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    all_logits = []
    with torch.no_grad():  # Disable gradient calculation for inference
        for batch in test_loader:
            x_test = batch[0]  # Extract the test features
            logits, _ = model(x_test)  # Get the predictions/logits
            proba = torch.sigmoid(logits)
            all_logits.append(proba)
    
    # Concatenate all logits into a single tensor
    all_logits = torch.cat(all_logits, dim=0).cpu().numpy()
    return all_logits

test_logits = predict_test_set(model, test_loader)

In [57]:
print(f'roc auc: {round(roc_auc_score(test_labels, test_logits),4)}')
precision, recall, thresholds = precision_recall_curve(test_labels, test_logits)
print(f'pr auc: {round(auc(recall, precision),4)}')

roc auc: 0.9153
pr auc: 0.4647
