# Imports

In [1]:
from getpass import getpass

token = getpass('Enter your GitHub personal access token: ')
name = getpass('Enter your GitHub name: ')
#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

!git clone https://{token}@github.com/{name}/comp_med.git
#%cd comp_med


Enter your GitHub personal access token: ··········
Enter your GitHub name: ··········
Cloning into 'comp_med'...
remote: Enumerating objects: 77, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 77 (delta 41), reused 32 (delta 12), pack-reused 0 (from 0)[K
Receiving objects: 100% (77/77), 32.31 KiB | 1.70 MiB/s, done.
Resolving deltas: 100% (41/41), done.


In [2]:
!git config --global user.email "vierling.lukas@gmailcom"
!git config --global user.name  "lukasVierling"
!git add .
!git commit -m "My latest changes from Colab"
!git push origin main


fatal: not a git repository (or any of the parent directories): .git
fatal: not a git repository (or any of the parent directories): .git
fatal: not a git repository (or any of the parent directories): .git


## Make Code Deterministic for Reproducibility

In [1]:
import os
import random
import numpy as np
import torch
os.environ["PYTHONHASHSEED"] = str(42)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [2]:
!pip install wfdb



In [3]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
!touch /content/comp_med/__init__.py
!touch /content/comp_med/data/__init__.py
!touch /content/comp_med/models/__init__.py


In [4]:
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

#my imports
from comp_med.models.attentionCNN import CNN_1D
from comp_med.models.oldCNN import CNN_2D
from comp_med.data.preprocessing import get_dataloaders


## Training Helpers

In [5]:
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device="cpu", return_loss=False):
  model.to(device)
  loss_tracker = []
  for epoch in range(epochs):
    pbar = tqdm(train_loader, desc=f"Train the model in epoch {epoch}...")
    loss = 0
    for x,y in pbar:
      optimizer.zero_grad()
      x,y = x.to(device),y.to(device)
      out = model(x)
      out = torch.sigmoid(out)
      loss = criterion(out,y)
      loss.backward()
      optimizer.step()
      loss = loss.item()
      pbar.set_description(f"Current loss in epoch {epoch} is {loss}")
      loss_tracker.append(loss)
    #start validation
    acc = eval(model, val_loader, device)
    print(f"Acc on val in epoch {epoch} is: {acc}")
  if return_loss:
    return loss_tracker

def eval(model, data_loader, device="cpu"):
    model.to(device)
    #start validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
      for x,y in data_loader:
        x,y = x.to(device), y.to(device)
        out = model(x)
        preds = (torch.sigmoid(out) > 0.5).float()
        correct += (preds == y).sum().item()
        total += y.shape[0]
    model.train()
    acc = correct / total
    return acc

# Verification Dummy Task

In [6]:
# create a dummy dataset to verify the architecture works
batch_size = 256
length = 5
samples_per_sec = 1000
num_leads = 12
num_samples = 500

t = np.arange(0, length, 1/samples_per_sec)
base_freq = 1

X = np.zeros((num_samples, num_leads, t.size))
labels = np.zeros((num_samples,))

for i in range(num_samples):
  label = i % 2
  for l in range(num_leads):
    amplitude = l+1
    X[i,l] = amplitude * np.sin( 2*np.pi * base_freq * t)
  if label == 1:
    #get a random lead
    lead_idx = np.random.randint(0, num_leads)
    #doulbe the frequency when label is 1
    X[i, lead_idx] = (lead_idx+1) * np.sin( 2*np.pi * base_freq*2 * t)

  #add some noise
  noise = np.random.normal(0, 0.1, size=X[i].shape)
  X[i] = X[i]+noise
  labels[i]= label

X_t = torch.from_numpy(X).float()
y_t = torch.from_numpy(labels).unsqueeze(1).float()

len = X_t.shape[0]
permutation = torch.randperm(len)
train_idx = permutation[:400]
test_idx =permutation[400:]

X_train, y_train = X_t[train_idx], y_t[train_idx]
X_test, y_test = X_t[test_idx] ,y_t[test_idx]
print(X_t.shape)
print(X_test.shape)
print(X_train.shape)
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size)

torch.Size([500, 12, 5000])
torch.Size([100, 12, 5000])
torch.Size([400, 12, 5000])


In [7]:
epochs = 25
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNN_1D()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
train(model, criterion, optimizer, train_loader, train_loader ,epochs, device)
print("Fianl acc on test is: ", eval(model, test_loader, device))

TypeError: LeadProcessingBlock1D.__init__() got an unexpected keyword argument 'dilation'

# Convergence Analysis of the Model

In [9]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt", train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
Load data from given path


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class LeadProcessingBlock1D(nn.Module):
    """1D residual block over time, but *per‑lead* (no cross‑lead mixing)."""
    def __init__(self, num_leads, filters_per_lead, dilation):
        super().__init__()
        channels = num_leads * filters_per_lead
        # each conv is grouped by lead → no mixing between leads here
        self.conv1 = nn.Conv1d(
            channels, channels, kernel_size=3,
            padding=dilation, dilation=dilation,
            groups=num_leads,  # <— depthwise per lead
            bias=False
        )
        self.bn1   = nn.BatchNorm1d(channels)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv1d(
            channels, channels, kernel_size=3,
            padding=dilation, dilation=dilation,
            groups=num_leads,  # <— same here
            bias=False
        )
        self.bn2   = nn.BatchNorm1d(channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        return out + identity

class LeadMixingBlock(nn.Module):
    def __init__(self, num_leads, filters_per_lead, num_heads):
        super().__init__()
        self.num_leads = num_leads
        self.filters_per_lead = filters_per_lead
        self.attn = nn.MultiheadAttention(embed_dim = filters_per_lead, num_heads = num_heads, batch_first=True)

    def forward(self,x):
        B,C,T = x.shape
        L = self.num_leads
        F = self.filters_per_lead
        #reshape the input to get the leads  (C=FT)
        out = x.view(B, L, F, T)

        # B*T attention probelms
        out = out.permute(0,3,1,2).reshape(B*T, L, F)

        attn_out, _ = self.attn(out, out, out)
        out = attn_out.reshape(B,T,L,F).permute(0,2,3,1)

        return out.reshape(B,C,T)

class CNN_1D_new(nn.Module):
    def __init__(self, num_leads=12, hidden_channels=64, filters_1d = 20, kernel_1d = 100, stride_1d = 50, attn_heads = 4, res_dilations = [1,2,4]):
        super().__init__()
        #encoder
        hidden_channels = num_leads * filters_1d
        self.encoder = nn.Sequential(
            nn.Conv1d(num_leads, hidden_channels, kernel_size=kernel_1d, stride=stride_1d, groups=num_leads,bias=False),
            nn.ReLU(inplace=True)
        )
        self.layers = nn.ModuleList()
        self.LeadBlocks = nn.Sequential(
            LeadProcessingBlock1D(num_leads,filters_1d, dilation=res_dilations[0]),
            LeadMixingBlock(num_leads, filters_1d, attn_heads),
            LeadProcessingBlock1D(num_leads,filters_1d, dilation=res_dilations[1]),
            LeadMixingBlock(num_leads, filters_1d, attn_heads),
            LeadProcessingBlock1D(num_leads,filters_1d, dilation=res_dilations[2]),
            LeadMixingBlock(num_leads, filters_1d, attn_heads),
        )
        self.last_conv = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=16, dilation=16, bias=False),
                                       nn.BatchNorm1d(hidden_channels),
                                       nn.ReLU(inplace=True))
        self.pooling = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Linear(hidden_channels, 1)
    def forward(self, x):
        B = x.size(0)
        out = self.encoder(x)
        out = self.LeadBlocks(out)
        out = self.last_conv(out)
        out = self.pooling(out)
        #flatten
        out = out.view(B,-1)
        out = self.fc(out)
        return out



In [22]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [23]:
model = CNN_1D_new()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
print("Fianl acc on test is: ", eval(model, test_loader, device))

Current loss in epoch 0 is 0.281154990196228: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.05130443722009659: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.05822877585887909: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 2 is: 0.6905158069883528


Current loss in epoch 3 is 0.009757733903825283: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 3 is: 0.6888519134775375


Current loss in epoch 4 is 0.005402702372521162: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 4 is: 0.6605657237936772


Current loss in epoch 5 is 0.010494672693312168: 100%|██████████| 23/23 [00:18<00:00,  1.27it/s]


Acc on val in epoch 5 is: 0.6589018302828619


Current loss in epoch 6 is 0.006840330082923174: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 6 is: 0.64891846921797


Current loss in epoch 7 is 0.0009816238889470696: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 7 is: 0.6472545757071547


Current loss in epoch 8 is 0.0021680293139070272: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 8 is: 0.6722129783693843


Current loss in epoch 9 is 0.0009207280236296356: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 9 is: 0.6672212978369384


Current loss in epoch 10 is 0.0012772877234965563: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 10 is: 0.6672212978369384


Current loss in epoch 11 is 0.0008239531307481229: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 11 is: 0.651414309484193


Current loss in epoch 12 is 0.0007958798087202013: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 12 is: 0.6663893510815307


Current loss in epoch 13 is 0.00042275903979316354: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 13 is: 0.6580698835274542


Current loss in epoch 14 is 0.00037440776941366494: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 14 is: 0.653910149750416


Current loss in epoch 15 is 0.000299061561236158: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 15 is: 0.6630615640599001


Current loss in epoch 16 is 0.0002748648985289037: 100%|██████████| 23/23 [00:17<00:00,  1.35it/s]


Acc on val in epoch 16 is: 0.651414309484193


Current loss in epoch 17 is 0.0005037491209805012: 100%|██████████| 23/23 [00:18<00:00,  1.25it/s]


Acc on val in epoch 17 is: 0.6589018302828619


Current loss in epoch 18 is 0.00029927529976703227: 100%|██████████| 23/23 [00:18<00:00,  1.23it/s]


Acc on val in epoch 18 is: 0.6530782029950083


Current loss in epoch 19 is 0.00031406240304932: 100%|██████████| 23/23 [00:19<00:00,  1.19it/s]


Acc on val in epoch 19 is: 0.6622296173044925
Fianl acc on test is:  0.8968307484828051


# Sensitivity Analysis Towards Data Perturbation
We add random noise $\mathcal{N}(0,\sigma^2)$ to the input of the model and evaluate how robust the model is.

In [13]:
trials = 3
epochs = 20
lr = 1e-3

In [14]:
def eval_with_perturbation(model, data_loader, std_levels=[1e-3,5e-3,1e-2,5e-2,1e-1,5e-1,1], device="cpu"):
    accs = []
    for std in std_levels:
      model.to(device)
      #start validation
      model.eval()
      correct = 0
      total = 0
      with torch.no_grad():
        for x,y in data_loader:
          x,y = x.to(device), y.to(device)
          # add noise
          x = x + torch.randn_like(x) * std
          out = model(x)
          preds = (torch.sigmoid(out) > 0.5).float()
          correct += (preds == y).sum().item()
          total += y.shape[0]
      model.train()
      acc = correct / total
      accs.append(acc)
    return accs

In [15]:
trail_to_acc = {}
for trial in range(trials):
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval_with_perturbation(model, test_loader, device=device)
  trail_to_acc[trial] = accs

Current loss in epoch 0 is 0.12262696772813797: 100%|██████████| 23/23 [00:08<00:00,  2.81it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.08027421683073044: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 1 is: 0.6772046589018302


Current loss in epoch 2 is 0.018351132050156593: 100%|██████████| 23/23 [00:08<00:00,  2.72it/s]


Acc on val in epoch 2 is: 0.802828618968386


Current loss in epoch 3 is 0.09371186047792435: 100%|██████████| 23/23 [00:08<00:00,  2.79it/s]


Acc on val in epoch 3 is: 0.8494176372712147


Current loss in epoch 4 is 0.024837573990225792: 100%|██████████| 23/23 [00:08<00:00,  2.79it/s]


Acc on val in epoch 4 is: 0.8660565723793677


Current loss in epoch 5 is 0.02017902210354805: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 5 is: 0.7495840266222962


Current loss in epoch 6 is 0.012972411699593067: 100%|██████████| 23/23 [00:08<00:00,  2.74it/s]


Acc on val in epoch 6 is: 0.8386023294509152


Current loss in epoch 7 is 0.0013552390737459064: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 7 is: 0.8053244592346089


Current loss in epoch 8 is 0.0015145022189244628: 100%|██████████| 23/23 [00:08<00:00,  2.71it/s]


Acc on val in epoch 8 is: 0.7454242928452579


Current loss in epoch 9 is 0.0008705118088982999: 100%|██████████| 23/23 [00:08<00:00,  2.72it/s]


Acc on val in epoch 9 is: 0.8610648918469218


Current loss in epoch 10 is 0.00030955762485973537: 100%|██████████| 23/23 [00:08<00:00,  2.75it/s]


Acc on val in epoch 10 is: 0.7762063227953411


Current loss in epoch 11 is 0.002030485076829791: 100%|██████████| 23/23 [00:08<00:00,  2.72it/s]


Acc on val in epoch 11 is: 0.8535773710482529


Current loss in epoch 12 is 0.0003312973421998322: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 12 is: 0.8161397670549085


Current loss in epoch 13 is 0.00026927702128887177: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 13 is: 0.7362728785357737


Current loss in epoch 14 is 0.0005367140984162688: 100%|██████████| 23/23 [00:08<00:00,  2.73it/s]


Acc on val in epoch 14 is: 0.7762063227953411


Current loss in epoch 15 is 0.0002542502770666033: 100%|██████████| 23/23 [00:08<00:00,  2.61it/s]


Acc on val in epoch 15 is: 0.8302828618968386


Current loss in epoch 16 is 0.0006185718812048435: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 16 is: 0.7878535773710482


Current loss in epoch 17 is 0.00015538703883066773: 100%|██████████| 23/23 [00:08<00:00,  2.70it/s]


Acc on val in epoch 17 is: 0.7995008319467554


Current loss in epoch 18 is 0.00019522913498803973: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 18 is: 0.7886855241264559


Current loss in epoch 19 is 0.0001964727562153712: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 19 is: 0.7795341098169717


Current loss in epoch 0 is 0.18173611164093018: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.04137532785534859: 100%|██████████| 23/23 [00:08<00:00,  2.69it/s]


Acc on val in epoch 1 is: 0.7063227953410982


Current loss in epoch 2 is 0.015295527875423431: 100%|██████████| 23/23 [00:08<00:00,  2.67it/s]


Acc on val in epoch 2 is: 0.8660565723793677


Current loss in epoch 3 is 0.006899348925799131: 100%|██████████| 23/23 [00:08<00:00,  2.66it/s]


Acc on val in epoch 3 is: 0.7537437603993344


Current loss in epoch 4 is 0.00298992614261806: 100%|██████████| 23/23 [00:08<00:00,  2.67it/s]


Acc on val in epoch 4 is: 0.8768718801996672


Current loss in epoch 5 is 0.0013794316910207272: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 5 is: 0.8818635607321131


Current loss in epoch 6 is 0.0023187820333987474: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 6 is: 0.7479201331114809


Current loss in epoch 7 is 0.001055543078109622: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 7 is: 0.831946755407654


Current loss in epoch 8 is 0.004138501361012459: 100%|██████████| 23/23 [00:08<00:00,  2.66it/s]


Acc on val in epoch 8 is: 0.7462562396006656


Current loss in epoch 9 is 0.005504134111106396: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 9 is: 0.8036605657237936


Current loss in epoch 10 is 0.001053712796419859: 100%|██████████| 23/23 [00:08<00:00,  2.65it/s]


Acc on val in epoch 10 is: 0.8760399334442596


Current loss in epoch 11 is 0.001840145792812109: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 11 is: 0.8718801996672213


Current loss in epoch 12 is 0.0013623950071632862: 100%|██████████| 23/23 [00:08<00:00,  2.69it/s]


Acc on val in epoch 12 is: 0.7870216306156406


Current loss in epoch 13 is 0.00019080640049651265: 100%|██████████| 23/23 [00:08<00:00,  2.65it/s]


Acc on val in epoch 13 is: 0.8519134775374376


Current loss in epoch 14 is 0.0003993217833340168: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 14 is: 0.7886855241264559


Current loss in epoch 15 is 0.0003118351742159575: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 15 is: 0.8352745424292846


Current loss in epoch 16 is 0.0006800913834013045: 100%|██████████| 23/23 [00:08<00:00,  2.67it/s]


Acc on val in epoch 16 is: 0.7562396006655574


Current loss in epoch 17 is 0.00027768511790782213: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 17 is: 0.8036605657237936


Current loss in epoch 18 is 0.00014769539120607078: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 18 is: 0.7703826955074875


Current loss in epoch 19 is 0.00011683280172292143: 100%|██████████| 23/23 [00:08<00:00,  2.66it/s]


Acc on val in epoch 19 is: 0.7495840266222962


Current loss in epoch 0 is 0.09961763024330139: 100%|██████████| 23/23 [00:08<00:00,  2.67it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.03181261569261551: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 1 is: 0.7470881863560732


Current loss in epoch 2 is 0.0077316961251199245: 100%|██████████| 23/23 [00:08<00:00,  2.62it/s]


Acc on val in epoch 2 is: 0.6647254575707154


Current loss in epoch 3 is 0.009215709753334522: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 3 is: 0.7229617304492513


Current loss in epoch 4 is 0.007076606620103121: 100%|██████████| 23/23 [00:08<00:00,  2.66it/s]


Acc on val in epoch 4 is: 0.7254575707154742


Current loss in epoch 5 is 0.0023069048766046762: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 5 is: 0.6672212978369384


Current loss in epoch 6 is 0.0025211982429027557: 100%|██████████| 23/23 [00:08<00:00,  2.67it/s]


Acc on val in epoch 6 is: 0.7504159733777038


Current loss in epoch 7 is 0.000864020548760891: 100%|██████████| 23/23 [00:08<00:00,  2.68it/s]


Acc on val in epoch 7 is: 0.7221297836938436


Current loss in epoch 8 is 0.0006833648076280951: 100%|██████████| 23/23 [00:08<00:00,  2.64it/s]


Acc on val in epoch 8 is: 0.7129783693843594


Current loss in epoch 9 is 0.0005257856682874262: 100%|██████████| 23/23 [00:08<00:00,  2.63it/s]


Acc on val in epoch 9 is: 0.7071547420965059


Current loss in epoch 10 is 0.0005303008947521448:  13%|█▎        | 3/23 [00:01<00:08,  2.49it/s]


KeyboardInterrupt: 

In [81]:
print(trial_to_acc)

[0.8985165205664194, 0.8961564396493594, 0.8958192852326365, 0.8924477410654079, 0.8739042481456507, 0.7073499662845584, 0.3509777478084963]


# Sensetivity Analysis Towards Hyperparameters
We evaluate the sensitivity towards hyperparameters. We execute every experiment 3 times and report the average

In [None]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt", train_ratio=0.6, val_ratio=0.1)

## Kernel Size
We evaluate different learning rates $lr = \{ 10^{-2}, 10^{-3}, 10^{-4}\}$

In [None]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
kernel_sizes = [50,100,200]
accs = []
for kernel_size in kernel_sizes:
  model = CNN_1D(kernel_1d = kernel_size)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  accs.append(acc)

## Dilation
We performance when changing dilation

In [None]:
dilations = []
accs = []
for lr in lrs:
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  accs.append(acc)

## Attention Heads
We evaluate the performance when changing the number of attention heads

In [None]:
attention_heads = []
accs = []
for lr in lrs:
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  accs.append(acc)

## Stride
We investigate different Kernel Size / Stride combinations

In [None]:
kernel_size_stride_pairs = []
accs = []
for lr in lrs:
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  accs.append(acc)