In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import time
import random
import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
import os

class NeonatalVentilationDataset(pl.LightningDataModule):
    def __init__(self, train_index, train_wf, test_index, test_wf, target, batch_size, fixed_len = None, oversample=True, center=True):
        self.df = pd.read_csv(train_index)
        self.waveforms = pd.read_csv(train_wf).iloc[:,1:]
        self.test_df = pd.read_csv(test_index)
        self.test_waveforms = pd.read_csv(test_wf).iloc[:,1:]
        self.target = target
        self.train_subset = None
        self.val_subset = None
        self.test_subset = None
        self.batch_size = batch_size
        self.fixed_len = fixed_len
        self.oversample = oversample
        self.center = center
        self.xs = [None] * self.df.shape[0]
        self.train_ind = None
        self.val_ind = None
        self.test_ind = None
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        s = self.df["Sample"].iat[idx]
        i = self.df["original_index"].iat[idx]
        if self.xs[idx] is None:
            w = self.waveforms.loc[(self.waveforms["Sample"] == s) & (self.waveforms["original_index"] == i)].iloc[:,:2]
            if self.center:
                w = w.interpolate().apply(lambda x : (x - np.mean(x)) / np.std(x))
            else:
                w = w.interpolate().apply(lambda x : x / np.std(x))
            if self.fixed_len != None:
                if w.shape[0] > self.fixed_len:
                    w = w.iloc[:self.fixed_len,:].T
                else:
                    pad = pd.DataFrame(np.zeros((self.fixed_len - w.shape[0],2)))
                    pad.columns = w.columns 
                    w = pd.concat((w, pad)).T
            self.xs[idx] = torch.tensor(w.to_numpy().astype(np.float32))
        return torch.tensor(self.df.iloc[idx, :(self.df.shape[1]-3)].to_numpy().astype(np.float32)), self.xs[idx], torch.tensor(self.df[self.target].iat[idx].astype(np.float32))
    
    def get_train_test_samplers(self, test_size, seed=42):
        train, val, _, _ = train_test_split(range(self.df.shape[0]), range(self.df.shape[0]), test_size=test_size, random_state=seed, stratify=self.df[self.target])
        if self.oversample:
            train_df = self.df.iloc[train, :]
            target_n = train_df.loc[train_df[self.target] == 1].shape[0]
            off_target_n = train_df.loc[train_df[self.target] == 0].shape[0]
            if target_n > off_target_n:
                diff = target_n - off_target_n
                new_indices = range(self.df.shape[0], self.df.shape[0] + diff)
                self.df = pd.concat([self.df, train_df[train_df[self.target] == 0].sample(diff, replace=True)])
            else:
                diff = off_target_n - target_n
                new_indices = range(self.df.shape[0], self.df.shape[0] + diff)
                self.df = pd.concat([self.df, train_df[train_df[self.target] == 1].sample(diff, replace=True)])
            train = train + list(new_indices)
        self.train_ind = train
        self.val_ind = val
        self.test_ind = list(range(self.df.shape[0], self.df.shape[0] + self.test_df.shape[0]))
        self.df = pd.concat([self.df, self.test_df]).reset_index().iloc[:,2:].dropna(axis=1)
        self.waveforms = pd.concat([self.waveforms, self.test_waveforms])
        self.xs = [None] * self.df.shape[0]
        return (SubsetRandomSampler(train), SequentialSampler(val))
    
    def train_dataloader(self, seed=42):
        if self.train_subset is None:
            self.train_subset, self.val_subset = self.get_train_test_samplers(0.2, seed=seed)
            self.test_subset = SequentialSampler(self.test_ind)
        return DataLoader(self, batch_size=self.batch_size, sampler=self.train_subset, pin_memory=True)
    
    def val_dataloader(self, seed=42):
        if self.val_subset is None:
            self.train_subset, self.val_subset = self.get_train_test_samplers(0.2, seed=seed)
            self.test_subset = SequentialSampler(self.test_ind)
        return DataLoader(self, batch_size=self.batch_size, sampler=self.val_subset, pin_memory=True)
    
    def test_dataloader(self, seed=42):
        if self.train_subset is None:
            self.train_subset, self.val_subset = self.get_train_test_samplers(0.2, seed=seed)
            self.test_subset = SequentialSampler(self.test_ind)
        return DataLoader(self, batch_size=self.batch_size, sampler=self.test_subset, pin_memory=True)



In [2]:
def conv_len(lin, pad, dil, ker, stride):
    return int((lin + 2 * pad - dil * (ker - 1) - 1) / stride + 1)

def pool_len(lin, pad, ker, stride):
    return int((lin + 2 * pad - ker) / stride + 1)

class AbnormalBreathDetectorCNN(pl.LightningModule):
    def __init__(self, learning_rate, k_sizes, dilatations, cout, fixed_len, dropout, pool_kern=None):
        super(AbnormalBreathDetectorCNN, self).__init__()
        self.save_hyperparameters()
        self.k_sizes = k_sizes
        self.dilatations = dilatations
        self.c_out_n = cout
        self.fixed_len = fixed_len
        self.dropout = torch.nn.Dropout(dropout)
        self.pool_kern = pool_kern
        self.conv_modules = nn.ModuleList([nn.Conv1d(2, self.c_out_n, k, dilation=d) for d in self.dilatations for k in self.k_sizes])
        #self.conv_modules = nn.ModuleList([nn.Conv1d(2, self.c_out_n, k) for k in self.k_sizes])
        if self.pool_kern is not None:
            self.pooling = nn.AvgPool1d(self.pool_kern)
            self.out_sizes = [pool_len(conv_len(self.fixed_len, 0, d, k, 1), 0, self.pool_kern, self.pool_kern) for d in self.dilatations for k in self.k_sizes]
        else:
            self.out_sizes = [conv_len(self.fixed_len, 0, d, k, 1) for d in self.dilatations for k in self.k_sizes]
        self.fc = torch.nn.Linear(sum(self.out_sizes) * self.c_out_n, 1)
        self.relu = torch.nn.ReLU()
        self.lr = learning_rate
        self.epoch = 0
        
    def forward(self, xs):
        if self.pool_kern is not None:
            convs = [self.dropout(self.pooling(self.relu(m(xs)))) for m in self.conv_modules]
        else:
            convs = [self.dropout(self.relu(m(xs))) for m in self.conv_modules]
        convs = torch.cat([convs[i].view(-1, self.out_sizes[i] * self.c_out_n) for i in range(len(self.out_sizes))], dim=1)
        out = self.fc(convs)
        pred = torch.sigmoid(out)
        return pred
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=5)
        return {"optimizer" : optimizer, "lr_scheduler" : lr_scheduler, "monitor" : "Validation_F1_log"}
    
    def training_step(self, batch, batch_idx):
        _, x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat.flatten(), y)
        return {"loss" : loss, "preds" : y_hat.flatten(), "labels" : y}
    
    def training_epoch_end(self, training_step_outputs):
        avg_loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
        preds = torch.cat([x["preds"] for x in training_step_outputs])
        targets = torch.cat([x["labels"] for x in training_step_outputs])
        print("Epoch {} train".format(self.epoch))
        confmat = torchmetrics.ConfusionMatrix(num_classes=2)
        print(confmat(preds.cpu(), targets.int().cpu()))
        f1score = torchmetrics.F1Score(num_classes=1)
        f1_epoch = f1score(preds.cpu(), targets.int().cpu())
        recall = torchmetrics.Recall(num_classes=1)
        recall_epoch = recall(preds.cpu(), targets.int().cpu())
        specificity = torchmetrics.Specificity(num_classes=1)
        specificity_epoch = specificity(preds.cpu(), targets.int().cpu())
        auroc = torchmetrics.AUROC(pos_label=1)
        auroc_epoch = auroc(preds.cpu(), targets.int().cpu())
        self.logger.experiment.add_scalar("Train_Loss", avg_loss.item(), self.epoch)
        self.logger.experiment.add_scalar("Train_F1", f1_epoch, self.epoch)
        self.logger.experiment.add_scalar("Train_recall", recall_epoch, self.epoch)
        self.logger.experiment.add_scalar("Train_specificity", specificity_epoch, self.epoch)
        self.logger.experiment.add_scalar("Train_AUROC", auroc_epoch, self.epoch)
    
    def validation_step(self, batch, batch_idx):
        _, x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat.flatten(), y)
        return {"loss" : loss, "preds" : y_hat.flatten(), "labels" : y}
    
    def validation_epoch_end(self, validation_step_outputs):
        avg_loss = torch.stack([x["loss"] for x in validation_step_outputs]).mean()
        preds = torch.cat([x["preds"] for x in validation_step_outputs])
        targets = torch.cat([x["labels"] for x in validation_step_outputs])
        print("Epoch {} validation".format(self.epoch))
        confmat = torchmetrics.ConfusionMatrix(num_classes=2)
        print(confmat(preds.cpu(), targets.int().cpu()))
        f1score = torchmetrics.F1Score(num_classes=1)
        f1_epoch = f1score(preds.cpu(), targets.int().cpu())
        recall = torchmetrics.Recall(num_classes=1)
        recall_epoch = recall(preds.cpu(), targets.int().cpu())
        specificity = torchmetrics.Specificity(num_classes=1)
        specificity_epoch = specificity(preds.cpu(), targets.int().cpu())
        auroc = torchmetrics.AUROC(pos_label=1)
        auroc_epoch = auroc(preds.cpu(), targets.int().cpu())
        self.logger.experiment.add_scalar("Validation_Loss", avg_loss.item(), self.epoch)
        self.logger.experiment.add_scalar("Validation_F1", f1_epoch, self.epoch)
        self.logger.experiment.add_scalar("Validation_recall", recall_epoch, self.epoch)
        self.logger.experiment.add_scalar("Validation_specificity", specificity_epoch, self.epoch)
        self.logger.experiment.add_scalar("Validation_AUROC", auroc_epoch, self.epoch)
        self.log("Validation_F1_log", f1_epoch)
        self.epoch += 1
    
    def test_step(self, batch, batch_idx):
        _, x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat.flatten(), y)
        return {"loss" : loss, "preds" : y_hat.flatten(), "labels" : y}
    
    def test_epoch_end(self, validation_step_outputs):
        avg_loss = torch.stack([x["loss"] for x in validation_step_outputs]).mean()
        preds = torch.cat([x["preds"] for x in validation_step_outputs])
        targets = torch.cat([x["labels"] for x in validation_step_outputs])
        print("Test Results")
        confmat = torchmetrics.ConfusionMatrix(num_classes=2)
        print(confmat(preds.cpu(), targets.int().cpu()))
        f1score = torchmetrics.F1Score(num_classes=1)
        f1_epoch = f1score(preds.cpu(), targets.int().cpu())
        recall = torchmetrics.Recall(num_classes=1)
        recall_epoch = recall(preds.cpu(), targets.int().cpu())
        specificity = torchmetrics.Specificity(num_classes=1)
        specificity_epoch = specificity(preds.cpu(), targets.int().cpu())
        auroc = torchmetrics.AUROC(pos_label=1)
        auroc_epoch = auroc(preds.cpu(), targets.int().cpu())
        self.logger.experiment.add_scalar("Test_Loss", avg_loss.item(), self.epoch)
        self.logger.experiment.add_scalar("Test_F1", f1_epoch, self.epoch)
        self.logger.experiment.add_scalar("Test_Recall", recall_epoch, self.epoch)
        self.logger.experiment.add_scalar("Test_Specificity", specificity_epoch, self.epoch)
        self.logger.experiment.add_scalar("Test_AUROC", auroc_epoch, self.epoch)

In [None]:
breath_max_len = 300
batch_size = 64
max_epoch = 100
oversampling = True
dropout = 0
seed = 641
asynchronies = ["early trigger", "late trigger", "failed trigger", "multiple trigger", "early cycling", "late cycling", "expiratory work", "splinting"]
kernal_sizes = [5,11,21]
dilatation_sizes = [1]
pooling = [False]
lrs = [0.001]

for l in lrs:
    for p in pooling:
        for a in asynchronies:
            data = NeonatalVentilationDataset("../data/train/combined.csv", "../data/train/waveforms.csv",
                                              "../data/test/combined.csv", "../data/test/waveforms.csv",
                                              a, batch_size, breath_max_len, oversampling, center=True)
            stem = "CNN_lr{}_kern{}_dil{}_filters128_{}_{}".format(l, "-".join([str(x) for x in kernal_sizes]), "-".join([str(x) for x in dilatation_sizes]), "pooled" if p else "no_pool", a)
            if os.path.isdir(stem):
                continue
            tb_logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), version=stem)
            checkpoint_cb = ModelCheckpoint(save_top_k = 1, 
                                            monitor="Validation_F1_log", 
                                            mode="max", 
                                            dirpath="model_checkpoints",
                                            filename=stem + "{epoch:02d}-{Validation_F1_log:.2f}")
            earlystopping_cb = EarlyStopping(monitor="Validation_F1_log", mode="max", patience=10)
            if p:
                model = AbnormalBreathDetectorCNN(l, kernal_sizes, dilatation_sizes, 128, breath_max_len, dropout, pool_kern = 3)
            else:
                model = AbnormalBreathDetectorCNN(l, kernal_sizes, dilatation_sizes, 128, breath_max_len, dropout)
            trainer = pl.Trainer(track_grad_norm=2, max_epochs=max_epoch, accelerator="gpu", devices=1, 
                                 callbacks=[earlystopping_cb, checkpoint_cb], logger=tb_logger)
            trainer.fit(model=model, train_dataloaders=data.train_dataloader(seed), val_dataloaders=data.val_dataloader(seed))
            trainer.test(model=model, dataloaders=data.test_dataloader(seed))

In [3]:
loaded_model = AbnormalBreathDetectorCNN.load_from_checkpoint("model_checkpoints/CNN_lr0.001_kern5-11-21_dil1_filters128_no_pool_failed triggerepoch=37-Validation_F1_log=0.91.ckpt") 
#                                                       learning_rate = 0.001, k_sizes = [5,11,21], dilatations = [1], cout = 128, fixed_len = 300, dropout = 0.5)
data = NeonatalVentilationDataset("../data/train/combined.csv", "../data/train/waveforms.csv",
                                              "../data/test/combined.csv", "../data/test/waveforms.csv",
                                              "failed trigger", 64, 300, True, center=True)
i = 0
y_hat = []
ys = []
val_set = iter(data.test_dataloader(641))
_, x, y = next(val_set)
while x is not None:
    y_hat.append(loaded_model(x.cuda()).flatten())
    ys.append(y.flatten())
    _, x, y = next(val_set, (None, None, None))
y_hat = torch.cat(y_hat).cpu()
ys = torch.cat(ys).cpu()
confmat = torchmetrics.ConfusionMatrix(task="binary", num_classes=2)
print(confmat(y_hat, ys.int()))

Lightning automatically upgraded your loaded checkpoint from v1.6.0 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file C:\Users\David Chong\Desktop\backup\Ventilation_Asynchrony_Classifier_Paper\model_dev\model_checkpoints\CNN_lr0.001_kern5-11-21_dil1_filters128_no_pool_failed triggerepoch=37-Validation_F1_log=0.91.ckpt`


tensor([[2364,   11],
        [  10,  115]])


In [4]:
false_neg = (ys == 1) & (torch.round(y_hat) == 0)
false_pos = (ys == 0) & (torch.round(y_hat) == 1)
asynchronies = ["early trigger", "late trigger", "failed trigger", "multiple trigger", "early cycling", "late cycling", "expiratory work", "splinting", "work shifting", "artefact", "unassisted breath", "no asynchrony"]
data.df.iloc[data.test_ind,:].reset_index().loc[false_pos.numpy()][asynchronies].sum()

early trigger        2
late trigger         0
failed trigger       3
multiple trigger     0
early cycling        0
late cycling         0
expiratory work      6
splinting            0
work shifting        0
artefact             0
unassisted breath    1
no asynchrony        3
dtype: int64

In [5]:
data.df.iloc[data.test_ind,:].reset_index().loc[false_neg.numpy()][asynchronies].sum()

early trigger        0
late trigger         0
failed trigger       0
multiple trigger     1
early cycling        0
late cycling         3
expiratory work      6
splinting            0
work shifting        0
artefact             1
unassisted breath    0
no asynchrony        2
dtype: int64