In [4]:
%reload_ext autoreload
%autoreload 2
from pathlib import Path
from sklearn.model_selection import train_test_split

import sys
sys.path.append("..")
from dataset import ManualFeatureDataset, ManualFeatureDataModule

## Train/Test Split

In [5]:
root_dir = Path("/media/nvme1/icare-data/6h-features")
labels_csv = Path("/home/bc299/icare/patient_data.csv")

## Dataset Setup

In [6]:
all_patient_ids = [dir_.name for dir_ in root_dir.iterdir()]
train_ids, temp_ids = train_test_split(all_patient_ids, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=2/3, random_state=42)

In [7]:
dataset = ManualFeatureDataset(root_dir, labels_csv, train_ids)

In [12]:
data_module = ManualFeatureDataModule(root_dir=root_dir,
                                      labels_csv=labels_csv,
                                      batch_size=32)
data_module.setup()

In [15]:
for batch_idx, (x, y) in enumerate(data_module.test_dataloader()):
    print(x.size())

torch.Size([32, 11, 8, 144])
torch.Size([32, 11, 8, 144])
torch.Size([32, 11, 8, 144])
torch.Size([6, 11, 8, 144])


## Train Model

In [20]:
import torch
import lightning.pytorch as pl
from torch.optim import Adam
from torch.nn.functional import nll_loss
from model import BiLSTMClassifier

In [21]:
class BiLSTMClassifierModule(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_layers, dropout, learning_rate):
        super().__init__()

        self.model = BiLSTMClassifier(input_size, hidden_size, num_layers, dropout)
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
data_module = ManualFeatureDataModule(root_dir, labels_csv, batch_size=32)
model = BiLSTMClassifier()