<a href="https://colab.research.google.com/github/invictus125/cs598-final-project/blob/main/intraoperative_hypotension_ryan_testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Update from github
!git clone https://github.com/invictus125/cs598-final-project
!cd /content/cs598-final-project && git pull

Cloning into 'cs598-final-project'...
remote: Enumerating objects: 82, done.[K
remote: Counting objects: 100% (82/82), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 82 (delta 37), reused 52 (delta 13), pack-reused 0[K
Receiving objects: 100% (82/82), 22.14 KiB | 11.07 MiB/s, done.
Resolving deltas: 100% (37/37), done.
Already up to date.


Link to paper: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0272055

In [5]:
# Imports
!pip install torcheval
import sys
sys.path.append('/content/cs598-final-project')
from src.model import WaveformResNet, IntraoperativeHypotensionModel
from src.train import train



In [6]:
# TEST CELL - REMOVE

import torch
from torch import nn


class EncoderBlock(nn.Module):
  def __init__(self, dim_in, kernel_size=15, stride=1):
    super(EncoderBlock, self).__init__()
    self.conv = nn.Conv1d(1, 1, kernel_size, stride, padding=7)
    self.fc = nn.Linear(dim_in, dim_in)


  def forward(self, x):
    x_hat = torch.flatten(self.conv(x))
    return self.fc(x_hat)


class ResidualBlock(nn.Module):
  def __init__(
    self,
    in_channels,
    out_channels,
    size_down,
    kernel_size,
    stride=1
  ):
    super(ResidualBlock, self).__init__()
    self.bn1 = nn.BatchNorm1d(in_channels)
    self.act1 = nn.ReLU()
    self.do = nn.Dropout()
    self.conv1 = nn.Conv1d(in_channels, in_channels, kernel_size, stride)
    self.bn2 = nn.BatchNorm1d(in_channels)
    self.act2 = nn.ReLU()
    self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size, stride)

    self.seq = nn.Sequential(
      self.bn1,
      self.act1,
      self.do,
      self.conv1,
      self.bn2,
      self.act2,
      self.conv2
    )


  def forward(self, x):
    x_hat = self.seq(x)

    # Concat raw input with convolved for skip connection
    return torch.cat([x, x_hat], 1)


class FlattenAndLinearBlock(nn.Module):
  def __init__(self, dim_in, dim_out):
    super(FlattenAndLinearBlock, self).__init__()
    self.fc = nn.Linear(dim_in, dim_out)


  def forward(self, x):
    x_hat = torch.flatten(x)
    x_hat = self.fc(x_hat)

    return x_hat


class WaveformResNet(nn.Module):
  def __init__(
    self,
    input_shape,
    output_size,
    data_type
  ):
    super(WaveformResNet, self).__init__()
    self.encoder = EncoderBlock(input_shape, 15, 1)
    self.res_in_dim = input_shape
    self.output_size = output_size

    if data_type not in ['abp', 'ecg', 'eeg']:
      raise ValueError('Invalid data type. Must be one of [abp, ecg, eeg]')

    # Set up configurations for residual blocks
    residual_configs = []
    linear_block_input_length = -1
    if data_type in ['abp', 'ecg']:
      residual_configs = [
        {
          'kernel_size': 15,
          'in_channels': 1,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': False,
        },
      ]
      linear_block_input_length = 496
    else:
      residual_configs = [
        {
          'kernel_size': 7,
          'in_channels': 1,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': False,
        },
      ]
      linear_block_input_length = 120

    self.residuals = []
    # Build residuals
    for i in range(12):
      self.residuals.append(
        ResidualBlock(
          size_down=residual_configs[i]['size_down'],
          in_channels=residual_configs[i]['in_channels'],
          out_channels=residual_configs[i]['out_channels'],
          kernel_size=residual_configs[i]['kernel_size'],
        )
      )

    self.fl_ln = FlattenAndLinearBlock(linear_block_input_length, output_size)


  def forward(self, x):
    x_hat = self.encoder(x).unsqueeze(0)

    for i in range(len(self.residuals)):
      x_hat = self.residuals[i](x_hat)

    return self.fl_ln(x_hat)


  def get_output_size(self):
    return self.output_size


class IntraoperativeHypotensionModel(nn.Module):
  def __init__(
    self,
    ecg_resnet,
    abp_resnet,
    eeg_resnet
  ):
    super(IntraoperativeHypotensionModel, self).__init__()

    self.ecg = ecg_resnet
    self.abp = abp_resnet
    self.eeg = eeg_resnet

    self.fc_input_length = 0

    if self.ecg is not None:
      self.fc_input_length += self.ecg.get_output_size()

    if self.abp is not None:
      self.fc_input_length += self.abp.get_output_size()

    if self.eeg is not None:
      self.fc_input_length += self.eeg.get_output_size()

    if self.fc_input_length == 0:
      raise 'No resnet blocks provided, unable to build model'

    self.fc1 = nn.Linear(self.fc_input_length, 16)
    self.fc2 = nn.Linear(16, 1)
    self.act = nn.Sigmoid()

    self.seq = nn.Sequential(
      self.fc1,
      self.fc2,
      self.act
    )


  def forward(self, abp, ecg, eeg):
    ecg_o = torch.Tensor([])
    abp_o = torch.Tensor([])
    eeg_o = torch.Tensor([])

    if self.ecg is not None:
      ecg_o = self.ecg(ecg)

    if self.abp is not None:
      abp_o = self.abp(abp)

    if self.eeg is not None:
      eeg_o = self.eeg(eeg)

    fc_in = torch.flatten(torch.concat([ecg_o, abp_o, eeg_o]))

    prediction = self.seq(fc_in)

    return prediction


rb = ResidualBlock(2, 2, True, 15, 1)
flb = FlattenAndLinearBlock(496, 32)

In [7]:
# Put together model as the paper describes (all resnets)
abp_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='abp'
)

ecg_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='ecg'
)

eeg_resnet = WaveformResNet(
    input_shape=7680,
    output_size=32,
    data_type='eeg'
)

model = IntraoperativeHypotensionModel(
    abp_resnet=abp_resnet,
    ecg_resnet=ecg_resnet,
    eeg_resnet=eeg_resnet
)

In [8]:
# TEST CELL - REMOVE

from src.train import train

ecg = torch.randn((10, 30000))
abp = torch.randn((10, 30000))
eeg = torch.randn((10, 7680))

class TestData():
  def __init__(self, abp, ecg, eeg, y):
    self.abp = abp
    self.ecg = ecg
    self.eeg = eeg
    self.y = y

batch = []
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))

eval_batch = []
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 0))
eval_batch.append(TestData(torch.randn((1,30000)), torch.randn((1,30000)), torch.randn((1,7680)), 1))

train(model, batch, eval_batch)

     Epoch #1


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 30000])