<a href="https://colab.research.google.com/github/invictus125/cs598-final-project/blob/main/intraoperative_hypotension_TA_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Reproduction of:
## Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram

Original paper by: Yong-Yeon Jo, Jong-Hwan Jang, Joon-myoung Kwon, Hyung-Chul Lee, Chul-Woo Jung, Seonjeong Byun, Han‐Gil Jeong

Reproduction project authored by
* Mark Bauer
  * mbauer553
  * markab5@illinois.edu
* Ryan David
  * victheone
  * invictus125
  * radavid2@illinois.edu

This project can be found on github https://github.com/invictus125/cs598-final-project.  

> Note that this project uses <b>VitalDB, an open biosignal dataset.  All users must agree to the Data Use Agreement below.</b>  If after reviewing the agreement you do not comply, please do not read on and close this window.
[Data Use Agreement](https://vitaldb.net/dataset/?query=overview&documentId=13qqajnNZzkN7NZ9aXnaQ-47NWy7kx-a6gbrcEsi-gak&sectionId=h.vcpgs1yemdb5)

## Introduction
Our project is to perform an approximate reproduction of a paper, which can be found [here](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0272055), on predicting hypotension during surgery from a combination of signals such as mean arterial blood pressure (ABP), electrocardiogram (ECG), and electroencephalogram (EEG) as opposed to ABP alone.  Predicting hypotension is important because it is correlated with many post operation complications, and is actionable.  Please read the original work if you are interested in more detail!

## Scope of reproducibility

As of the draft on 2024-04-14, we are using what we understand to be a very close replication of the model, with a smaller openly available data set.  Our model currently executes utilizing ABP data from the dataset for predictors and labels, but the EEG and ECG data is randomly created mock data.  The original work also examined more look ahead times.  

In [9]:
# install dependencies
!pip install torcheval
!pip install vitaldb



## Methodology - Data

Methodology - Data
Case: a surgery/operation

Track: data observed during a case, consisting of a device and type

As of the draft on 2024-04-14 we are getting only ABP data and labels looking ahead one minute.  

In [23]:
from functools import partial
from io import StringIO
from os import listdir, getcwd
from torch import FloatTensor, BoolTensor

import csv
import numpy as np
import requests
import vitaldb

SEGEMENT_LENGTH_SECONDS = 60

ABP_TRACK = 'SNUADC/ART'
ECG_TRACK = 'SNUADC/ECG_II'
EEG_TRACK = 'BIS/EEG1_WAV'

RELEVANT_TRACKS = [
    ABP_TRACK,
    ECG_TRACK,
    EEG_TRACK,
]

def _url_to_reader(url_string):
    response = requests.get(url_string)
    file = StringIO(response.text)
    return csv.DictReader(file, delimiter=',')

def get_unique_vals(key, iterable):
    return set(map(lambda item: item[key], iterable))

def case_filter(case):
    return float(case['age']) >= 18.0 and case['ane_type'] == 'General'

def _case_track_filter(case_id, case_dict):
    track_list = case_dict[case_id]['tracks']
    return (
        ABP_TRACK in track_list and
        ECG_TRACK in track_list and
        EEG_TRACK in track_list
    )

def _get_candidate_cases():
    cases_by_id = {}
    for case in _url_to_reader('https://api.vitaldb.net/cases'):
        if case_filter(case):
            case['tracks'] = {}
            cases_by_id[case['\ufeffcaseid']] = case

    track_list_reader = _url_to_reader('https://api.vitaldb.net/trks')

    for track in track_list_reader:
        case_id = track['caseid']
        if track['tname'] in RELEVANT_TRACKS:
            if cases_by_id.get(case_id):
                cases_by_id[case_id]['tracks'][track['tname']] = track['tid']

    case_track_filter = partial(_case_track_filter, case_dict=cases_by_id)

    return [case_id for case_id in filter(case_track_filter, cases_by_id.keys())]

def _get_candidate_cases_from_dir(dir_path):
    return [f.split('.vital')[0] for f in listdir(dir_path) if '.vital' in f]

def _download_vital_file(case_id):
    vf = vitaldb.VitalFile(int(case_id), RELEVANT_TRACKS)
    vf.to_vital(case_id+'.vital')

def _get_tracks_from_vital_file(path, tracks, sample_rate):
    vf = vitaldb.read_vital(path, tracks)
    return vf.to_numpy(tracks, sample_rate)

def validate_abp_segment(segment):

    return (
        not np.isnan(segment).any() and
        not (segment > 200).any() and
        not (segment < 30).any() and
        not ((np.max(segment) - np.min(segment)) < 30) and
        not (np.abs(np.diff(segment)) > 30).any() # abrupt changes are assumed to be noise
    )

def download_data(num_requested_cases):
    num_downloaded_cases = 0
    candidate_case_ids = _get_candidate_cases()

    np.random.shuffle(candidate_case_ids)
    for case_id in candidate_case_ids:
        print('Downloading case:', case_id)
        _download_vital_file(case_id)
        num_downloaded_cases = num_downloaded_cases + 1
        at_requested = num_downloaded_cases == num_requested_cases
        if at_requested:
            break

    if not at_requested:
        print('Requsted cases not reached but all available cases exhausted.  ')

def get_data(
    minutes_ahead,
    abp_and_ecg_sample_rate_per_second=500,
    eeg_sample_rate_per_second=128,
    max_num_samples=None,
    max_num_cases=None,
    from_dir=None,
):
    if from_dir is None:
        candidate_case_ids = _get_candidate_cases()
    else:
        candidate_case_ids = _get_candidate_cases_from_dir(from_dir)

    abps = []
    ecgs = []
    eegs = []
    hypotension_event_bools = []

    abp_data_in_two_seconds = 2 * abp_and_ecg_sample_rate_per_second

    at_max = False

    case_count = 0
    np.random.shuffle(candidate_case_ids)
    for case_id in candidate_case_ids:
        case_num_samples = 0
        case_num_events = 0

        print('Getting track data for case:', case_id)
        if from_dir is None:
            case_tracks = vitaldb.load_case(int(case_id), RELEVANT_TRACKS[0:2], 1/abp_and_ecg_sample_rate_per_second)
        else:
            case_tracks = _get_tracks_from_vital_file(f"{from_dir}/{case_id}.vital", RELEVANT_TRACKS[0:2], 1/abp_and_ecg_sample_rate_per_second)

        abp_track = case_tracks[:,0]
        # ecg_track = case_tracks[:,1]

        # eeg_track = vitaldb.load_case(int(case_id), RELEVANT_TRACKS[2], 1/eeg_sample_rate_per_second).flatten()

        for i in range(
            0,
            len(abp_track) - abp_and_ecg_sample_rate_per_second * (SEGEMENT_LENGTH_SECONDS + (1 + minutes_ahead) * SEGEMENT_LENGTH_SECONDS),
            10 * abp_and_ecg_sample_rate_per_second
        ):
            x_segment = abp_track[i:i + abp_and_ecg_sample_rate_per_second * SEGEMENT_LENGTH_SECONDS]
            y_segment_start = i + abp_and_ecg_sample_rate_per_second * (SEGEMENT_LENGTH_SECONDS + minutes_ahead * SEGEMENT_LENGTH_SECONDS)
            y_segement_end = i + abp_and_ecg_sample_rate_per_second * (SEGEMENT_LENGTH_SECONDS + (minutes_ahead + 1) * SEGEMENT_LENGTH_SECONDS)
            y_segment = abp_track[y_segment_start:y_segement_end]

            if validate_abp_segment(x_segment) and validate_abp_segment(y_segment):
                abps.append(x_segment)

                # 2 second moving average
                y_numerator = np.nancumsum(y_segment, dtype=np.float32)
                y_numerator[abp_data_in_two_seconds:] = y_numerator[abp_data_in_two_seconds:] - y_numerator[:-abp_data_in_two_seconds]
                y_moving_avg = y_numerator[abp_data_in_two_seconds - 1:] / abp_data_in_two_seconds

                is_hypotension_event = np.nanmax(y_moving_avg) < 65
                hypotension_event_bools.append(is_hypotension_event)
                case_num_samples = case_num_samples + 1
                if(is_hypotension_event):
                    case_num_events = case_num_events + 1

            at_max_samples = len(hypotension_event_bools) == max_num_samples
            if at_max_samples:
                break

        case_count = case_count + 1
        print(f"Statistics for case: {case_id}, {case_num_samples} total valid samples, {case_num_events} positive samples")

        if at_max_samples or case_count == max_num_cases:
            if at_max_samples:
                print('Max samples reached')
            else:
                print('Max cases reached')
            at_max = True
            break

    if not at_max:
        print('Max not reached but all available cases exhausted.  ')

    # Shuffle the samples
    abps = np.array(abps)
    ecgs = np.array(ecgs)
    eegs = np.array(eegs)
    hypotension_event_bools = np.array(hypotension_event_bools)
    shuffled_idx = np.array([i for i in range(0, len(hypotension_event_bools))]).astype(int)
    np.random.shuffle(shuffled_idx)
    if len(abps) > 0:
      abps = abps[shuffled_idx]
    if len(ecgs) > 0:
      ecgs = ecgs[shuffled_idx]
    if len(eegs) > 0:
      eegs = eegs[shuffled_idx]
    hypotension_event_bools = hypotension_event_bools[shuffled_idx]

    return (
        FloatTensor(abps).unsqueeze(1),
        FloatTensor(ecgs).unsqueeze(1),
        FloatTensor(eegs).unsqueeze(1),
        BoolTensor(hypotension_event_bools).float(),
    )


## Methodology - Model
Our model is an exact reproduction based on the description provided in the original paper.

There is a ResNet for each of the three waveform types we handle consisting of:

- A CNN encoder layer
- 12 residual blocks, each having two convolutions and two batch normalizations. Alternating blocks will halve the length of the data using a max pooling operation. Per the paper, we also added skip connections by summing the input into the output in each residual block.
- A fully connected output layer which flattens the channels prior to passing through a NN

The model is built such that we can provide one or more ResNets and it will adapt. This is so that we can experiment with varying combinations of input data.

Once the input is run through the ResNets, their output is concatenated and passed through a fully connected layer which ends with a sigmoid activation, producing the final prediction.


In [11]:
import torch
import math
from torch import nn


class EncoderBlock(nn.Module):
  def __init__(self, dim_in, kernel_size=15, stride=1):
    super(EncoderBlock, self).__init__()
    padding = math.floor(kernel_size / 2.0)
    self.conv = nn.Conv1d(1, 1, kernel_size, stride, padding=padding)
    self.mp = nn.MaxPool1d(kernel_size, stride, padding)
    self.fc = nn.Linear(dim_in, dim_in)
    torch.nn.init.normal_(self.fc.weight, mean=0.0, std=0.01)


  def forward(self, x):
    x_hat = self.conv(x)
    x_hat = self.mp(x_hat)
    return self.fc(x_hat)


class ResidualBlock(nn.Module):
  def __init__(
    self,
    in_channels,
    out_channels,
    size_down,
    kernel_size,
    stride=1
  ):
    super(ResidualBlock, self).__init__()

    self.size_down = size_down
    self.in_channels = in_channels
    self.out_channels = out_channels

    padding = math.floor(kernel_size / 2.0)

    self.bn1 = nn.BatchNorm1d(in_channels)
    self.act1 = nn.ReLU()
    self.do = nn.Dropout()
    self.conv1 = nn.Conv1d(in_channels, in_channels, kernel_size, stride, padding)
    self.bn2 = nn.BatchNorm1d(in_channels)
    self.act2 = nn.ReLU()
    self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
    self.mp = nn.MaxPool1d(kernel_size, padding=padding, stride=2)
    self.conv_for_input = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)


  def forward(self, x):
    x_hat = self.bn1(x)
    x_hat = self.act1(x_hat)
    x_hat = self.do(x_hat)
    x_hat = self.conv1(x_hat)
    x_hat = self.bn2(x_hat)
    x_hat = self.act2(x_hat)
    x_hat = self.conv2(x_hat)

    # Adjust dimensions of input if needed for the skip connection
    x_input = None
    if self.in_channels != self.out_channels:
      x_input = self.conv_for_input(x)
    else:
      x_input = x

    x_hat = x_hat + x_input

    if self.size_down:
      x_hat = self.mp(x_hat)

    return x_hat


class FlattenAndLinearBlock(nn.Module):
  def __init__(self, dim_in, dim_out):
    super(FlattenAndLinearBlock, self).__init__()
    self.fc = nn.Linear(dim_in, dim_out)
    torch.nn.init.normal_(self.fc.weight, mean=0.0, std=0.01)


  def forward(self, x):
    x_hat = torch.flatten(x, start_dim=1, end_dim=-1)
    x_hat = self.fc(x_hat)

    return x_hat


class WaveformResNet(nn.Module):
  def __init__(
    self,
    input_shape,
    output_size,
    data_type
  ):
    super(WaveformResNet, self).__init__()
    self.encoder = EncoderBlock(input_shape, 15, 1)
    self.res_in_dim = input_shape
    self.output_size = output_size
    self.data_type = data_type

    if data_type not in ['abp', 'ecg', 'eeg']:
      raise ValueError('Invalid data type. Must be one of [abp, ecg, eeg]')

    # Set up configurations for residual blocks
    residual_configs = []
    linear_block_input_length = -1
    if data_type in ['abp', 'ecg']:
      residual_configs = [
        {
          'kernel_size': 15,
          'in_channels': 1,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 15,
          'in_channels': 2,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': False,
        },
      ]
      linear_block_input_length = 469 * 6
    else:
      residual_configs = [
        {
          'kernel_size': 7,
          'in_channels': 1,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': False,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 2,
          'size_down': True,
        },
        {
          'kernel_size': 7,
          'in_channels': 2,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 4,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 4,
          'out_channels': 6,
          'size_down': False,
        },
        {
          'kernel_size': 3,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': True,
        },
        {
          'kernel_size': 3,
          'in_channels': 6,
          'out_channels': 6,
          'size_down': False,
        },
      ]
      linear_block_input_length = 120 * 6

    self.residuals = []
    # Build residuals
    for i in range(12):
      self.residuals.append(
        ResidualBlock(
          size_down=residual_configs[i]['size_down'],
          in_channels=residual_configs[i]['in_channels'],
          out_channels=residual_configs[i]['out_channels'],
          kernel_size=residual_configs[i]['kernel_size'],
        )
      )

    self.fl_ln = FlattenAndLinearBlock(linear_block_input_length, output_size)


  def forward(self, x):
    # TODO: get encoder layer working properly and uncomment
    # x_hat = self.encoder(x)
    x_hat = x

    for i in range(len(self.residuals)):
      x_hat = self.residuals[i](x_hat)

    out = self.fl_ln(x_hat)

    return out


  def get_output_size(self):
    return self.output_size


class IntraoperativeHypotensionModel(nn.Module):
  def __init__(
    self,
    ecg_resnet=None,
    abp_resnet=None,
    eeg_resnet=None
  ):
    super(IntraoperativeHypotensionModel, self).__init__()

    self.ecg = ecg_resnet
    self.abp = abp_resnet
    self.eeg = eeg_resnet

    self.fc_input_length = 0

    if self.ecg is not None:
      self.fc_input_length += self.ecg.get_output_size()

    if self.abp is not None:
      self.fc_input_length += self.abp.get_output_size()

    if self.eeg is not None:
      self.fc_input_length += self.eeg.get_output_size()

    if self.fc_input_length == 0:
      raise 'No resnet blocks provided, unable to build model'

    self.fc1 = nn.Linear(self.fc_input_length, 16)
    self.fc2 = nn.Linear(16, 1)
    self.act = nn.Sigmoid()
    torch.nn.init.normal_(self.fc1.weight, mean=0.0, std=0.01)
    torch.nn.init.normal_(self.fc2.weight, mean=0.0, std=0.01)


  def forward(self, abp, ecg, eeg):
    ecg_o = torch.Tensor([])
    abp_o = torch.Tensor([])
    eeg_o = torch.Tensor([])

    if self.ecg is not None:
      ecg_o = self.ecg(ecg)

    if self.abp is not None:
      abp_o = self.abp(abp)

    if self.eeg is not None:
      eeg_o = self.eeg(eeg)

    resnet_output = torch.concat([ecg_o, abp_o, eeg_o], dim=1)

    intermediate = self.fc1(resnet_output)
    intermediate = self.fc2(intermediate)
    prediction = self.act(intermediate)

    return prediction


## Methodology - Training
Computational requirements:
- At least 50 GB of RAM
- GPU instance (we have been experimenting with an A100)

The training of this model is fairly straightforward. The paper suggested that we should use Adam as the optimizer and BCE as the loss function, so that is what we have done.

Each training epoch will also automatically run evaluation on both the train set and the validation set. See the evaluation methodology for details.

In [12]:
import torch
from torch.optim import Adam
from torch.nn import BCELoss
import numpy as np


def _extract_batch(data, batch_size, batch_number):
  start = batch_size * batch_number
  end = start + batch_size

  if start >= len(data[3]):
    return None

  return [
      data[0][start:end] if len(data[0]) > 0 else None,
      data[1][start:end] if len(data[1]) > 0 else None,
      data[2][start:end] if len(data[2]) > 0 else None,
      data[3][start:end]
  ]


def _train_one_epoch(
  model,
  train_data,
  optimizer,
  criterion,
  batch_size=32
):
  model.train()
  loss_history = []

  batch_num = 0
  batch = _extract_batch(train_data, batch_size, batch_num)
  while batch is not None:
    optimizer.zero_grad()
    abp = batch[0]
    ecg = batch[1]
    eeg = batch[2]
    y = batch[3]
    y_hat = model(abp, ecg, eeg)
    y_hat = torch.squeeze(y_hat, dim=-1)
    loss = criterion(y_hat, y)
    loss.backward()
    optimizer.step()
    batch_loss = loss.item()
    print(f'\tBatch {batch_num} loss: {batch_loss}')
    loss_history.append(batch_loss)

    batch_num += 1
    batch = _extract_batch(train_data, batch_size, batch_num)

  return loss_history


def train(
  model,
  train_data_handle,
  test_data_handle,
  learning_rate=0.0001,
  epochs=100,
  suspend_train_epochs_threshold=5,
  batch_size=32
):
  """Trains an IntraoperativeHypotensionModel using the given learning rate for
  the given number of epochs

  model: the IntraoperativeHypotensionModel to train
  train_data_handle: the dataset we will train on
  test_data_handle: the dataset we will use for evaluation
  learning_rate: the learning rate to use with the Adam optimizer
  epochs: the number of epochs to train for
  suspend_train_epochs_threshold: training will be suspended if the loss does
    not improve for this number of epochs
  """
  if model is None or train_data_handle is None or test_data_handle is None:
    raise ValueError(
      'model, train_data_handle, and test_data_handle are required for training'
    )

  criterion = BCELoss()
  optimizer = Adam(model.parameters(), lr=learning_rate)

  overall_loss_history = []
  consecutive_epochs_without_improvement = 0
  for epoch in range(epochs):
    print('====================================')
    print(f'     Epoch #{epoch + 1}')
    print('====================================')

    loss_history = _train_one_epoch(
      model,
      train_data_handle,
      optimizer,
      criterion,
      batch_size
    )
    eval_model(model, train_data_handle, 'Train', batch_size)
    # Not using performance metrics yet in this function.
    # Potential TODO: stop training once desired performance is reached (TBD)
    performance = eval_model(model, test_data_handle, 'Test', batch_size)

    if epoch > 0:
      mean_loss = np.mean(loss_history)
      overall_loss_history.append(mean_loss)
      loss_change = overall_loss_history[epoch - 1] - mean_loss
      if loss_change < 0.1:
        consecutive_epochs_without_improvement += 1
      else:
        consecutive_epochs_without_improvement = 0

    if consecutive_epochs_without_improvement >= suspend_train_epochs_threshold:
      print(f'Training stopping after {epoch+1} epochs.')
      print(f'Loss did not change for {suspend_train_epochs_threshold} epochs')
      break

##  Methodology - Evaluation
The original paper uses four metrics:
- AUROC
- AUPRC
- Sensitivity
- Specificity

We chose to use the torcheval library for our metrics, except for binary specificity which did not appear to be present in torcheval.

In light of that, we have implemented our own binary specificity. Regrettably, this does not yet work properly, but we will do our best to get it functional before our final submission.



In [26]:
import torch
from torcheval.metrics import BinaryAUROC, BinaryAUPRC, BinaryRecall


def _binary_specificity(test, target):
  # TN / TN + FP
  pinned = torch.where(test >= 0.5, 1.0, 0.0)
  pos = torch.where(pinned > 0, 1.0, 0.0)
  neg = torch.where(pinned < 1, 1.0, 0.0)
  gt_pos = torch.where(target > 0, 1.0, 0.0)
  gt_neg = torch.where(target < 1, 1.0, 0.0)
  tn = neg + gt_neg
  tn = torch.sum(torch.where(tn > 1, 1.0, 0.0), dtype=torch.float)
  fp = pos + gt_neg
  fp = torch.sum(torch.where(fp > 1, 1.0, 0.0), dtype=torch.float)

  return (tn / (tn + fp))


def eval_model(
  model,
  eval_data,
  dataset_name,
  batch_size=32
):
  model.eval()

  auroc = []
  auprc = []
  sensitivity = []
  specificity = []

  f_auroc = BinaryAUROC()
  f_auprc = BinaryAUPRC()
  f_sensitivity = BinaryRecall()

  batch_num = 0
  batch = _extract_batch(eval_data, batch_size, batch_num)
  while batch is not None:
    abp = batch[0]
    ecg = batch[1]
    eeg = batch[2]
    y = batch[3]

    y_hat = model(abp, ecg, eeg)
    y_hat = y_hat.squeeze(-1)

    y_hat_long = torch.where(y_hat >= 0.5, 1.0, 0.0).long()
    target_long = y.long()

    print(f'y_hat_long sum: {y_hat_long.sum()}, target_long sum: {target_long.sum()}')

    f_auroc.update(y_hat, y)
    f_auprc.update(y_hat, y)
    f_sensitivity.update(y_hat_long, target_long)

    auroc.append(f_auroc.compute())
    auprc.append(f_auprc.compute())
    sensitivity.append(f_sensitivity.compute())
    specificity.append(_binary_specificity(y_hat, y))

    batch_num += 1
    batch = _extract_batch(eval_data, batch_size, batch_num)

  print(sensitivity)
  m_auroc = np.mean(auroc)
  m_auprc = np.mean(auprc)
  m_sensitivity = np.mean(sensitivity)
  m_specificity = np.mean(specificity)

  print(f'    {dataset_name} data metrics:')
  print(f'        AUROC: {m_auroc}')
  print(f'        AUPRC: {m_auprc}')
  print(f'        Sensitivity: {m_sensitivity}')
  print(f'        Specificity: {m_specificity}')

  return m_auroc, m_auprc, m_sensitivity, m_specificity


### Results
Thus far we have the ability to build our model and obtain one type of test data in the form we need it for training and evaluation. Please execute the code cells below to see for yourself!

Our plans from here until the end of the project are:

- Make it possible to obtain the other two data types (ECG and EEG) and use them the same way we are able to use ABP
- Get our custom specificity function working properly
- Train and evaluate our model using a variety of samples
- Train and evaluate our model using varying combinations of the input types


# Train with only ABP

This next section will demonstrate training on only ABP data. It also demonstrates the scalability of our model, allowing us to only use certain resnets if we so choose.

In [27]:
# Put together a model using only ABP
abp_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='abp',
)

abp_model = IntraoperativeHypotensionModel(
    abp_resnet=abp_resnet
)

In [24]:
# Obtain ABP data to train on
test_set = get_data(3, max_num_samples=1000)

# Get 819
_download_vital_file('819')

train_set = get_data(3, from_dir='.')

Getting track data for case: 1750
Statistics for case: 1750, 0 total valid samples, 0 positive samples
Getting track data for case: 3704
Statistics for case: 3704, 1000 total valid samples, 94 positive samples
Max samples reached
Getting track data for case: 819
Statistics for case: 819, 737 total valid samples, 73 positive samples
Max not reached but all available cases exhausted.  


In [28]:
# Train the model on ABP data

# train_set = [
#     all_data[0][0:4000].unsqueeze(1),
#     all_data[1][0:4000].unsqueeze(1),
#     all_data[2][0:4000].unsqueeze(1),
#     all_data[3][0:4000]
# ]

# test_set = [
#     all_data[0][4000:].unsqueeze(1),
#     all_data[1][4000:].unsqueeze(1),
#     all_data[2][4000:].unsqueeze(1),
#     all_data[3][4000:]
# ]

# train(abp_model, train_set, test_set, batch_size=40, epochs=100, learning_rate=0.0001)

# TRAIN ON ONLY CASE 819

train(abp_model, train_set, test_set, batch_size=40)

# TEST WITH RANDOM DATA
# sample_size = 400
# train_set_r = [
#     torch.randn([sample_size, 1, 30000]),
#     torch.randn([sample_size, 1, 30000]),
#     torch.randn([sample_size, 1, 30000]),
#     torch.where(torch.rand([sample_size]) > 0.5, 1.0, 0.0),
# ]

# train(abp_model, train_set_r, train_set_r, batch_size=40, epochs=3)


     Epoch #1
	Batch 0 loss: 0.7868557572364807
	Batch 1 loss: 0.7844471335411072
	Batch 2 loss: 0.7544103860855103
	Batch 3 loss: 0.7547532320022583
	Batch 4 loss: 0.7390578985214233
	Batch 5 loss: 0.7296877503395081
	Batch 6 loss: 0.7051376700401306
	Batch 7 loss: 0.6989387273788452
	Batch 8 loss: 0.6848915219306946
	Batch 9 loss: 0.665812075138092
	Batch 10 loss: 0.6529740691184998
	Batch 11 loss: 0.6330965161323547
	Batch 12 loss: 0.615434467792511
	Batch 13 loss: 0.600544810295105
	Batch 14 loss: 0.5802755355834961
	Batch 15 loss: 0.5752568244934082
	Batch 16 loss: 0.5482537150382996
	Batch 17 loss: 0.5135014057159424
	Batch 18 loss: 0.5068650841712952
y_hat_long sum: 0, target_long sum: 4
y_hat_long sum: 0, target_long sum: 2
y_hat_long sum: 0, target_long sum: 6
y_hat_long sum: 0, target_long sum: 3
y_hat_long sum: 0, target_long sum: 4
y_hat_long sum: 0, target_long sum: 2
y_hat_long sum: 0, target_long sum: 9
y_hat_long sum: 0, target_long sum: 2
y_hat_long sum: 0, target_long

In [None]:
# DEBUG - check input data
# print(train_set[0].size())
for x in range(train_set[0].size()[0]):
  cur = train_set[0][x]
  nanbread = cur != cur
  sumval = nanbread.sum()
  if sumval > 0:
    print(f'sample {x} has {sumval} nan vals')
nanbread = train_set[0] != train_set[0]
print(nanbread.sum())

tensor(0)


In [None]:
# Put together model as the paper describes (with all resnets)
abp_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='abp'
)

ecg_resnet = WaveformResNet(
    input_shape=30000,
    output_size=32,
    data_type='ecg'
)

eeg_resnet = WaveformResNet(
    input_shape=7680,
    output_size=32,
    data_type='eeg'
)

model = IntraoperativeHypotensionModel(
    abp_resnet=abp_resnet,
    ecg_resnet=ecg_resnet,
    eeg_resnet=eeg_resnet
)

## References
Jo YY, Jang JH, Kwon Jm, Lee HC, Jung CW, et al. (2022) Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study. PLOS ONE 17(8): e0272055. https://doi.org/10.1371/journal.pone.0272055

## Acknowledgements
* As mentioned in the introduction, this project leveraged the open [vitaldb dataset](https://vitaldb.net/dataset/), and without it would have been impossible in its current form.
* Significant inspiration was drawn from [vital db examples](https://github.com/vitaldb/examples)