In [1]:
import os
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn.functional as F
from glob import glob

import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers

import torchmetrics

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [28]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [4]:
interictal_data = torch.load("data/interictal.pt")
preictal_data = torch.load("data/preictal.pt")

In [5]:
test_patient = 'MSEL_00172'
val_patient = 'MSEL_00095'

In [6]:
interictal_train_X = torch.cat([patient_data for patient_id, patient_data in interictal_data.items() if patient_id not in [val_patient, test_patient]])
preictal_train_X = torch.cat([patient_data for patient_id, patient_data in preictal_data.items() if patient_id not in [val_patient, test_patient]])
# Undersample interictal data
interictal_train_X = interictal_train_X[::int(len(interictal_train_X)/len(preictal_train_X))]
train_X = torch.cat([interictal_train_X, preictal_train_X])

interictal_train_y = torch.zeros(interictal_train_X.shape[0])
preictal_train_y = torch.ones(preictal_train_X.shape[0])
train_y = torch.cat([interictal_train_y, preictal_train_y])

In [7]:
interictal_val_X = interictal_data[val_patient]
preictal_val_X = preictal_data[val_patient]
# Undersample interictal data
interictal_val_X = interictal_val_X[::int(len(interictal_val_X)/len(preictal_val_X))]
val_X = torch.cat([interictal_val_X, preictal_val_X])

interictal_val_y = torch.zeros(interictal_val_X.shape[0])
preictal_val_y = torch.ones(preictal_val_X.shape[0])
val_y = torch.cat([interictal_val_y, preictal_val_y])

In [8]:
interictal_test_X = interictal_data[test_patient]
preictal_test_X = preictal_data[test_patient]
# Undersample interictal data
# interictal_test_X = interictal_test_X[::int(len(interictal_test_X)/len(preictal_test_X))]
test_X = torch.cat([interictal_test_X, preictal_test_X])

interictal_test_y = torch.zeros(interictal_test_X.shape[0])
preictal_test_y = torch.ones(preictal_test_X.shape[0])
test_y = torch.cat([interictal_test_y, preictal_test_y])

In [9]:
train_dataset = TensorDataset(train_X, train_y)
val_dataset = TensorDataset(val_X, val_y)
test_dataset = TensorDataset(test_X, test_y)

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False,)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Testing samples: {len(test_dataset)}")

Training samples: 12648
Validation samples: 479
Testing samples: 5834


In [3]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=5):
        super().__init__()
        self.gamma = gamma
    
    def forward(self, p, y):
        p_t = y*p + (1-y)*(1-p)
        scale_factor = -(1-p_t)**self.gamma
        losses = scale_factor * torch.log(p_t)
        return losses.mean()

In [4]:
class LitSegmentClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
    
        self.forward_lstm = nn.LSTM(input_size=6, hidden_size=500, batch_first=True)
        self.backward_lstm = nn.LSTM(input_size=6, hidden_size=500, batch_first=True)
        self.linear_stack = nn.Sequential(
            nn.Linear(1000, 500), 
            nn.ReLU(), 
            # nn.Dropout(0.7),
            nn.Linear(500, 200),
            nn.ReLU(), 
            # nn.Dropout(0.7),
            nn.Linear(200, 100),
            nn.ReLU(), 
            # nn.Dropout(0.7),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )

        self.loss = FocalLoss()

        self.accuracy = torchmetrics.Accuracy(threshold=0.5)

        self.lr = 1e-3

        self.example_input_array = torch.rand((1, 120, 6))

    def forward(self, x):
        _, (forward_h, _)  = self.forward_lstm(x)
        _, (backward_h, _)  = self.backward_lstm(torch.flip(x, dims=[0]))
        h = torch.cat([forward_h[-1], backward_h[-1]], dim=1)
        output = self.linear_stack(h.squeeze())
        output = output.squeeze()
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss(pred, y)
        self.log("Training Loss", loss)
        self.log('Training Accuracy', self.accuracy(pred, y.int()))
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.loss(pred, y)
        self.log("Validation Loss", loss)
        self.log('Validation Accuracy', self.accuracy(pred, y.int()))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler}         


In [74]:
model = LitSegmentClassifier()
# model = LitSegmentClassifier.load_from_checkpoint(r"lightning_logs\Model\lr_scheduler\version_1\checkpoints\epoch=199-step=9799.ckpt")

tb_logger = pl_loggers.TensorBoardLogger("lightning_logs/", name="Model/focal_loss", log_graph=True)
trainer = pl.Trainer(gpus=1, max_epochs=5000, logger=tb_logger)#, callbacks=[EarlyStopping(monitor="Validation Loss", patience=20)], auto_lr_find=True)
# trainer.tune(model, train_dataloader, val_dataloader)
# trainer.fit(model, train_dataloader, val_dataloader)
trainer.fit(model, test_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params | In sizes    | Out sizes                                  
---------------------------------------------------------------------------------------------------------
0 | forward_lstm  | LSTM       | 1.0 M  | [1, 120, 6] | [[1, 120, 500], [[1, 1, 500], [1, 1, 500]]]
1 | backward_lstm | LSTM       | 1.0 M  | [1, 120, 6] | [[1, 120, 500], [[1, 1, 500], [1, 1, 500]]]
2 | linear_stack  | Sequential | 620 K  | [1000]      | [1]                                        
3 | loss          | FocalLoss  | 0      | ?           | ?                                          
4 | accuracy      | Accuracy   | 0      | ?           | ?                                          
---------------------------------------------------------------------------------------------------------
2.7 M     Trainable params
0        

Epoch 4:  91%|█████████▏| 21/23 [02:05<00:11,  5.68s/it, loss=0.0141, v_num=44]
Epoch 10:  74%|███████▍  | 17/23 [01:36<00:32,  5.37s/it, loss=0.0125, v_num=45]
Epoch 4999: 100%|██████████| 23/23 [00:01<00:00, 17.75it/s, loss=0.00295, v_num=46]


In [75]:
len(interictal_test_X)/(len(preictal_test_X)+len(interictal_test_X))

0.859273225917038

In [76]:
len(preictal_test_X)

821

# Evaluation on a Test Patient

In [5]:
from evaluation import evaluate

In [6]:
test_patient = 'MSEL_00172'

In [7]:
model = LitSegmentClassifier.load_from_checkpoint(r"lightning_logs\Model\focal_loss\version_46\checkpoints\epoch=4999-step=114999.ckpt")

In [9]:
evaluate(model, test_patient, integration_windows=[300000, 600000, 900000], thresholds=np.round(np.arange(0.4, 0.7, 0.02),2), timer_duration=450000, detection_interval=60000)

Unnamed: 0,Integration Window,threshold,S,TiW,IoC,p
300000_0.4,300000,0.4,0.7,0.348215,0.388963,0.01312
300000_0.42,300000,0.42,0.7,0.318778,0.416073,0.007618
300000_0.44,300000,0.44,0.7,0.291794,0.440774,0.004381
300000_0.46,300000,0.46,0.7,0.257451,0.472012,0.001977
300000_0.48,300000,0.48,0.7,0.234024,0.493198,0.001069
300000_0.5,300000,0.5,0.7,0.228382,0.498286,0.000913
300000_0.52,300000,0.52,0.6,0.208758,0.415941,0.004127
300000_0.54,300000,0.54,0.6,0.187416,0.435067,0.002306
300000_0.56,300000,0.56,0.6,0.169876,0.450729,0.001347
300000_0.58,300000,0.58,0.6,0.159818,0.459687,0.000962


## Read test patients data

## Finding Seizure Start Times

## Making Predictions from Segments

## Metrics

Unnamed: 0,Integration Window,threshold,S,TiW,IoC,p
300000_0.4,300000,0.4,0.7,0.348215,0.388963,0.01312
300000_0.42,300000,0.42,0.7,0.318778,0.416073,0.007618
300000_0.44,300000,0.44,0.7,0.291794,0.440774,0.004381
300000_0.46,300000,0.46,0.7,0.257451,0.472012,0.001977
300000_0.48,300000,0.48,0.7,0.234024,0.493198,0.001069
300000_0.5,300000,0.5,0.7,0.228382,0.498286,0.000913
300000_0.52,300000,0.52,0.6,0.208758,0.415941,0.004127
300000_0.54,300000,0.54,0.6,0.187416,0.435067,0.002306
300000_0.56,300000,0.56,0.6,0.169876,0.450729,0.001347
300000_0.58,300000,0.58,0.6,0.159818,0.459687,0.000962
