In [1]:
import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from glob import glob

import pandas as pd
import numpy as np
import pytorch_lightning as pl

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [3]:
class SegmentDataset(Dataset):
    def __init__(self, patients, sampling=None, scaled=True):
        if sampling not in [None, "undersampling", "oversampling"]:
            raise ValueError("Sampling must be one of None, undersampling or oversampling")
        self.segment_files = []
        self.labels = []
        for patient in patients:
            if scaled:
                interictal_segment_files = glob(f"data/segments/scaled/{patient}/interictal/{patient}_interictal_scaled_segment_*.parquet")
                preictal_segment_files= glob(f"data/segments/scaled/{patient}/preictal/{patient}_preictal_scaled_segment_*.parquet")
            else:
                interictal_segment_files = glob(f"data/segments/raw/{patient}/interictal/{patient}_interictal_segment_*.parquet")
                preictal_segment_files= glob(f"data/segments/raw/{patient}/preictal/{patient}_preictal_segment_*.parquet")
            if sampling == "undersampling":
                interictal_segment_files = list(np.random.choice(interictal_segment_files, size=len(preictal_segment_files), replace=False))
            elif sampling == "oversampling":
                preictal_segment_files = list(np.random.choice(preictal_segment_files, size=len(interictal_segment_files), replace=True))
            self.segment_files.extend(interictal_segment_files + preictal_segment_files)
            self.labels.extend([0.0 for file in interictal_segment_files] + [1.0 for file in preictal_segment_files])
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        segment_file = self.segment_files[idx]
        segment_df = pd.read_parquet(segment_file).fillna(0)
        segment_feature_array = np.concatenate([segment_df.mean(), segment_df.std()])
        segment_features = torch.Tensor(segment_feature_array)
        label = self.labels[idx]
        return segment_features, label
        

In [4]:
class LogRegClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.linear_1 = nn.Linear(12, 10)
        self.linear_2 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        output = self.linear_1(x)
        output = self.relu(output)
        output = self.linear_2(output)
        output = output.squeeze()
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x).squeeze()
        loss = F.binary_cross_entropy_with_logits(pred, y)
        self.log("Training Loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x).squeeze()
        loss = F.binary_cross_entropy_with_logits(pred, y)
        self.log("Validation Loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
        return optimizer         


In [5]:
train_patients = ["MSEL_00172", "MSEL_00501", "MSEL_01097", "MSEL_01575", "MSEL_01808", "MSEL_01838"]
test_patients = ["MSEL_01842"]

train_data = SegmentDataset(train_patients, sampling="undersampling")
test_data = SegmentDataset(test_patients, sampling="undersampling")

train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)
print(f"Training samples:{len(train_data)}")
print(f"Testinging samples:{len(test_data)}")


Training samples:4320
Testinging samples:480


In [6]:
model = LogRegClassifier()
trainer = pl.Trainer(gpus=1, max_epochs=50, log_every_n_steps=1)
trainer.fit(model, train_dataloader, test_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type   | Params
------------------------------------
0 | linear_1 | Linear | 130   
1 | linear_2 | Linear | 11    
2 | relu     | ReLU   | 0     
------------------------------------
141       Trainable params
0         Non-trainable params
141       Total params
0.001     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/38 [00:00<?, ?it/s] 

  rank_zero_warn(


Epoch 49: 100%|██████████| 38/38 [00:15<00:00,  2.46it/s, loss=0.657, v_num=36]
