<a href="https://colab.research.google.com/github/invictus125/cs598-final-project/blob/main/intraoperative_hypotension_ryan_testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Update from github
!git clone https://github.com/invictus125/cs598-final-project
!cd /content/cs598-final-project && git pull

fatal: destination path 'cs598-final-project' already exists and is not an empty directory.
Already up to date.


Link to paper: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0272055

In [1]:
# Imports
!pip install torcheval
!pip install vitaldb
# import sys
# sys.path.append('/content/cs598-final-project')
# from src.model import WaveformResNet, IntraoperativeHypotensionModel
# from src.train import train



In [2]:
# TEST CELL - REMOVE

import torch
import math
from torch import nn


class EncoderBlock(nn.Module):
  def __init__(self, dim_in, kernel_size=15, stride=1):
    super(EncoderBlock, self).__init__()
    padding = math.floor(kernel_size / 2.0)
    self.conv = nn.Conv1d(1, 1, kernel_size, stride, padding=padding)
    self.mp = nn.MaxPool1d(kernel_size, stride, padding)
    self.fc = nn.Linear(dim_in, dim_in)


  def forward(self, x):
    x_hat = self.conv(x)
    x_hat = self.mp(x_hat)
    return self.fc(x_hat)


class ResidualBlock(nn.Module):
  def __init__(
    self,
    in_channels,
    out_channels,
    size_down,
    kernel_size,
    stride=1
  ):
    super(ResidualBlock, self).__init__()

    self.size_down = size_down
    self.in_channels = in_channels
    self.out_channels = out_channels

    padding = math.floor(kernel_size / 2.0)

    self.bn1 = nn.BatchNorm1d(in_channels)
    self.act1 = nn.ReLU()
    self.do = nn.Dropout()
    self.conv1 = nn.Conv1d(in_channels, in_channels, kernel_size, stride, padding)
    self.bn2 = nn.BatchNorm1d(in_channels)
    self.act2 = nn.ReLU()
    self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
    self.mp = nn.MaxPool1d(kernel_size, padding=padding, stride=2)
    self.conv_for_input = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)


  def forward(self, x):
    x_hat = self.bn1(x)
    x_hat = self.act1(x_hat)
    x_hat = self.do(x_hat)
    x_hat = self.conv1(x_hat)
    x_hat = self.bn2(x_hat)
    x_hat = self.act2(x_hat)
    x_hat = self.conv2(x_hat)

    # Adjust dimensions of input if needed for the skip connection
    x_input = None
    if self.in_channels != self.out_channels:
      x_input = self.conv_for_input(x)
    else:
      x_input = x

    x_hat = x_hat + x_input

    if self.size_down:
      x_hat = self.mp(x_hat)

    return x_hat


class FlattenAndLinearBlock(nn.Module):
  def __init__(self, dim_in, dim_out):
    super(FlattenAndLinearBlock, self).__init__()
    self.fc = nn.Linear(dim_in, dim_out)


  def forward(self, x):
    x_hat = torch.flatten(x, start_dim=1, end_dim=-1)
    x_hat = self.fc(x_hat)

    return x_hat


class WaveformResNet(nn.Module):
  def __init__(
    self,
    input_shape,
    output_size,
    data_type
  ):
    super(WaveformResNet, self).__init__()
    self.encoder = EncoderBlock(input_shape, 15, 1)
    self.res_in_dim = input_shape
    self.output_size = output_size
    self.data_type = data_type

    if data_type not in ['abp', 'ecg', 'eeg']:
      raise ValueError('Invalid data type. Must be one of [abp, ecg, eeg]')

    # Set up configurations for residual blocks
    residual_configs = []
    linear_block_input_length = -1
    if data_type in ['abp', 'ecg']:
      residual_configs = [
        {
          'kernel_size': 15,
          'in_channels': 1,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': False,
        },
      ]
      linear_block_input_length = 469 * 6
    else:
      residual_configs = [
        {
          'kernel_size': 7,
          'in_channels': 1,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': False,
        },
      ]
      linear_block_input_length = 120 * 6

    self.residuals = []
    # Build residuals
    for i in range(12):
      self.residuals.append(
        ResidualBlock(
          size_down=residual_configs[i]['size_down'],
          in_channels=residual_configs[i]['in_channels'],
          out_channels=residual_configs[i]['out_channels'],
          kernel_size=residual_configs[i]['kernel_size'],
        )
      )

    self.fl_ln = FlattenAndLinearBlock(linear_block_input_length, output_size)


  def forward(self, x):
    x_hat = self.encoder(x)

    for i in range(len(self.residuals)):
      x_hat = self.residuals[i](x_hat)

    return self.fl_ln(x_hat)


  def get_output_size(self):
    return self.output_size


class IntraoperativeHypotensionModel(nn.Module):
  def __init__(
    self,
    ecg_resnet,
    abp_resnet,
    eeg_resnet
  ):
    super(IntraoperativeHypotensionModel, self).__init__()

    self.ecg = ecg_resnet
    self.abp = abp_resnet
    self.eeg = eeg_resnet

    self.fc_input_length = 0

    if self.ecg is not None:
      self.fc_input_length += self.ecg.get_output_size()

    if self.abp is not None:
      self.fc_input_length += self.abp.get_output_size()

    if self.eeg is not None:
      self.fc_input_length += self.eeg.get_output_size()

    if self.fc_input_length == 0:
      raise 'No resnet blocks provided, unable to build model'

    self.fc1 = nn.Linear(self.fc_input_length, 16)
    self.fc2 = nn.Linear(16, 1)
    self.act = nn.Sigmoid()

    self.seq = nn.Sequential(
      self.fc1,
      self.fc2,
      self.act
    )


  def forward(self, abp, ecg, eeg):
    ecg_o = torch.Tensor([])
    abp_o = torch.Tensor([])
    eeg_o = torch.Tensor([])

    if self.ecg is not None:
      ecg_o = self.ecg(ecg)

    if self.abp is not None:
      abp_o = self.abp(abp)

    if self.eeg is not None:
      eeg_o = self.eeg(eeg)

    resnet_output = torch.concat([ecg_o, abp_o, eeg_o], dim=1)

    prediction = self.seq(resnet_output)

    return prediction


In [3]:
import torch
from torch.optim import Adam
from torch.nn import BCELoss
import numpy as np
from torcheval.metrics import BinaryAUROC, BinaryAUPRC, BinaryRecall


def _binary_specificity(test, target):
  # TN / TN + FP
  pinned = torch.where(test >= 0.5, 1.0, 0.0)
  pos = torch.where(pinned > 0, 1.0, 0.0)
  neg = torch.where(pinned < 1, 1.0, 0.0)
  gt_pos = torch.where(target > 0, 1.0, 0.0)
  gt_neg = torch.where(target < 1, 1.0, 0.0)
  tn = neg + gt_neg
  tn = torch.sum(torch.where(tn > 1, 1.0, 0.0), dtype=torch.float)
  fp = pos + gt_neg
  fp = torch.sum(torch.where(fp > 1, 1.0, 0.0), dtype=torch.float)

  return (tn / (tn + fp))


def _train_one_epoch(
  model,
  train_data,
  optimizer,
  criterion
):
  model.train()
  loss_history = []
  for data in train_data:
    optimizer.zero_grad()
    y_hat = model(data.abp, data.ecg, data.eeg)
    loss = criterion(torch.squeeze(y_hat, dim=-1), data.y)
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())

  return loss_history


def _eval_model(
  model,
  eval_data,
  dataset_name
):
  model.eval()

  auroc = []
  auprc = []
  sensitivity = []
  specificity = []

  f_auroc = BinaryAUROC()
  f_auprc = BinaryAUPRC()
  f_sensitivity = BinaryRecall()
  for data in eval_data:
    y_hat = model(data.abp, data.ecg, data.eeg)
    y_hat = y_hat.squeeze(-1)
    y_hat_long = torch.where(y_hat >= 0.5, 1.0, 0.0).long()
    target_long = data.y.long()

    f_auroc.update(y_hat, data.y)
    f_auprc.update(y_hat, data.y)
    f_sensitivity.update(y_hat_long, target_long)

    auroc.append(f_auroc.compute())
    auprc.append(f_auprc.compute())
    sensitivity.append(f_sensitivity.compute())
    specificity.append(_binary_specificity(y_hat, data.y))

  m_auroc = np.mean(auroc)
  m_auprc = np.mean(auprc)
  m_sensitivity = np.mean(sensitivity)
  m_specificity = np.mean(specificity)

  print(f'    {dataset_name} data metrics:')
  print(f'        AUROC: {m_auroc}')
  print(f'        AUPRC: {m_auprc}')
  print(f'        Sensitivity: {m_sensitivity}')
  print(f'        Specificity: {m_specificity}')

  return m_auroc, m_auprc, m_sensitivity, m_specificity


def train(
  model,
  train_data_handle,
  test_data_handle,
  learning_rate=0.0001,
  epochs=100,
  suspend_train_epochs_threshold=5
):
  """Trains an IntraoperativeHypotensionModel using the given learning rate for
  the given number of epochs

  model: the IntraoperativeHypotensionModel to train
  train_data_handle: the dataset we will train on
  test_data_handle: the dataset we will use for evaluation
  learning_rate: the learning rate to use with the Adam optimizer
  epochs: the number of epochs to train for
  suspend_train_epochs_threshold: training will be suspended if the loss does
    not improve for this number of epochs
  """
  if model is None or train_data_handle is None or test_data_handle is None:
    raise ValueError(
      'model, train_data_handle, and test_data_handle are required for training'
    )

  criterion = BCELoss()
  optimizer = Adam(model.parameters(), lr=learning_rate)

  overall_loss_history = []
  consecutive_epochs_without_improvement = 0
  for epoch in range(epochs):
    print('====================================')
    print(f'     Epoch #{epoch + 1}')
    print('====================================')
    loss_history = _train_one_epoch(
      model,
      train_data_handle,
      optimizer,
      criterion
    )
    _eval_model(model, train_data_handle, 'Train')
    # Not using performance metrics yet in this function.
    # Potential TODO: stop training once desired performance is reached (TBD)
    performance = _eval_model(model, test_data_handle, 'Test')

    if epoch > 0:
      mean_loss = np.mean(loss_history)
      overall_loss_history.append(mean_loss)
      loss_change = overall_loss_history[epoch - 1] - mean_loss
      if loss_change < 0.1:
        consecutive_epochs_without_improvement += 1
      else:
        consecutive_epochs_without_improvement = 0

    if consecutive_epochs_without_improvement >= suspend_train_epochs_threshold:
      print(f'Training stopping after {epoch+1} epochs.')
      print(f'Loss did not change for {suspend_train_epochs_threshold} epochs')
      break



In [4]:
# Put together model as the paper describes (all resnets)
abp_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='abp'
)

ecg_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='ecg'
)

eeg_resnet = WaveformResNet(
    input_shape=7680,
    output_size=32,
    data_type='eeg'
)

model = IntraoperativeHypotensionModel(
    abp_resnet=abp_resnet,
    ecg_resnet=ecg_resnet,
    eeg_resnet=eeg_resnet
)

In [None]:
import sys
sys.path.append('/content/cs598-final-project')
from src.data import get_data

real_data = get_data(15, max_num_samples=5)

In [6]:
# TEST CELL - REMOVE

# from src.train import train

class TestData():
  def __init__(self, abp, ecg, eeg, y):
    self.abp = abp
    self.ecg = ecg
    self.eeg = eeg
    self.y = y

abp = real_data[0].unsqueeze(1)

batch = []
batch.append(TestData(abp, torch.randn((5,1,30000)), torch.randn((5,1,7680)), torch.Tensor([0,0,0,0,0])))
batch.append(TestData(torch.randn((5,1,30000)), torch.randn((5,1,30000)), torch.randn((5,1,7680)), torch.Tensor([1,1,1,1,1])))

eval_batch = []
eval_batch.append(TestData(abp, torch.randn((5,1,30000)), torch.randn((5,1,7680)), torch.Tensor([0,0,0,0,0])))
eval_batch.append(TestData(torch.randn((5,1,30000)), torch.randn((5,1,30000)), torch.randn((5,1,7680)), torch.Tensor([1,1,1,1,1])))


train(model, batch, eval_batch, epochs=1)

     Epoch #1




    Train data metrics:
        AUROC: 0.75
        AUPRC: 0.5
        Sensitivity: 0.5
        Specificity: nan




    Test data metrics:
        AUROC: 0.75
        AUPRC: 0.5
        Sensitivity: 0.5
        Specificity: nan
