In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torcheval.metrics as metrics
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, random_split
from jaxtyping import Float
from chbmit.chbmit import FilteredCMP
from tqdm.notebook import tqdm
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class MultiHead(nn.Module):
  def __init__(self, ):
    super().__init__()
    self.shared_layers = nn.Sequential(nn.Conv2d(23, 96, (11, 11), 4),
                                       nn.MaxPool2d(2),
                                       nn.Conv2d(96, 384, (5, 5), padding='same'), 
                                       nn.MaxPool2d(2),
                                       nn.Conv2d(384, 384, (3, 3), padding='same'),
                                       nn.Conv2d(384, 384, (3, 3), padding='same'),
                                       nn.Conv2d(384, 256, (3, 3), padding='same'),
                                       nn.MaxPool2d(2),
                                       nn.Flatten(),
                                       nn.Linear(12544, 4096),
                                       nn.ReLU(),
                                       nn.Linear(4096, 4096),
    )
    self.classification_head = nn.Sequential(
      nn.Linear(4096, 1000),
      nn.ReLU(),
      nn.Linear(1000, 1000),
      nn.ReLU(),
      nn.Linear(1000, 2)
    )

    self.regression_head = nn.Sequential(
      nn.Linear(4096, 1000),
      nn.ReLU(),
      nn.Linear(1000, 1000),
      nn.ReLU(),
      nn.Linear(1000, 1)
    )

    self.variance_head = nn.Sequential(
      nn.Linear(4096, 1000),
      nn.ReLU(),
      nn.Linear(1000, 1000),
      nn.ReLU(),
      nn.Linear(1000, 1)
    )

  def forward(self, data): 
    shared_output = self.shared_layers(data)
    return (self.classification_head(shared_output), self.regression_head(shared_output), self.variance_head(shared_output.detach()))


In [17]:
class MHTrainer():
  def __init__(self,data_path: str, batch_size: int, learing_rate: float, n_fft: int):
    self.init_model()
    self.global_step = 0
    self.dataset = FilteredCMP(60, 10, data_path, sop=3600, sph=0, regression=True)
    self.train_data, self.val_data = random_split(self.dataset, [.8, .2])
    self.DataLoader = DataLoader(self.train_data, batch_size=batch_size)
    self.ValLoader = DataLoader(self.val_data, batch_size=batch_size)
    self.optim = torch.optim.AdamW(self.model.parameters(), lr=learing_rate)
    self.init_class_weights()
    self.classification_critierion = nn.CrossEntropyLoss()
    self.writer = SummaryWriter()

  def init_model(self,):
    self.model = MultiHead()

  @torch.no_grad()
  def init_class_weights(self,):
    pos = 0
    neg = 0
    for data, label in tqdm(self.DataLoader, leave=False):
      data = self.process(data)
      pred = self.model(data)[0]

      label[label > 0] = 1
      label = label.long()
      pos += label[label == 1].size(-1)
      neg += label[label == 0].size(-1)
    
    self.classification_critierion.weight = [(pos + neg) / (2 * neg), (pos + neg) / (2 * pos)]

  def process(self, data):
    batch_size = data.size(0)
    data = data.reshape(batch_size * 23, -1)
    data = torch.stft(data, n_fft=128, window=torch.hann_window(128), return_complex=False)
    data = torch.sqrt((data[:, :, :, 0] ** 2 + data[:, :, :, 1] ** 2))
    data = data.reshape(batch_size, 23, -1, data.size(-1))
    data = F.interpolate(data, (256, 256))
    return data

  def train_epoch_classification(self):
    for data, label in tqdm(self.DataLoader, leave=False,):
      data = self.process(data)
      pred = self.model(data)[0] 

      label[label != float('inf')] = 1
      label[label == float('inf')] = 0
      label = label.long()
      loss = self.classification_critierion(pred, label)
      self.writer.add_scalar('loss/train_class', loss, self.global_step)

      self.optim.zero_grad()
      loss.backward()
      self.optim.step()
      self.global_step += 1

  def train_epoch_regression(self):
    for data, label in tqdm(self.DataLoader, leave=False,):
      data = self.process(data)
      mean, std = self.model(data)[1:]

      loss = self.heteroskedastic_criterion(mean, std, label)
      self.writer.add_scalar('loss/train_regression', loss, self.global_step)

      self.optim.zero_grad()
      loss.backward()
      self.optim.step()
      self.global_step += 1
  
  @torch.no_grad()
  def validate_classification(self):
    avg_loss = metrics.Mean(device=device)
    auprc = metrics.BinaryAUPRC()
    for data, label in tqdm(self.ValLoader, leave=False,):
      data = self.process(data)
      pred = self.model(data)[0]

      label[label != float('inf')] = 1
      label[label == float('inf')] = 0
      label = label.long()
      loss = nn.CrossEntropyLoss(reduce=False)(pred, label)
      avg_loss.update(loss)
      auprc.update(pred[..., 1], label)
    self.writer.add_scalar('loss/val_class', avg_loss.compute(), self.global_step)
    self.writer.add_scalar('metrics/auprc', auprc.compute(), self.global_step)

  @torch.no_grad()
  def validate_regression(self):
    avg_loss = metrics.Mean(device=device)
    RMSE = metrics.Mean(device=device)
    for data, label in tqdm(self.ValLoader, leave=False,):
      data = self.process(data)
      mean, std = self.model(data)[1:]

      loss = self.heteroskedastic_criterion(mean, std, label)
      avg_loss.update(loss)
      RMSE.update((label[label != float('inf')] - mean[label != float('inf')]) ** 2)
    self.writer.add_scalar('loss/val_class', avg_loss.compute(), self.global_step)
    self.writer.add_scalar('metrics/RMSE', RMSE.compute() ** 0.5, self.global_step)

  def heteroskedastic_criterion(self, mean, std, label, reduce=True):
    mean = mean[label != float('inf')]
    std = std[label != float('inf')]
    label = label[label != float('inf')]
    var = std ** 2
    loss = (label - mean) ** 2 / (2 * var) + 0.5 * torch.log(2 * torch.pi * var)
    if reduce:
      loss = loss.mean(dim=-1)
    return loss

  def train_classification(self):
    for _ in range(10):
      self.train_epoch_classification()
      self.validate_classification()
      self.validate_regression()
      self.writer.flush()
  
  def train_regression(self,):
    for _ in range(10):
      self.train_epoch_regression()
      self.validate_regression()
      self.validate_classification()
      self.writer.flush()

In [19]:
trainer = MHTrainer('/Users/femibello/Documents/projects/SeizureSense/physionet.org/files/chbmit/1.0.0/chb01', 25, 1e-4, 128)

  0%|          | 0/2663 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [25]:
d = FilteredCMP(60, 10, 'tests/test_data', sop=3600, sph=0, regression=True)

  0%|          | 0/111 [00:00<?, ?it/s]

In [9]:
trainer.train()

