In [1]:
%reload_ext autoreload
%autoreload 2
from pathlib import Path
from sklearn.model_selection import train_test_split

import sys
sys.path.append("..")
from dataset import ManualFeatureDataset, ManualFeatureDataModule

## Train/Test Split

In [2]:
root_dir = Path("/media/nvme1/icare-data/6h-features")
labels_csv = Path("/home/bc299/icare/patient_data.csv")

## Dataset Setup

In [3]:
all_patient_ids = [dir_.name for dir_ in root_dir.iterdir()]
train_ids, temp_ids = train_test_split(all_patient_ids, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=2/3, random_state=42)

In [4]:
dataset = ManualFeatureDataset(root_dir, labels_csv, train_ids)

In [5]:
data_module = ManualFeatureDataModule(root_dir=root_dir,
                                      labels_csv=labels_csv,
                                      batch_size=32)
data_module.setup()

## Train Model

In [6]:
import torch
import lightning.pytorch as pl
from torch.nn.functional import nll_loss
from model import BiLSTMClassifier
from sklearn.metrics import roc_auc_score, accuracy_score

In [7]:
class BiLSTMClassifierModule(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_layers, dropout, learning_rate):
        super().__init__()

        self.model = BiLSTMClassifier(input_size, hidden_size, num_layers, dropout)
        self.learning_rate = learning_rate
        self.save_hyperparameters()

        self.test_step_outputs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    # The test step will evaluate the performance for each 6h epoch
    # assuming that the 6h epochs are in order and the batch size is 1.
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_step_outputs.append({
            "predictions": preds,
            "labels": y,
            "batch_idx": batch_idx
        })
        return loss
    
    def on_test_epoch_end(self):
        # Create a list of epoch names based on data
        epoch_names = [str(x) for x in range(12, 72+1, 6)]
        # Store aggregated labels and predictions
        aggregated_labels = {name: [] for name in epoch_names}
        aggregated_preds = {name: [] for name in epoch_names}
        # Aggregate labels and predictions based on batch_idx (6h epochs)
        for output in self.test_step_outputs:
            batch_idx = output["batch_idx"]
            epoch_name = epoch_names[batch_idx]
            aggregated_labels[epoch_name].extend(output["labels"].cpu().numpy())
            aggregated_preds[epoch_name].extend(output["predictions"].cpu().numpy())
        # Compute and log metrics for each 6h epoch
        for epoch_name in epoch_names:
            labels = aggregated_labels[epoch_name]
            preds = aggregated_preds[epoch_name]
            acc = accuracy_score(labels, preds)
            self.log(f"test_acc_{epoch_name}", acc)
        # Free memory
        self.test_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [8]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

data_module = ManualFeatureDataModule(root_dir, labels_csv, batch_size=32)
model = BiLSTMClassifierModule(input_size=8, hidden_size=128, num_layers=4, dropout=0.5, learning_rate=1e-5)
logger = WandbLogger(project="test-project", name="test-by-epoch")
trainer = Trainer(max_epochs=50, logger=logger)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbillchen0011[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model, data_module)

You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | BiLSTMClassifier | 1.3 M 
-------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.310     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 49: 100%|██████████| 123/123 [00:00<00:00, 169.26it/s, v_num=ztvk, val_loss=0.629, val_acc=0.624]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 123/123 [00:00<00:00, 165.23it/s, v_num=ztvk, val_loss=0.629, val_acc=0.624]


In [10]:
trainer.test(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 11/11 [00:00<00:00, 320.05it/s]


[{'test_acc_12': 0.6764705882352942,
  'test_acc_18': 0.5980392156862745,
  'test_acc_24': 0.6372549019607843,
  'test_acc_30': 0.6764705882352942,
  'test_acc_36': 0.6078431372549019,
  'test_acc_42': 0.5686274509803921,
  'test_acc_48': 0.5490196078431373,
  'test_acc_54': 0.5392156862745098,
  'test_acc_60': 0.5490196078431373,
  'test_acc_66': 0.5784313725490197,
  'test_acc_72': 0.5980392156862745}]