# Imports

In [1]:
from getpass import getpass

token = getpass('Enter your GitHub personal access token: ')
name = getpass('Enter your GitHub name: ')
#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

!git clone https://{token}@github.com/{name}/comp_med.git
#%cd comp_med




In [None]:
#!rm -r comp_med

## Make Code Deterministic for Reproducibility

In [4]:
import os
import random
import numpy as np
import torch
def set_seed(SEED):
  os.environ["PYTHONHASHSEED"] = str(SEED)
  random.seed(SEED)
  np.random.seed(SEED)
  torch.manual_seed(SEED)
  torch.cuda.manual_seed_all(SEED)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

In [11]:
!pip install wfdb



In [5]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
)

#my imports
from comp_med.models.attentionCNN import CNN_1D
from comp_med.models.oldCNN import CNN_2D
from comp_med.data.preprocessing import get_dataloaders


## Training Helpers

In [7]:
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device="cpu", return_loss=False):
  model.to(device)
  loss_tracker = []
  for epoch in range(epochs):
    pbar = tqdm(train_loader, desc=f"Train the model in epoch {epoch}...")
    loss = 0
    for x,y in pbar:
      optimizer.zero_grad()
      x,y = x.to(device),y.to(device)
      out = model(x)
      out = torch.sigmoid(out)
      loss = criterion(out,y)
      loss.backward()
      optimizer.step()
      loss = loss.item()
      pbar.set_description(f"Current loss in epoch {epoch} is {loss}")
      loss_tracker.append(loss)
    #start validation
    acc = eval(model, val_loader, device)
    print(f"Acc on val in epoch {epoch} is: {acc}")
  if return_loss:
    return loss_tracker

def eval(model, data_loader, device="cpu", all_metrics=False):
    model.to(device)
    #start validation
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
      for x,y in data_loader:
        x,y = x.to(device), y.to(device)
        out = model(x)
        preds.append((torch.sigmoid(out) > 0.5).float())
        labels.append(y)

    model.train()
    preds = torch.cat(preds).cpu().numpy().ravel()
    labels = torch.cat(labels).cpu().numpy().ravel()
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds)
    sens = recall_score(labels, preds)
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    spec = tn/(tn + fp)
    f1 = f1_score(labels, preds)
    if all_metrics:
      return {"acc":acc, "prec":prec, "sens":sens, "spec":spec, "f1":f1}
    else:
      return acc

# Verification Dummy Task

In [None]:
# create a dummy dataset to verify the architecture works
batch_size = 256
length = 5
samples_per_sec = 1000
num_leads = 12
num_samples = 500

t = np.arange(0, length, 1/samples_per_sec)
base_freq = 1

X = np.zeros((num_samples, num_leads, t.size))
labels = np.zeros((num_samples,))

for i in range(num_samples):
  label = i % 2
  for l in range(num_leads):
    amplitude = l+1
    X[i,l] = amplitude * np.sin( 2*np.pi * base_freq * t)
  if label == 1:
    #get a random lead
    lead_idx = np.random.randint(0, num_leads)
    #doulbe the frequency when label is 1
    X[i, lead_idx] = (lead_idx+1) * np.sin( 2*np.pi * base_freq*2 * t)

  #add some noise
  noise = np.random.normal(0, 0.1, size=X[i].shape)
  X[i] = X[i]+noise
  labels[i]= label

X_t = torch.from_numpy(X).float()
y_t = torch.from_numpy(labels).unsqueeze(1).float()

len = X_t.shape[0]
permutation = torch.randperm(len)
train_idx = permutation[:400]
test_idx =permutation[400:]

X_train, y_train = X_t[train_idx], y_t[train_idx]
X_test, y_test = X_t[test_idx] ,y_t[test_idx]
print(X_t.shape)
print(X_test.shape)
print(X_train.shape)
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size)

torch.Size([500, 12, 5000])
torch.Size([100, 12, 5000])
torch.Size([400, 12, 5000])


In [None]:
epochs = 25
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNN_1D()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
train(model, criterion, optimizer, train_loader, train_loader ,epochs, device)
print("Fianl acc on test is: ", eval(model, test_loader, device))

Current loss in epoch 0 is 0.5104804635047913: 100%|██████████| 2/2 [00:02<00:00,  1.25s/it]


Acc on val in epoch 0 is: 0.49


Current loss in epoch 1 is 0.3203796446323395: 100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


Acc on val in epoch 1 is: 0.51


Current loss in epoch 2 is 0.08878125250339508: 100%|██████████| 2/2 [00:01<00:00,  1.26it/s]


Acc on val in epoch 2 is: 0.51


Current loss in epoch 3 is 0.04321501776576042: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 3 is: 0.51


Current loss in epoch 4 is 0.02822541631758213: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


Acc on val in epoch 4 is: 0.51


Current loss in epoch 5 is 0.025208203122019768: 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


Acc on val in epoch 5 is: 0.51


Current loss in epoch 6 is 0.03727706894278526: 100%|██████████| 2/2 [00:01<00:00,  1.27it/s]


Acc on val in epoch 6 is: 0.51


Current loss in epoch 7 is 0.015179564245045185: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]


Acc on val in epoch 7 is: 0.51


Current loss in epoch 8 is 0.01463332585990429: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s]


Acc on val in epoch 8 is: 0.51


Current loss in epoch 9 is 0.010161933489143848: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]


Acc on val in epoch 9 is: 0.51


Current loss in epoch 10 is 0.009517444297671318: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 10 is: 0.51


Current loss in epoch 11 is 0.006419841665774584: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 11 is: 0.51


Current loss in epoch 12 is 0.006302207242697477: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 12 is: 0.51


Current loss in epoch 13 is 0.005004946608096361: 100%|██████████| 2/2 [00:01<00:00,  1.65it/s]


Acc on val in epoch 13 is: 0.51


Current loss in epoch 14 is 0.004329374060034752: 100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


Acc on val in epoch 14 is: 0.51


Current loss in epoch 15 is 0.003817749908193946: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s]


Acc on val in epoch 15 is: 0.51


Current loss in epoch 16 is 0.0048565189354121685: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]


Acc on val in epoch 16 is: 0.51


Current loss in epoch 17 is 0.003029455663636327: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s]


Acc on val in epoch 17 is: 0.51


Current loss in epoch 18 is 0.002934318035840988: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 18 is: 0.5525


Current loss in epoch 19 is 0.002483918098732829: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]


Acc on val in epoch 19 is: 0.775


Current loss in epoch 20 is 0.0036703546065837145: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 20 is: 0.945


Current loss in epoch 21 is 0.002393413567915559: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s]


Acc on val in epoch 21 is: 0.9775


Current loss in epoch 22 is 0.0022414394188672304: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]


Acc on val in epoch 22 is: 1.0


Current loss in epoch 23 is 0.002042328007519245: 100%|██████████| 2/2 [00:01<00:00,  1.67it/s]


Acc on val in epoch 23 is: 1.0


Current loss in epoch 24 is 0.0017815360333770514: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s]


Acc on val in epoch 24 is: 1.0
Fianl acc on test is:  1.0


# Convergence Analysis of the Model

In [16]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt", train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
Load data from given path


In [17]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
set_seed(0)
model = CNN_1D()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
print("Fianl acc on test is: ", eval(model, test_loader, device))

Current loss in epoch 0 is 0.08360694348812103: 100%|██████████| 23/23 [00:05<00:00,  4.09it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.04628802090883255: 100%|██████████| 23/23 [00:03<00:00,  6.06it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.007345775607973337: 100%|██████████| 23/23 [00:03<00:00,  6.02it/s]


Acc on val in epoch 2 is: 0.6946755407653911


Current loss in epoch 3 is 0.031864557415246964: 100%|██████████| 23/23 [00:03<00:00,  6.09it/s]


Acc on val in epoch 3 is: 0.7487520798668885


Current loss in epoch 4 is 0.0037259787786751986: 100%|██████████| 23/23 [00:03<00:00,  6.07it/s]


Acc on val in epoch 4 is: 0.78369384359401


Current loss in epoch 5 is 0.0017308311071246862: 100%|██████████| 23/23 [00:03<00:00,  6.06it/s]


Acc on val in epoch 5 is: 0.6938435940099834


Current loss in epoch 6 is 0.004814693238586187: 100%|██████████| 23/23 [00:03<00:00,  6.11it/s]


Acc on val in epoch 6 is: 0.7254575707154742


Current loss in epoch 7 is 0.004150109365582466: 100%|██████████| 23/23 [00:03<00:00,  6.06it/s]


Acc on val in epoch 7 is: 0.8394342762063228


Current loss in epoch 8 is 0.020144077017903328: 100%|██████████| 23/23 [00:03<00:00,  6.07it/s]


Acc on val in epoch 8 is: 0.8302828618968386


Current loss in epoch 9 is 0.015454555861651897: 100%|██████████| 23/23 [00:03<00:00,  6.05it/s]


Acc on val in epoch 9 is: 0.8128119800332779


Current loss in epoch 10 is 0.003554391674697399: 100%|██████████| 23/23 [00:03<00:00,  5.98it/s]


Acc on val in epoch 10 is: 0.8302828618968386


Current loss in epoch 11 is 0.0005613495013676584: 100%|██████████| 23/23 [00:03<00:00,  6.08it/s]


Acc on val in epoch 11 is: 0.6589018302828619


Current loss in epoch 12 is 0.0037457516882568598: 100%|██████████| 23/23 [00:03<00:00,  6.05it/s]


Acc on val in epoch 12 is: 0.6763727121464226


Current loss in epoch 13 is 0.00023705829516984522: 100%|██████████| 23/23 [00:03<00:00,  6.01it/s]


Acc on val in epoch 13 is: 0.6613976705490848


Current loss in epoch 14 is 0.0008414376643486321: 100%|██████████| 23/23 [00:03<00:00,  6.03it/s]


Acc on val in epoch 14 is: 0.6589018302828619


Current loss in epoch 15 is 0.0008028267184272408: 100%|██████████| 23/23 [00:03<00:00,  6.04it/s]


Acc on val in epoch 15 is: 0.6638935108153078


Current loss in epoch 16 is 0.00034881947794929147: 100%|██████████| 23/23 [00:03<00:00,  6.02it/s]


Acc on val in epoch 16 is: 0.6589018302828619


Current loss in epoch 17 is 0.00023410480935126543: 100%|██████████| 23/23 [00:03<00:00,  6.07it/s]


Acc on val in epoch 17 is: 0.6622296173044925


Current loss in epoch 18 is 0.00023087805311661214: 100%|██████████| 23/23 [00:03<00:00,  6.05it/s]


Acc on val in epoch 18 is: 0.6638935108153078


Current loss in epoch 19 is 0.0001931805891217664: 100%|██████████| 23/23 [00:03<00:00,  6.00it/s]


Acc on val in epoch 19 is: 0.6638935108153078
Fianl acc on test is:  0.9214430209035739


In [19]:
print(losses)

[0.6905696988105774, 0.7873849868774414, 0.7033773064613342, 0.6144259572029114, 0.4938927888870239, 0.43385016918182373, 0.3073195219039917, 0.30743545293807983, 0.24546369910240173, 0.26762712001800537, 0.18329425156116486, 0.19517819583415985, 0.1483273208141327, 0.14521941542625427, 0.11643799394369125, 0.11115504801273346, 0.1288827359676361, 0.12901932001113892, 0.11040644347667694, 0.0793788731098175, 0.09725043177604675, 0.08397062867879868, 0.08360694348812103, 0.12333345413208008, 0.07704699039459229, 0.051855579018592834, 0.06411096453666687, 0.06638966500759125, 0.1060364693403244, 0.054874926805496216, 0.07386410236358643, 0.07193392515182495, 0.06380335986614227, 0.03590027242898941, 0.036552757024765015, 0.052600957453250885, 0.09102331101894379, 0.06065509840846062, 0.028976725414395332, 0.03565126657485962, 0.04196413606405258, 0.02770964615046978, 0.026211634278297424, 0.021278470754623413, 0.0255604051053524, 0.04628802090883255, 0.023396994918584824, 0.0163365788757

# Sensitivity Analysis Towards Data Perturbation
We add random noise $\mathcal{N}(0,\sigma^2)$ to the input of the model and evaluate how robust the model is.

In [None]:
trials = 3
epochs = 20
lr = 1e-3
seeds = [0,1,2]

In [None]:
def eval_with_perturbation(model, data_loader, std_levels=[1e-3,5e-3,1e-2,5e-2,1e-1,5e-1,1], device="cpu"):
    accs = []
    for std in std_levels:
      model.to(device)
      #start validation
      model.eval()
      correct = 0
      total = 0
      with torch.no_grad():
        for x,y in data_loader:
          x,y = x.to(device), y.to(device)
          # add noise
          x = x + torch.randn_like(x) * std
          out = model(x)
          preds = (torch.sigmoid(out) > 0.5).float()
          correct += (preds == y).sum().item()
          total += y.shape[0]
      model.train()
      acc = correct / total
      accs.append(acc)
    return accs

In [None]:
trial_to_acc = {}
for trial,seed in zip(range(trials),seeds):
  set_seed(seed)
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval_with_perturbation(model, test_loader, device=device)
  trial_to_acc[trial] = accs

Current loss in epoch 0 is 0.2759646475315094: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.06597258150577545: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.022524302825331688: 100%|██████████| 23/23 [00:21<00:00,  1.07it/s]


Acc on val in epoch 2 is: 0.8261231281198004


Current loss in epoch 3 is 0.01326372753828764: 100%|██████████| 23/23 [00:21<00:00,  1.09it/s]


Acc on val in epoch 3 is: 0.737936772046589


Current loss in epoch 4 is 0.012736589647829533: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 4 is: 0.7229617304492513


Current loss in epoch 5 is 0.011166000738739967: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 5 is: 0.7229617304492513


Current loss in epoch 6 is 0.020346665754914284: 100%|██████████| 23/23 [00:17<00:00,  1.35it/s]


Acc on val in epoch 6 is: 0.7412645590682196


Current loss in epoch 7 is 0.0030898726545274258: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 7 is: 0.7662229617304492


Current loss in epoch 8 is 0.0032272387761622667: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 8 is: 0.7520798668885191


Current loss in epoch 9 is 0.0012734277406707406: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


Acc on val in epoch 9 is: 0.6838602329450915


Current loss in epoch 10 is 0.0006510601961053908: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 10 is: 0.6805324459234608


Current loss in epoch 11 is 0.0004503272648435086: 100%|██████████| 23/23 [00:18<00:00,  1.22it/s]


Acc on val in epoch 11 is: 0.7063227953410982


Current loss in epoch 12 is 0.0004752559179905802: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 12 is: 0.6938435940099834


Current loss in epoch 13 is 0.001056078472174704: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Acc on val in epoch 13 is: 0.6905158069883528


Current loss in epoch 14 is 0.009285333566367626: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 14 is: 0.71630615640599


Current loss in epoch 15 is 0.030553344637155533: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 15 is: 0.6247920133111481


Current loss in epoch 16 is 0.010342328809201717: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 16 is: 0.7088186356073212


Current loss in epoch 17 is 0.003859503660351038: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 17 is: 0.7004991680532446


Current loss in epoch 18 is 0.002098959404975176: 100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


Acc on val in epoch 18 is: 0.6946755407653911


Current loss in epoch 19 is 0.0008093470241874456: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 19 is: 0.6963394342762064


Current loss in epoch 0 is 0.1733430027961731: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.04414393752813339: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.025462323799729347: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 2 is: 0.7146422628951747


Current loss in epoch 3 is 0.00439727958291769: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 3 is: 0.718801996672213


Current loss in epoch 4 is 0.0047921729274094105: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 4 is: 0.7329450915141431


Current loss in epoch 5 is 0.005066166631877422: 100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


Acc on val in epoch 5 is: 0.7678868552412645


Current loss in epoch 6 is 0.023793648928403854: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 6 is: 0.7296173044925125


Current loss in epoch 7 is 0.007282329723238945: 100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


Acc on val in epoch 7 is: 0.8618968386023295


Current loss in epoch 8 is 0.0014109826879575849: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 8 is: 0.78369384359401


Current loss in epoch 9 is 0.0006989935645833611: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


Acc on val in epoch 9 is: 0.7479201331114809


Current loss in epoch 10 is 0.0031582913361489773: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 10 is: 0.7753743760399334


Current loss in epoch 11 is 0.0004954234464094043: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 11 is: 0.7279534109816972


Current loss in epoch 12 is 0.0003886868944391608: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 12 is: 0.7371048252911814


Current loss in epoch 13 is 0.00047050093417055905: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 13 is: 0.7279534109816972


Current loss in epoch 14 is 0.00020281130855437368: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 14 is: 0.7279534109816972


Current loss in epoch 15 is 0.0004683499282691628: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 15 is: 0.7312811980033278


Current loss in epoch 16 is 0.0003271755704190582: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 16 is: 0.7287853577371048


Current loss in epoch 17 is 0.0001620682014618069: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 17 is: 0.7237936772046589


Current loss in epoch 18 is 0.0002761706418823451: 100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


Acc on val in epoch 18 is: 0.7279534109816972


Current loss in epoch 19 is 0.0004573499027173966: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 19 is: 0.7296173044925125


Current loss in epoch 0 is 0.21527308225631714: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 0 is: 0.651414309484193


Current loss in epoch 1 is 0.07182791084051132: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.019832590594887733: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 2 is: 0.7279534109816972


Current loss in epoch 3 is 0.012981760315597057: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 3 is: 0.7462562396006656


Current loss in epoch 4 is 0.016201091930270195: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 4 is: 0.71630615640599


Current loss in epoch 5 is 0.0023507464211434126: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 5 is: 0.7354409317803661


Current loss in epoch 6 is 0.0014131300849840045: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 6 is: 0.7271214642262895


Current loss in epoch 7 is 0.0018411250784993172: 100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


Acc on val in epoch 7 is: 0.7179700499168054


Current loss in epoch 8 is 0.005120887421071529: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 8 is: 0.7171381031613977


Current loss in epoch 9 is 0.010517876595258713: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 9 is: 0.781198003327787


Current loss in epoch 10 is 0.039786726236343384: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 10 is: 0.7254575707154742


Current loss in epoch 11 is 0.0044457814656198025: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 11 is: 0.7262895174708819


Current loss in epoch 12 is 0.019021809101104736: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 12 is: 0.7171381031613977


Current loss in epoch 13 is 0.0013172096805647016: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 13 is: 0.7662229617304492


Current loss in epoch 14 is 0.00031679010135121644: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 14 is: 0.7237936772046589


Current loss in epoch 15 is 0.0008165273466147482: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 15 is: 0.7287853577371048


Current loss in epoch 16 is 0.0006517958827316761: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 16 is: 0.7279534109816972


Current loss in epoch 17 is 0.0002550899516791105: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 17 is: 0.7204658901830283


Current loss in epoch 18 is 0.0002837718930095434: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 18 is: 0.7204658901830283


Current loss in epoch 19 is 0.0007111789891496301: 100%|██████████| 23/23 [00:17<00:00,  1.29it/s]


Acc on val in epoch 19 is: 0.7221297836938436


In [None]:
import json
print(trial_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, 'noise_perturbation.json')
with open(out_path, 'w') as f:
    json.dump(trial_to_acc, f, indent=2)

{0: [0.9271746459878625, 0.9268374915711396, 0.9254888739042482, 0.8472690492245448, 0.7579231287929872, 0.6567768037761295, 0.6567768037761295], 1: [0.9197572488199596, 0.9190829399865138, 0.9153742414025624, 0.7801753202966959, 0.7107215104517869, 0.662508428860418, 0.6456507080242752], 2: [0.9146999325691166, 0.914025623735671, 0.9136884693189481, 0.8978422117329737, 0.8297370195549562, 0.46999325691166555, 0.4851652056641942]}


# Sensetivity Analysis Towards Hyperparameters
We evaluate the sensitivity towards hyperparameters. We execute every experiment 3 times and report the average

In [None]:
seeds = [0,0,0] #always use the same random seed for every run

## Kernel Size
We evaluate different kernel sizes  $[50,100,200]$

In [None]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
kernel_sizes = [25,50,200]
kernel_to_acc = {}
seeds = [0,0,0]
for kernel_size, seed in zip(kernel_sizes,seeds):
  set_seed(seed)
  model = CNN_1D(kernel_size=kernel_size)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  kernel_to_acc[kernel_size] = acc

Current loss in epoch 0 is 0.285277396440506: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 0 is: 0.6522462562396006


Current loss in epoch 1 is 0.11314776539802551: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.04671460762619972: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 2 is: 0.6547420965058236


Current loss in epoch 3 is 0.01095372624695301: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 3 is: 0.7154742096505824


Current loss in epoch 4 is 0.006582955364137888: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 4 is: 0.7346089850249584


Current loss in epoch 5 is 0.004983148537576199: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 5 is: 0.7903494176372712


Current loss in epoch 6 is 0.00730926962569356: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 6 is: 0.7695507487520798


Current loss in epoch 7 is 0.003046886995434761: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 7 is: 0.7970049916805324


Current loss in epoch 8 is 0.0011102607240900397: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 8 is: 0.8186356073211315


Current loss in epoch 9 is 0.004396279342472553: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 9 is: 0.8435940099833611


Current loss in epoch 10 is 0.0013283052248880267: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 10 is: 0.7354409317803661


Current loss in epoch 11 is 0.0010981113882735372: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 11 is: 0.8577371048252912


Current loss in epoch 12 is 0.0028541921637952328: 100%|██████████| 23/23 [00:17<00:00,  1.35it/s]


Acc on val in epoch 12 is: 0.8128119800332779


Current loss in epoch 13 is 0.011072671972215176: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 13 is: 0.8111480865224625


Current loss in epoch 14 is 0.012327490374445915: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 14 is: 0.8402662229617305


Current loss in epoch 15 is 0.004778989125043154: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 15 is: 0.7820299500831946


Current loss in epoch 16 is 0.001938458881340921: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 16 is: 0.8211314475873545


Current loss in epoch 17 is 0.004275916609913111: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 17 is: 0.7287853577371048


Current loss in epoch 18 is 0.0013328121276572347: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 18 is: 0.6980033277870217


Current loss in epoch 19 is 0.0016551954904571176: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 19 is: 0.7296173044925125
Fianl acc on test is:  0.9086311530681052


Current loss in epoch 0 is 0.2722465991973877: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.0361960269510746: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.05597604811191559: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 2 is: 0.6597337770382695


Current loss in epoch 3 is 0.012563235126435757: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 3 is: 0.721297836938436


Current loss in epoch 4 is 0.002993552479892969: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 4 is: 0.7354409317803661


Current loss in epoch 5 is 0.0027533865068107843: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 5 is: 0.7262895174708819


Current loss in epoch 6 is 0.005254688207060099: 100%|██████████| 23/23 [00:16<00:00,  1.35it/s]


Acc on val in epoch 6 is: 0.7537437603993344


Current loss in epoch 7 is 0.000913757539819926: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 7 is: 0.8502495840266223


Current loss in epoch 8 is 0.004065567161887884: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 8 is: 0.7504159733777038


Current loss in epoch 9 is 0.0005767935654148459: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 9 is: 0.7462562396006656


Current loss in epoch 10 is 0.0005166514893062413: 100%|██████████| 23/23 [00:17<00:00,  1.34it/s]


Acc on val in epoch 10 is: 0.7354409317803661


Current loss in epoch 11 is 0.0003470609663054347: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 11 is: 0.7371048252911814


Current loss in epoch 12 is 0.0007830221438780427: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 12 is: 0.7371048252911814


Current loss in epoch 13 is 0.0004543823597487062: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 13 is: 0.7321131447587355


Current loss in epoch 14 is 0.002776517765596509: 100%|██████████| 23/23 [00:17<00:00,  1.30it/s]


Acc on val in epoch 14 is: 0.762063227953411


Current loss in epoch 15 is 0.00047701687435619533: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 15 is: 0.740432612312812


Current loss in epoch 16 is 0.0018600716721266508: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 16 is: 0.7454242928452579


Current loss in epoch 17 is 0.00048589883954264224: 100%|██████████| 23/23 [00:17<00:00,  1.35it/s]


Acc on val in epoch 17 is: 0.7445923460898503


Current loss in epoch 18 is 0.0003300511743873358: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 18 is: 0.7512479201331115


Current loss in epoch 19 is 0.00019852684636134654: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 19 is: 0.740432612312812
Fianl acc on test is:  0.9265003371544167


Current loss in epoch 0 is 0.31284400820732117: 100%|██████████| 23/23 [00:18<00:00,  1.28it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.14599554240703583: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Acc on val in epoch 1 is: 0.670549084858569


Current loss in epoch 2 is 0.04862445965409279: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 2 is: 0.6788685524126455


Current loss in epoch 3 is 0.039830077439546585: 100%|██████████| 23/23 [00:18<00:00,  1.22it/s]


Acc on val in epoch 3 is: 0.6755407653910149


Current loss in epoch 4 is 0.022801276296377182: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 4 is: 0.7054908485856906


Current loss in epoch 5 is 0.007334355264902115: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


Acc on val in epoch 5 is: 0.8061564059900166


Current loss in epoch 6 is 0.002421768382191658: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 6 is: 0.742928452579035


Current loss in epoch 7 is 0.006066690664738417: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Acc on val in epoch 7 is: 0.7337770382695508


Current loss in epoch 8 is 0.009301567450165749: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 8 is: 0.7437603993344426


Current loss in epoch 9 is 0.0037967024836689234: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


Acc on val in epoch 9 is: 0.8069883527454242


Current loss in epoch 10 is 0.0012351460754871368: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 10 is: 0.7678868552412645


Current loss in epoch 11 is 0.0008467560983262956: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Acc on val in epoch 11 is: 0.7504159733777038


Current loss in epoch 12 is 0.003960408270359039: 100%|██████████| 23/23 [00:18<00:00,  1.24it/s]


Acc on val in epoch 12 is: 0.7920133111480865


Current loss in epoch 13 is 0.000559267878998071: 100%|██████████| 23/23 [00:18<00:00,  1.25it/s]


Acc on val in epoch 13 is: 0.762063227953411


Current loss in epoch 14 is 0.0003068592050112784: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 14 is: 0.8336106489184693


Current loss in epoch 15 is 0.00047006766544654965: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Acc on val in epoch 15 is: 0.7579034941763727


Current loss in epoch 16 is 0.00029452776652760804: 100%|██████████| 23/23 [00:17<00:00,  1.28it/s]


Acc on val in epoch 16 is: 0.7512479201331115


Current loss in epoch 17 is 0.0003287572762928903: 100%|██████████| 23/23 [00:18<00:00,  1.25it/s]


Acc on val in epoch 17 is: 0.7928452579034941


Current loss in epoch 18 is 0.0003940437745768577: 100%|██████████| 23/23 [00:18<00:00,  1.27it/s]


Acc on val in epoch 18 is: 0.7321131447587355


Current loss in epoch 19 is 0.0001844937796704471: 100%|██████████| 23/23 [00:18<00:00,  1.25it/s]


Acc on val in epoch 19 is: 0.7695507487520798
Fianl acc on test is:  0.9369521240728254


In [None]:
import json
print(kernel_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, 'kernel_size.json')
with open(out_path, 'w') as f:
    json.dump(kernel_to_acc, f, indent=2)

{25: 0.9086311530681052, 50: 0.9265003371544167, 200: 0.9369521240728254}


## Stride
We evaluate different strides $[25,50,100]$

In [None]:
strides = [25,100,200]
stride_to_acc = {}
seeds = [0,0,0]
for stride, seed in zip(strides,seeds):
  set_seed(seed)
  model = CNN_1D(stride = stride)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  stride_to_acc[stride] = acc

Current loss in epoch 0 is 0.2615770995616913: 100%|██████████| 23/23 [00:22<00:00,  1.02it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.0642012432217598: 100%|██████████| 23/23 [00:23<00:00,  1.01s/it]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.027889633551239967: 100%|██████████| 23/23 [00:22<00:00,  1.03it/s]


Acc on val in epoch 2 is: 0.7221297836938436


Current loss in epoch 3 is 0.007395664695650339: 100%|██████████| 23/23 [00:22<00:00,  1.03it/s]


Acc on val in epoch 3 is: 0.697171381031614


Current loss in epoch 4 is 0.002081871498376131: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 4 is: 0.6880199667221298


Current loss in epoch 5 is 0.0022306323517113924: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 5 is: 0.7004991680532446


Current loss in epoch 6 is 0.0016594124026596546: 100%|██████████| 23/23 [00:22<00:00,  1.02it/s]


Acc on val in epoch 6 is: 0.6805324459234608


Current loss in epoch 7 is 0.001728530740365386: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 7 is: 0.6697171381031614


Current loss in epoch 8 is 0.0007776033016853034: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 8 is: 0.6763727121464226


Current loss in epoch 9 is 0.0007106707198545337: 100%|██████████| 23/23 [00:22<00:00,  1.03it/s]


Acc on val in epoch 9 is: 0.6880199667221298


Current loss in epoch 10 is 0.000515115330927074: 100%|██████████| 23/23 [00:22<00:00,  1.03it/s]


Acc on val in epoch 10 is: 0.6913477537437605


Current loss in epoch 11 is 0.00044172746129333973: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 11 is: 0.6755407653910149


Current loss in epoch 12 is 0.0006389443879015744: 100%|██████████| 23/23 [00:22<00:00,  1.02it/s]


Acc on val in epoch 12 is: 0.6821963394342762


Current loss in epoch 13 is 0.0008271566475741565: 100%|██████████| 23/23 [00:22<00:00,  1.02it/s]


Acc on val in epoch 13 is: 0.6797004991680532


Current loss in epoch 14 is 0.0005115771200507879: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 14 is: 0.6830282861896838


Current loss in epoch 15 is 0.0006475409027189016: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 15 is: 0.6913477537437605


Current loss in epoch 16 is 0.0001809743553167209: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 16 is: 0.6821963394342762


Current loss in epoch 17 is 0.00020174095698166639: 100%|██████████| 23/23 [00:22<00:00,  1.03it/s]


Acc on val in epoch 17 is: 0.6938435940099834


Current loss in epoch 18 is 0.0013993592001497746: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Acc on val in epoch 18 is: 0.6913477537437605


Current loss in epoch 19 is 0.0002052297058980912: 100%|██████████| 23/23 [00:22<00:00,  1.02it/s]


Acc on val in epoch 19 is: 0.6913477537437605
Fianl acc on test is:  0.897167902899528


Current loss in epoch 0 is 0.36496517062187195: 100%|██████████| 23/23 [00:14<00:00,  1.57it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.12065209448337555: 100%|██████████| 23/23 [00:14<00:00,  1.56it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.033669427037239075: 100%|██████████| 23/23 [00:14<00:00,  1.56it/s]


Acc on val in epoch 2 is: 0.7029950083194676


Current loss in epoch 3 is 0.015550722368061543: 100%|██████████| 23/23 [00:14<00:00,  1.55it/s]


Acc on val in epoch 3 is: 0.8202995008319468


Current loss in epoch 4 is 0.01356946025043726: 100%|██████████| 23/23 [00:14<00:00,  1.54it/s]


Acc on val in epoch 4 is: 0.7121464226289518


Current loss in epoch 5 is 0.010781428776681423: 100%|██████████| 23/23 [00:15<00:00,  1.52it/s]


Acc on val in epoch 5 is: 0.6605657237936772


Current loss in epoch 6 is 0.007762079127132893: 100%|██████████| 23/23 [00:14<00:00,  1.57it/s]


Acc on val in epoch 6 is: 0.6946755407653911


Current loss in epoch 7 is 0.006634601391851902: 100%|██████████| 23/23 [00:14<00:00,  1.56it/s]


Acc on val in epoch 7 is: 0.7071547420965059


Current loss in epoch 8 is 0.011425206437706947: 100%|██████████| 23/23 [00:15<00:00,  1.53it/s]


Acc on val in epoch 8 is: 0.7612312811980033


Current loss in epoch 9 is 0.009706802666187286: 100%|██████████| 23/23 [00:14<00:00,  1.57it/s]


Acc on val in epoch 9 is: 0.7870216306156406


Current loss in epoch 10 is 0.000983488280326128: 100%|██████████| 23/23 [00:14<00:00,  1.56it/s]


Acc on val in epoch 10 is: 0.7104825291181365


Current loss in epoch 11 is 0.0010055230231955647: 100%|██████████| 23/23 [00:14<00:00,  1.57it/s]


Acc on val in epoch 11 is: 0.7229617304492513


Current loss in epoch 12 is 0.0005215037381276488: 100%|██████████| 23/23 [00:15<00:00,  1.51it/s]


Acc on val in epoch 12 is: 0.6855241264559068


Current loss in epoch 13 is 0.0020373936276882887: 100%|██████████| 23/23 [00:14<00:00,  1.56it/s]


Acc on val in epoch 13 is: 0.740432612312812


Current loss in epoch 14 is 0.002047174144536257: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 14 is: 0.7271214642262895


Current loss in epoch 15 is 0.0012425773311406374: 100%|██████████| 23/23 [00:17<00:00,  1.33it/s]


Acc on val in epoch 15 is: 0.6988352745424293


Current loss in epoch 16 is 0.000568256713449955: 100%|██████████| 23/23 [00:15<00:00,  1.49it/s]


Acc on val in epoch 16 is: 0.6930116472545758


Current loss in epoch 17 is 0.00017893826588988304: 100%|██████████| 23/23 [00:14<00:00,  1.56it/s]


Acc on val in epoch 17 is: 0.6980033277870217


Current loss in epoch 18 is 0.0009913257090374827: 100%|██████████| 23/23 [00:17<00:00,  1.31it/s]


Acc on val in epoch 18 is: 0.7046589018302829


Current loss in epoch 19 is 0.0004225285956636071: 100%|██████████| 23/23 [00:16<00:00,  1.43it/s]


Acc on val in epoch 19 is: 0.6880199667221298
Fianl acc on test is:  0.9416722859069454


Current loss in epoch 0 is 0.4813193678855896: 100%|██████████| 23/23 [00:14<00:00,  1.60it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.19951847195625305: 100%|██████████| 23/23 [00:13<00:00,  1.67it/s]


Acc on val in epoch 1 is: 0.7054908485856906


Current loss in epoch 2 is 0.1580645889043808: 100%|██████████| 23/23 [00:13<00:00,  1.65it/s]


Acc on val in epoch 2 is: 0.8435940099833611


Current loss in epoch 3 is 0.10229385644197464: 100%|██████████| 23/23 [00:13<00:00,  1.67it/s]


Acc on val in epoch 3 is: 0.8752079866888519


Current loss in epoch 4 is 0.039562374353408813: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 4 is: 0.8128119800332779


Current loss in epoch 5 is 0.06025538593530655: 100%|██████████| 23/23 [00:14<00:00,  1.63it/s]


Acc on val in epoch 5 is: 0.8494176372712147


Current loss in epoch 6 is 0.026563892140984535: 100%|██████████| 23/23 [00:14<00:00,  1.54it/s]


Acc on val in epoch 6 is: 0.8136439267886856


Current loss in epoch 7 is 0.02018957957625389: 100%|██████████| 23/23 [00:14<00:00,  1.64it/s]


Acc on val in epoch 7 is: 0.8502495840266223


Current loss in epoch 8 is 0.02385101653635502: 100%|██████████| 23/23 [00:13<00:00,  1.65it/s]


Acc on val in epoch 8 is: 0.802828618968386


Current loss in epoch 9 is 0.021404588595032692: 100%|██████████| 23/23 [00:13<00:00,  1.67it/s]


Acc on val in epoch 9 is: 0.8111480865224625


Current loss in epoch 10 is 0.027701333165168762: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 10 is: 0.7970049916805324


Current loss in epoch 11 is 0.004371295683085918: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 11 is: 0.8452579034941764


Current loss in epoch 12 is 0.011026022024452686: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 12 is: 0.7662229617304492


Current loss in epoch 13 is 0.03512285649776459: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 13 is: 0.7895174708818635


Current loss in epoch 14 is 0.0059165265411138535: 100%|██████████| 23/23 [00:13<00:00,  1.65it/s]


Acc on val in epoch 14 is: 0.8419301164725458


Current loss in epoch 15 is 0.01606622524559498: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 15 is: 0.8227953410981698


Current loss in epoch 16 is 0.003681400790810585: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 16 is: 0.8352745424292846


Current loss in epoch 17 is 0.001791688147932291: 100%|██████████| 23/23 [00:13<00:00,  1.68it/s]


Acc on val in epoch 17 is: 0.8352745424292846


Current loss in epoch 18 is 0.004027295392006636: 100%|██████████| 23/23 [00:14<00:00,  1.62it/s]


Acc on val in epoch 18 is: 0.8186356073211315


Current loss in epoch 19 is 0.005929342471063137: 100%|██████████| 23/23 [00:13<00:00,  1.67it/s]


Acc on val in epoch 19 is: 0.8153078202995009
Fianl acc on test is:  0.914025623735671


In [None]:
import json
print(stride_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, 'strides.json')
with open(out_path, 'w') as f:
    json.dump(stride_to_acc, f, indent=2)

{25: 0.897167902899528, 100: 0.9416722859069454, 200: 0.914025623735671}


## Attention Heads
We evaluate different amount of attention heads $[2,4,8]$

In [None]:
attention_heads = [1,2,10]
head_to_acc = {}
seeds = [0,0,0]
for heads,seed in zip(attention_heads,seeds):
  set_seed(seed)
  model = CNN_1D(attn_heads=heads)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  head_to_acc[heads] = acc

Current loss in epoch 0 is 0.3440678119659424: 100%|██████████| 23/23 [00:17<00:00,  1.32it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.11173176765441895: 100%|██████████| 23/23 [00:16<00:00,  1.39it/s]


Acc on val in epoch 1 is: 0.653910149750416


Current loss in epoch 2 is 0.04350460693240166: 100%|██████████| 23/23 [00:15<00:00,  1.45it/s]


Acc on val in epoch 2 is: 0.697171381031614


Current loss in epoch 3 is 0.015563981607556343: 100%|██████████| 23/23 [00:15<00:00,  1.44it/s]


Acc on val in epoch 3 is: 0.7021630615640599


Current loss in epoch 4 is 0.0032094609923660755: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 4 is: 0.7287853577371048


Current loss in epoch 5 is 0.0038007108960300684: 100%|██████████| 23/23 [00:15<00:00,  1.45it/s]


Acc on val in epoch 5 is: 0.7678868552412645


Current loss in epoch 6 is 0.005219881888478994: 100%|██████████| 23/23 [00:16<00:00,  1.43it/s]


Acc on val in epoch 6 is: 0.7420965058236273


Current loss in epoch 7 is 0.002107198117300868: 100%|██████████| 23/23 [00:16<00:00,  1.43it/s]


Acc on val in epoch 7 is: 0.8044925124792013


Current loss in epoch 8 is 0.001911291852593422: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 8 is: 0.697171381031614


Current loss in epoch 9 is 0.0024983584880828857: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 9 is: 0.6896838602329451


Current loss in epoch 10 is 0.0015707690035924315: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 10 is: 0.7004991680532446


Current loss in epoch 11 is 0.00046750972978770733: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 11 is: 0.7171381031613977


Current loss in epoch 12 is 0.0004043747903779149: 100%|██████████| 23/23 [00:16<00:00,  1.44it/s]


Acc on val in epoch 12 is: 0.7312811980033278


Current loss in epoch 13 is 0.0005742821376770735: 100%|██████████| 23/23 [00:15<00:00,  1.44it/s]


Acc on val in epoch 13 is: 0.7653910149750416


Current loss in epoch 14 is 0.0003803216095548123: 100%|██████████| 23/23 [00:15<00:00,  1.45it/s]


Acc on val in epoch 14 is: 0.7579034941763727


Current loss in epoch 15 is 0.00044977484503760934: 100%|██████████| 23/23 [00:15<00:00,  1.45it/s]


Acc on val in epoch 15 is: 0.7712146422628952


Current loss in epoch 16 is 0.00017851062875706702: 100%|██████████| 23/23 [00:15<00:00,  1.44it/s]


Acc on val in epoch 16 is: 0.7412645590682196


Current loss in epoch 17 is 0.00017923199629876763: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 17 is: 0.7312811980033278


Current loss in epoch 18 is 0.000854405399877578: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 18 is: 0.7454242928452579


Current loss in epoch 19 is 0.00013067018880974501: 100%|██████████| 23/23 [00:16<00:00,  1.43it/s]


Acc on val in epoch 19 is: 0.7321131447587355
Fianl acc on test is:  0.928186109238031


Current loss in epoch 0 is 0.2936450242996216: 100%|██████████| 23/23 [00:16<00:00,  1.41it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.07999500632286072: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.021914461627602577: 100%|██████████| 23/23 [00:16<00:00,  1.41it/s]


Acc on val in epoch 2 is: 0.7354409317803661


Current loss in epoch 3 is 0.007459661923348904: 100%|██████████| 23/23 [00:16<00:00,  1.40it/s]


Acc on val in epoch 3 is: 0.7512479201331115


Current loss in epoch 4 is 0.0060965861193835735: 100%|██████████| 23/23 [00:16<00:00,  1.39it/s]


Acc on val in epoch 4 is: 0.781198003327787


Current loss in epoch 5 is 0.0062764654867351055: 100%|██████████| 23/23 [00:16<00:00,  1.41it/s]


Acc on val in epoch 5 is: 0.7653910149750416


Current loss in epoch 6 is 0.003917504567652941: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 6 is: 0.6988352745424293


Current loss in epoch 7 is 0.0019382293103262782: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Acc on val in epoch 7 is: 0.8410981697171381


Current loss in epoch 8 is 0.0010780591983348131: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Acc on val in epoch 8 is: 0.7237936772046589


Current loss in epoch 9 is 0.000613441807217896: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 9 is: 0.8169717138103162


Current loss in epoch 10 is 0.0005306230741553009: 100%|██████████| 23/23 [00:16<00:00,  1.36it/s]


Acc on val in epoch 10 is: 0.7237936772046589


Current loss in epoch 11 is 0.0003077685250900686: 100%|██████████| 23/23 [00:16<00:00,  1.40it/s]


Acc on val in epoch 11 is: 0.740432612312812


Current loss in epoch 12 is 0.0004090499714948237: 100%|██████████| 23/23 [00:16<00:00,  1.41it/s]


Acc on val in epoch 12 is: 0.7312811980033278


Current loss in epoch 13 is 0.0005303589277900755: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Acc on val in epoch 13 is: 0.7420965058236273


Current loss in epoch 14 is 0.0004623527347575873: 100%|██████████| 23/23 [00:16<00:00,  1.38it/s]


Acc on val in epoch 14 is: 0.7670549084858569


Current loss in epoch 15 is 0.000686073035467416: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Acc on val in epoch 15 is: 0.7712146422628952


Current loss in epoch 16 is 0.00023685467022005469: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Acc on val in epoch 16 is: 0.7387687188019967


Current loss in epoch 17 is 0.0002370279689785093: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Acc on val in epoch 17 is: 0.7354409317803661


Current loss in epoch 18 is 0.001152569311670959: 100%|██████████| 23/23 [00:16<00:00,  1.42it/s]


Acc on val in epoch 18 is: 0.7562396006655574


Current loss in epoch 19 is 0.00032759379246272147: 100%|██████████| 23/23 [00:16<00:00,  1.41it/s]


Acc on val in epoch 19 is: 0.7371048252911814
Fianl acc on test is:  0.928186109238031


Current loss in epoch 0 is 0.6317203044891357:  39%|███▉      | 9/23 [00:09<00:14,  1.05s/it]


KeyboardInterrupt: 

In [None]:
import json
print(head_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, 'attention_heads.json')
with open(out_path, 'w') as f:
    json.dump(head_to_acc, f, indent=2)

{10: 0.9204315576534052, 1: 0.928186109238031, 2: 0.928186109238831}


## Filter Size
We investigate different filter sizes $[10,20,40]$

In [None]:
filter_sizes = [4,12,40]
filter_to_acc = {}
seeds = [0,0,0]
for filter,seed in zip(filter_sizes,seeds):
  set_seed(seed)
  model = CNN_1D(filters=filter)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  filter_to_acc[filter] = acc

Current loss in epoch 0 is 0.6357710957527161: 100%|██████████| 23/23 [00:13<00:00,  1.72it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.4964757561683655: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.4929906725883484: 100%|██████████| 23/23 [00:13<00:00,  1.74it/s]


Acc on val in epoch 2 is: 0.6605657237936772


Current loss in epoch 3 is 0.30943265557289124: 100%|██████████| 23/23 [00:13<00:00,  1.73it/s]


Acc on val in epoch 3 is: 0.6447587354409318


Current loss in epoch 4 is 0.19491508603096008: 100%|██████████| 23/23 [00:13<00:00,  1.74it/s]


Acc on val in epoch 4 is: 0.6896838602329451


Current loss in epoch 5 is 0.16057467460632324: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 5 is: 0.7021630615640599


Current loss in epoch 6 is 0.13362939655780792: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 6 is: 0.6896838602329451


Current loss in epoch 7 is 0.043126605451107025: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 7 is: 0.7271214642262895


Current loss in epoch 8 is 0.028560791164636612: 100%|██████████| 23/23 [00:13<00:00,  1.76it/s]


Acc on val in epoch 8 is: 0.6888519134775375


Current loss in epoch 9 is 0.03880002722144127: 100%|██████████| 23/23 [00:13<00:00,  1.74it/s]


Acc on val in epoch 9 is: 0.6622296173044925


Current loss in epoch 10 is 0.02018074505031109: 100%|██████████| 23/23 [00:13<00:00,  1.73it/s]


Acc on val in epoch 10 is: 0.6946755407653911


Current loss in epoch 11 is 0.012879421003162861: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 11 is: 0.6564059900166389


Current loss in epoch 12 is 0.02564888820052147: 100%|██████████| 23/23 [00:14<00:00,  1.63it/s]


Acc on val in epoch 12 is: 0.6697171381031614


Current loss in epoch 13 is 0.021129729226231575: 100%|██████████| 23/23 [00:13<00:00,  1.74it/s]


Acc on val in epoch 13 is: 0.6530782029950083


Current loss in epoch 14 is 0.015136078000068665: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 14 is: 0.7046589018302829


Current loss in epoch 15 is 0.007713708095252514: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 15 is: 0.6921797004991681


Current loss in epoch 16 is 0.025841359049081802: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 16 is: 0.6905158069883528


Current loss in epoch 17 is 0.04530082270503044: 100%|██████████| 23/23 [00:13<00:00,  1.74it/s]


Acc on val in epoch 17 is: 0.6788685524126455


Current loss in epoch 18 is 0.034756850451231: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Acc on val in epoch 18 is: 0.7129783693843594


Current loss in epoch 19 is 0.02776520699262619: 100%|██████████| 23/23 [00:13<00:00,  1.74it/s]


Acc on val in epoch 19 is: 0.7121464226289518
Fianl acc on test is:  0.899527983816588


Current loss in epoch 0 is 0.45733365416526794: 100%|██████████| 23/23 [00:15<00:00,  1.50it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.17012158036231995: 100%|██████████| 23/23 [00:14<00:00,  1.53it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.07485511898994446: 100%|██████████| 23/23 [00:14<00:00,  1.54it/s]


Acc on val in epoch 2 is: 0.8427620632279534


Current loss in epoch 3 is 0.052813589572906494: 100%|██████████| 23/23 [00:15<00:00,  1.53it/s]


Acc on val in epoch 3 is: 0.7271214642262895


Current loss in epoch 4 is 0.03208640217781067: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 4 is: 0.7770382695507487


Current loss in epoch 5 is 0.012048076838254929: 100%|██████████| 23/23 [00:15<00:00,  1.53it/s]


Acc on val in epoch 5 is: 0.7287853577371048


Current loss in epoch 6 is 0.020850766450166702: 100%|██████████| 23/23 [00:15<00:00,  1.52it/s]


Acc on val in epoch 6 is: 0.7329450915141431


Current loss in epoch 7 is 0.0033205158542841673: 100%|██████████| 23/23 [00:15<00:00,  1.49it/s]


Acc on val in epoch 7 is: 0.7312811980033278


Current loss in epoch 8 is 0.00545938964933157: 100%|██████████| 23/23 [00:15<00:00,  1.53it/s]


Acc on val in epoch 8 is: 0.7312811980033278


Current loss in epoch 9 is 0.003315345384180546: 100%|██████████| 23/23 [00:15<00:00,  1.53it/s]


Acc on val in epoch 9 is: 0.7470881863560732


Current loss in epoch 10 is 0.003214108757674694: 100%|██████████| 23/23 [00:15<00:00,  1.52it/s]


Acc on val in epoch 10 is: 0.7204658901830283


Current loss in epoch 11 is 0.009133346378803253: 100%|██████████| 23/23 [00:15<00:00,  1.48it/s]


Acc on val in epoch 11 is: 0.7354409317803661


Current loss in epoch 12 is 0.009414403699338436: 100%|██████████| 23/23 [00:14<00:00,  1.54it/s]


Acc on val in epoch 12 is: 0.7121464226289518


Current loss in epoch 13 is 0.009670902974903584: 100%|██████████| 23/23 [00:14<00:00,  1.54it/s]


Acc on val in epoch 13 is: 0.7246256239600666


Current loss in epoch 14 is 0.002175457077100873: 100%|██████████| 23/23 [00:15<00:00,  1.53it/s]


Acc on val in epoch 14 is: 0.7387687188019967


Current loss in epoch 15 is 0.0006408029585145414: 100%|██████████| 23/23 [00:15<00:00,  1.47it/s]


Acc on val in epoch 15 is: 0.8286189683860233


Current loss in epoch 16 is 0.0011359284399077296: 100%|██████████| 23/23 [00:15<00:00,  1.46it/s]


Acc on val in epoch 16 is: 0.7579034941763727


Current loss in epoch 17 is 0.0031511769630014896: 100%|██████████| 23/23 [00:15<00:00,  1.52it/s]


Acc on val in epoch 17 is: 0.7262895174708819


Current loss in epoch 18 is 0.002833228325471282: 100%|██████████| 23/23 [00:15<00:00,  1.51it/s]


Acc on val in epoch 18 is: 0.7088186356073212


Current loss in epoch 19 is 0.004000222776085138: 100%|██████████| 23/23 [00:15<00:00,  1.52it/s]


Acc on val in epoch 19 is: 0.7063227953410982
Fianl acc on test is:  0.9359406608226568


Current loss in epoch 0 is 0.15395379066467285: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 0 is: 0.6522462562396006


Current loss in epoch 1 is 0.03555653244256973: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 1 is: 0.6505823627287853


Current loss in epoch 2 is 0.029251182451844215: 100%|██████████| 23/23 [00:23<00:00,  1.01s/it]


Acc on val in epoch 2 is: 0.8111480865224625


Current loss in epoch 3 is 0.011541858315467834: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 3 is: 0.7312811980033278


Current loss in epoch 4 is 0.011085176840424538: 100%|██████████| 23/23 [00:24<00:00,  1.04s/it]


Acc on val in epoch 4 is: 0.7354409317803661


Current loss in epoch 5 is 0.02316027507185936: 100%|██████████| 23/23 [00:23<00:00,  1.01s/it]


Acc on val in epoch 5 is: 0.781198003327787


Current loss in epoch 6 is 0.0168097335845232: 100%|██████████| 23/23 [00:23<00:00,  1.00s/it]


Acc on val in epoch 6 is: 0.8327787021630616


Current loss in epoch 7 is 0.01727220229804516: 100%|██████████| 23/23 [00:23<00:00,  1.01s/it]


Acc on val in epoch 7 is: 0.7237936772046589


Current loss in epoch 8 is 0.00043371482752263546: 100%|██████████| 23/23 [00:23<00:00,  1.00s/it]


Acc on val in epoch 8 is: 0.6988352745424293


Current loss in epoch 9 is 0.001255993265658617: 100%|██████████| 23/23 [00:23<00:00,  1.00s/it]


Acc on val in epoch 9 is: 0.718801996672213


Current loss in epoch 10 is 0.0003631789004430175: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 10 is: 0.7287853577371048


Current loss in epoch 11 is 0.0012931539677083492: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 11 is: 0.7271214642262895


Current loss in epoch 12 is 0.00015081548190210015: 100%|██████████| 23/23 [00:23<00:00,  1.03s/it]


Acc on val in epoch 12 is: 0.71630615640599


Current loss in epoch 13 is 0.00023561742273159325: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 13 is: 0.7154742096505824


Current loss in epoch 14 is 0.0001658176479395479: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 14 is: 0.718801996672213


Current loss in epoch 15 is 0.00012074138066964224: 100%|██████████| 23/23 [00:23<00:00,  1.01s/it]


Acc on val in epoch 15 is: 0.718801996672213


Current loss in epoch 16 is 0.00023554930521640927: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 16 is: 0.7154742096505824


Current loss in epoch 17 is 9.694276377558708e-05: 100%|██████████| 23/23 [00:23<00:00,  1.02s/it]


Acc on val in epoch 17 is: 0.7146422628951747


Current loss in epoch 18 is 0.0005030892789363861: 100%|██████████| 23/23 [00:23<00:00,  1.01s/it]


Acc on val in epoch 18 is: 0.7154742096505824


Current loss in epoch 19 is 7.990912126842886e-05: 100%|██████████| 23/23 [00:23<00:00,  1.03s/it]


Acc on val in epoch 19 is: 0.718801996672213
Fianl acc on test is:  0.9221173297370195


In [None]:
import json
print(filter_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, 'filters.json')
with open(out_path, 'w') as f:
    json.dump(filter_to_acc, f, indent=2)

{4: 0.899527983816588, 12: 0.9359406608226568, 40: 0.9221173297370195}


## 3-Lead Ablation

In [8]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",save_path="/content/drive/MyDrive/ptbdb/preprocessed_data_3leads.pt", desired_leads=['i','ii','iii'],train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
No data path given -> create dataset
Data saved at:  /content/drive/MyDrive/ptbdb/preprocessed_data_3leads.pt


In [16]:
trial_to_acc = {}
seeds = [0,1,2]
trials = 3
lr = 0.001
epochs = 20
device = "cuda" if torch.cuda.is_available() else "cpu"
for trial,seed in zip(range(trials),seeds):
  set_seed(seed)
  model = CNN_1D(num_leads = 3)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval(model, test_loader, device=device)
  trial_to_acc[trial] = accs

Current loss in epoch 0 is 0.21896103024482727: 100%|██████████| 23/23 [00:01<00:00, 11.90it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.20157793164253235: 100%|██████████| 23/23 [00:01<00:00, 12.21it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.06580746173858643: 100%|██████████| 23/23 [00:01<00:00, 12.24it/s]


Acc on val in epoch 2 is: 0.6713810316139767


Current loss in epoch 3 is 0.04865026846528053: 100%|██████████| 23/23 [00:01<00:00, 12.20it/s]


Acc on val in epoch 3 is: 0.7079866888519135


Current loss in epoch 4 is 0.02782403863966465: 100%|██████████| 23/23 [00:01<00:00, 12.21it/s]


Acc on val in epoch 4 is: 0.7246256239600666


Current loss in epoch 5 is 0.013031884096562862: 100%|██████████| 23/23 [00:01<00:00, 12.17it/s]


Acc on val in epoch 5 is: 0.7562396006655574


Current loss in epoch 6 is 0.007594364229589701: 100%|██████████| 23/23 [00:01<00:00, 12.25it/s]


Acc on val in epoch 6 is: 0.8785357737104825


Current loss in epoch 7 is 0.007080988958477974: 100%|██████████| 23/23 [00:01<00:00, 12.14it/s]


Acc on val in epoch 7 is: 0.8219633943427621


Current loss in epoch 8 is 0.009694814682006836: 100%|██████████| 23/23 [00:01<00:00, 12.12it/s]


Acc on val in epoch 8 is: 0.8527454242928453


Current loss in epoch 9 is 0.004365976434201002: 100%|██████████| 23/23 [00:01<00:00, 12.07it/s]


Acc on val in epoch 9 is: 0.8469217970049917


Current loss in epoch 10 is 0.0036120370496064425: 100%|██████████| 23/23 [00:01<00:00, 12.04it/s]


Acc on val in epoch 10 is: 0.8419301164725458


Current loss in epoch 11 is 0.00295893638394773: 100%|██████████| 23/23 [00:01<00:00, 12.19it/s]


Acc on val in epoch 11 is: 0.8277870216306157


Current loss in epoch 12 is 0.0026678743306547403: 100%|██████████| 23/23 [00:01<00:00, 12.15it/s]


Acc on val in epoch 12 is: 0.8369384359400999


Current loss in epoch 13 is 0.004496629815548658: 100%|██████████| 23/23 [00:01<00:00, 12.20it/s]


Acc on val in epoch 13 is: 0.8552412645590682


Current loss in epoch 14 is 0.0027323716785758734: 100%|██████████| 23/23 [00:01<00:00, 12.13it/s]


Acc on val in epoch 14 is: 0.8560732113144759


Current loss in epoch 15 is 0.001749908784404397: 100%|██████████| 23/23 [00:01<00:00, 12.05it/s]


Acc on val in epoch 15 is: 0.8136439267886856


Current loss in epoch 16 is 0.0016082213260233402: 100%|██████████| 23/23 [00:01<00:00, 11.90it/s]


Acc on val in epoch 16 is: 0.8494176372712147


Current loss in epoch 17 is 0.0024539739824831486: 100%|██████████| 23/23 [00:01<00:00, 12.16it/s]


Acc on val in epoch 17 is: 0.8594009983361065


Current loss in epoch 18 is 0.0019597725477069616: 100%|██████████| 23/23 [00:01<00:00, 12.18it/s]


Acc on val in epoch 18 is: 0.8477537437603994


Current loss in epoch 19 is 0.0013562391977757215: 100%|██████████| 23/23 [00:01<00:00, 12.18it/s]


Acc on val in epoch 19 is: 0.8535773710482529


Current loss in epoch 0 is 0.2719905972480774: 100%|██████████| 23/23 [00:01<00:00, 12.18it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.10489430278539658: 100%|██████████| 23/23 [00:01<00:00, 12.12it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.05391469970345497: 100%|██████████| 23/23 [00:01<00:00, 12.15it/s]


Acc on val in epoch 2 is: 0.7004991680532446


Current loss in epoch 3 is 0.05212898924946785: 100%|██████████| 23/23 [00:01<00:00, 12.16it/s]


Acc on val in epoch 3 is: 0.6955074875207987


Current loss in epoch 4 is 0.02481275610625744: 100%|██████████| 23/23 [00:01<00:00, 11.88it/s]


Acc on val in epoch 4 is: 0.8094841930116472


Current loss in epoch 5 is 0.014755472540855408: 100%|██████████| 23/23 [00:01<00:00, 11.76it/s]


Acc on val in epoch 5 is: 0.8111480865224625


Current loss in epoch 6 is 0.015217168256640434: 100%|██████████| 23/23 [00:01<00:00, 11.86it/s]


Acc on val in epoch 6 is: 0.9068219633943427


Current loss in epoch 7 is 0.011839642189443111: 100%|██████████| 23/23 [00:01<00:00, 12.03it/s]


Acc on val in epoch 7 is: 0.8227953410981698


Current loss in epoch 8 is 0.024915961548686028: 100%|██████████| 23/23 [00:01<00:00, 12.07it/s]


Acc on val in epoch 8 is: 0.7221297836938436


Current loss in epoch 9 is 0.005242537707090378: 100%|██████████| 23/23 [00:01<00:00, 12.04it/s]


Acc on val in epoch 9 is: 0.8760399334442596


Current loss in epoch 10 is 0.0033706859685480595: 100%|██████████| 23/23 [00:01<00:00, 11.98it/s]


Acc on val in epoch 10 is: 0.7445923460898503


Current loss in epoch 11 is 0.002800785470753908: 100%|██████████| 23/23 [00:01<00:00, 11.98it/s]


Acc on val in epoch 11 is: 0.8885191347753744


Current loss in epoch 12 is 0.0030057039111852646: 100%|██████████| 23/23 [00:01<00:00, 11.88it/s]


Acc on val in epoch 12 is: 0.889351081530782


Current loss in epoch 13 is 0.002147951163351536: 100%|██████████| 23/23 [00:01<00:00, 12.00it/s]


Acc on val in epoch 13 is: 0.889351081530782


Current loss in epoch 14 is 0.0020853590685874224: 100%|██████████| 23/23 [00:01<00:00, 12.02it/s]


Acc on val in epoch 14 is: 0.8943427620632279


Current loss in epoch 15 is 0.0015853388467803597: 100%|██████████| 23/23 [00:02<00:00, 11.17it/s]


Acc on val in epoch 15 is: 0.8885191347753744


Current loss in epoch 16 is 0.002559979911893606: 100%|██████████| 23/23 [00:01<00:00, 12.08it/s]


Acc on val in epoch 16 is: 0.8843594009983361


Current loss in epoch 17 is 0.0012994555290788412: 100%|██████████| 23/23 [00:01<00:00, 12.10it/s]


Acc on val in epoch 17 is: 0.9009983361064892


Current loss in epoch 18 is 0.0015441946452483535: 100%|██████████| 23/23 [00:01<00:00, 11.99it/s]


Acc on val in epoch 18 is: 0.8851913477537438


Current loss in epoch 19 is 0.0012956848368048668: 100%|██████████| 23/23 [00:01<00:00, 12.06it/s]


Acc on val in epoch 19 is: 0.8976705490848585


Current loss in epoch 0 is 0.35472366213798523: 100%|██████████| 23/23 [00:01<00:00, 12.06it/s]


Acc on val in epoch 0 is: 0.653910149750416


Current loss in epoch 1 is 0.14254462718963623: 100%|██████████| 23/23 [00:01<00:00, 12.18it/s]


Acc on val in epoch 1 is: 0.6480865224625624


Current loss in epoch 2 is 0.06554996967315674: 100%|██████████| 23/23 [00:01<00:00, 12.12it/s]


Acc on val in epoch 2 is: 0.6572379367720466


Current loss in epoch 3 is 0.041088685393333435: 100%|██████████| 23/23 [00:01<00:00, 12.17it/s]


Acc on val in epoch 3 is: 0.7753743760399334


Current loss in epoch 4 is 0.03351463004946709: 100%|██████████| 23/23 [00:01<00:00, 12.08it/s]


Acc on val in epoch 4 is: 0.7229617304492513


Current loss in epoch 5 is 0.014234542846679688: 100%|██████████| 23/23 [00:01<00:00, 11.94it/s]


Acc on val in epoch 5 is: 0.8311148086522463


Current loss in epoch 6 is 0.016712317243218422: 100%|██████████| 23/23 [00:01<00:00, 12.19it/s]


Acc on val in epoch 6 is: 0.8078202995008319


Current loss in epoch 7 is 0.0074056824669241905: 100%|██████████| 23/23 [00:01<00:00, 12.08it/s]


Acc on val in epoch 7 is: 0.8777038269550749


Current loss in epoch 8 is 0.0056455316953361034: 100%|██████████| 23/23 [00:01<00:00, 12.13it/s]


Acc on val in epoch 8 is: 0.7537437603993344


Current loss in epoch 9 is 0.00635896623134613: 100%|██████████| 23/23 [00:01<00:00, 12.12it/s]


Acc on val in epoch 9 is: 0.718801996672213


Current loss in epoch 10 is 0.013146613724529743: 100%|██████████| 23/23 [00:01<00:00, 12.06it/s]


Acc on val in epoch 10 is: 0.8169717138103162


Current loss in epoch 11 is 0.05486049875617027: 100%|██████████| 23/23 [00:01<00:00, 12.02it/s]


Acc on val in epoch 11 is: 0.7628951747088186


Current loss in epoch 12 is 0.012332609854638577: 100%|██████████| 23/23 [00:01<00:00, 12.07it/s]


Acc on val in epoch 12 is: 0.7004991680532446


Current loss in epoch 13 is 0.002792987274006009: 100%|██████████| 23/23 [00:01<00:00, 12.03it/s]


Acc on val in epoch 13 is: 0.7229617304492513


Current loss in epoch 14 is 0.002807741519063711: 100%|██████████| 23/23 [00:01<00:00, 12.06it/s]


Acc on val in epoch 14 is: 0.8186356073211315


Current loss in epoch 15 is 0.0021398747339844704: 100%|██████████| 23/23 [00:01<00:00, 12.09it/s]


Acc on val in epoch 15 is: 0.8352745424292846


Current loss in epoch 16 is 0.004944419022649527: 100%|██████████| 23/23 [00:01<00:00, 11.97it/s]


Acc on val in epoch 16 is: 0.8169717138103162


Current loss in epoch 17 is 0.0014884758275002241: 100%|██████████| 23/23 [00:01<00:00, 12.07it/s]


Acc on val in epoch 17 is: 0.8103161397670549


Current loss in epoch 18 is 0.0028516140300780535: 100%|██████████| 23/23 [00:01<00:00, 12.07it/s]


Acc on val in epoch 18 is: 0.831946755407654


Current loss in epoch 19 is 0.00129319925326854: 100%|██████████| 23/23 [00:01<00:00, 12.09it/s]


Acc on val in epoch 19 is: 0.8311148086522463


In [17]:
import json
print(trial_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, '3_lead_experiment.json')
with open(out_path, 'w') as f:
    json.dump(trial_to_acc, f, indent=2)

{0: 0.9207687120701281, 1: 0.9238031018206339, 2: 0.9177343223196224}


In [24]:

print("Average is: ", (trial_to_acc[0] + trial_to_acc[1] + trial_to_acc[2]) / 3)

Average is:  0.9207687120701281


In [18]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt",train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
Load data from given path


In [22]:
trial_to_acc = {}
seeds = [0,1,2]
trials = 3
lr = 0.001
epochs = 20
device = "cuda" if torch.cuda.is_available() else "cpu"
for trial,seed in zip(range(trials),seeds):
  set_seed(seed)
  model = CNN_1D(num_leads = 12)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval(model, test_loader, device=device)
  trial_to_acc[trial] = accs

Current loss in epoch 0 is 0.08360694348812103: 100%|██████████| 23/23 [00:03<00:00,  6.74it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.04628802090883255: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.007345775607973337: 100%|██████████| 23/23 [00:03<00:00,  6.80it/s]


Acc on val in epoch 2 is: 0.6946755407653911


Current loss in epoch 3 is 0.031864557415246964: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 3 is: 0.7487520798668885


Current loss in epoch 4 is 0.0037259787786751986: 100%|██████████| 23/23 [00:03<00:00,  6.45it/s]


Acc on val in epoch 4 is: 0.78369384359401


Current loss in epoch 5 is 0.0017308311071246862: 100%|██████████| 23/23 [00:03<00:00,  6.74it/s]


Acc on val in epoch 5 is: 0.6938435940099834


Current loss in epoch 6 is 0.004814693238586187: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 6 is: 0.7254575707154742


Current loss in epoch 7 is 0.004150109365582466: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 7 is: 0.8394342762063228


Current loss in epoch 8 is 0.020144077017903328: 100%|██████████| 23/23 [00:03<00:00,  6.80it/s]


Acc on val in epoch 8 is: 0.8302828618968386


Current loss in epoch 9 is 0.015454555861651897: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 9 is: 0.8128119800332779


Current loss in epoch 10 is 0.003554391674697399: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 10 is: 0.8302828618968386


Current loss in epoch 11 is 0.0005613495013676584: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 11 is: 0.6589018302828619


Current loss in epoch 12 is 0.0037457516882568598: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]


Acc on val in epoch 12 is: 0.6763727121464226


Current loss in epoch 13 is 0.00023705829516984522: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 13 is: 0.6613976705490848


Current loss in epoch 14 is 0.0008414376643486321: 100%|██████████| 23/23 [00:03<00:00,  6.57it/s]


Acc on val in epoch 14 is: 0.6589018302828619


Current loss in epoch 15 is 0.0008028267184272408: 100%|██████████| 23/23 [00:03<00:00,  6.69it/s]


Acc on val in epoch 15 is: 0.6638935108153078


Current loss in epoch 16 is 0.00034881947794929147: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 16 is: 0.6589018302828619


Current loss in epoch 17 is 0.00023410480935126543: 100%|██████████| 23/23 [00:03<00:00,  6.74it/s]


Acc on val in epoch 17 is: 0.6622296173044925


Current loss in epoch 18 is 0.00023087805311661214: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 18 is: 0.6638935108153078


Current loss in epoch 19 is 0.0001931805891217664: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 19 is: 0.6638935108153078


Current loss in epoch 0 is 0.1295752227306366: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.03063192404806614: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 1 is: 0.6555740432612313


Current loss in epoch 2 is 0.01530728954821825: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 2 is: 0.7246256239600666


Current loss in epoch 3 is 0.007078195922076702: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 3 is: 0.6955074875207987


Current loss in epoch 4 is 0.013541890308260918: 100%|██████████| 23/23 [00:03<00:00,  6.75it/s]


Acc on val in epoch 4 is: 0.7071547420965059


Current loss in epoch 5 is 0.016276802867650986: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 5 is: 0.7279534109816972


Current loss in epoch 6 is 0.0023197655100375414: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 6 is: 0.7978369384359401


Current loss in epoch 7 is 0.003851060289889574: 100%|██████████| 23/23 [00:03<00:00,  6.78it/s]


Acc on val in epoch 7 is: 0.8801996672212978


Current loss in epoch 8 is 0.03566395118832588: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 8 is: 0.8851913477537438


Current loss in epoch 9 is 0.009812265634536743: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 9 is: 0.8119800332778702


Current loss in epoch 10 is 0.0006557105225510895: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 10 is: 0.7254575707154742


Current loss in epoch 11 is 0.000840689055621624: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 11 is: 0.7221297836938436


Current loss in epoch 12 is 0.0003726537979673594: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 12 is: 0.7104825291181365


Current loss in epoch 13 is 0.013589262031018734: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 13 is: 0.7861896838602329


Current loss in epoch 14 is 0.012108939699828625: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 14 is: 0.8810316139767055


Current loss in epoch 15 is 0.006050512194633484: 100%|██████████| 23/23 [00:03<00:00,  6.63it/s]


Acc on val in epoch 15 is: 0.8452579034941764


Current loss in epoch 16 is 0.006849335506558418: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]


Acc on val in epoch 16 is: 0.7762063227953411


Current loss in epoch 17 is 0.006331088021397591: 100%|██████████| 23/23 [00:03<00:00,  6.79it/s]


Acc on val in epoch 17 is: 0.8818635607321131


Current loss in epoch 18 is 0.0007789129158481956: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 18 is: 0.7479201331114809


Current loss in epoch 19 is 0.0008773405570536852: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 19 is: 0.8818635607321131


Current loss in epoch 0 is 0.22653231024742126: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 0 is: 0.6522462562396006


Current loss in epoch 1 is 0.07319629192352295: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 1 is: 0.6980033277870217


Current loss in epoch 2 is 0.026495659723877907: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]


Acc on val in epoch 2 is: 0.846089850249584


Current loss in epoch 3 is 0.009397050365805626: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 3 is: 0.6871880199667221


Current loss in epoch 4 is 0.04287714138627052: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 4 is: 0.7479201331114809


Current loss in epoch 5 is 0.007469009142369032: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 5 is: 0.7113144758735441


Current loss in epoch 6 is 0.007316102739423513: 100%|██████████| 23/23 [00:03<00:00,  6.75it/s]


Acc on val in epoch 6 is: 0.8344425956738769


Current loss in epoch 7 is 0.003638891503214836: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 7 is: 0.71630615640599


Current loss in epoch 8 is 0.037474166601896286: 100%|██████████| 23/23 [00:03<00:00,  6.43it/s]


Acc on val in epoch 8 is: 0.8103161397670549


Current loss in epoch 9 is 0.010849053971469402: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 9 is: 0.7029950083194676


Current loss in epoch 10 is 0.0015870861243456602: 100%|██████████| 23/23 [00:03<00:00,  6.78it/s]


Acc on val in epoch 10 is: 0.7420965058236273


Current loss in epoch 11 is 0.0020609914790838957: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 11 is: 0.7262895174708819


Current loss in epoch 12 is 0.0005561269936151803: 100%|██████████| 23/23 [00:03<00:00,  6.83it/s]


Acc on val in epoch 12 is: 0.7196339434276207


Current loss in epoch 13 is 0.00040279937093146145: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 13 is: 0.7221297836938436


Current loss in epoch 14 is 0.0020644289907068014: 100%|██████████| 23/23 [00:03<00:00,  6.77it/s]


Acc on val in epoch 14 is: 0.7262895174708819


Current loss in epoch 15 is 0.002141268691048026: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]


Acc on val in epoch 15 is: 0.7204658901830283


Current loss in epoch 16 is 0.0002005507267313078: 100%|██████████| 23/23 [00:03<00:00,  6.81it/s]


Acc on val in epoch 16 is: 0.718801996672213


Current loss in epoch 17 is 0.0002809906145557761: 100%|██████████| 23/23 [00:03<00:00,  6.79it/s]


Acc on val in epoch 17 is: 0.7146422628951747


Current loss in epoch 18 is 0.00013717552064917982: 100%|██████████| 23/23 [00:03<00:00,  6.82it/s]


Acc on val in epoch 18 is: 0.7138103161397671


Current loss in epoch 19 is 0.00023848038108553737: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]


Acc on val in epoch 19 is: 0.71630615640599


In [23]:
import json
print(trial_to_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, '12_lead_experiment.json')
with open(out_path, 'w') as f:
    json.dump(trial_to_acc, f, indent=2)

{0: 0.9214430209035739, 1: 0.9325691166554282, 2: 0.9460552933243426}


In [26]:

print("Average is: ", (trial_to_acc[0] + trial_to_acc[1] + trial_to_acc[2]) / 3)

Average is:  0.9333558102944481


# 5-Fold Cross Validation
In this section we compare our model to the 2D CNN model on 5-Fold Cross Validation

In [20]:
seeds = [0,0,0]
#my imports
from comp_med.data.preprocessing import get_record_paths, filter_records, split_patients

#get patient data and split on my own
path="/content/drive/MyDrive/ptbdb"
records = get_record_paths(path)
filtered_records = filter_records(records)
#now obtain k-fold split
train_folds, test_folds = split_patients(records, k_fold=5)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
k_fold cv has no validation set


In [21]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [24]:
k_fold_acc = {}
for i, train_ids, test_ids in zip(range(5),train_folds,test_folds):
  set_seed(0)
  train_loader,val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt",train_ids=train_ids, test_ids = test_ids, val_ids=[],train_ratio=0.8, val_ratio=0)
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, train_loader ,epochs, device, return_loss=False) #train loader as val because we don't have val laoder
  metrics=eval(model, test_loader, device, all_metrics=True)
  print("Fianl metrics on test is: ",metrics )
  k_fold_acc[i] = metrics
  del model, optimizer, criterion
  del train_loader, val_loader, test_loader
  torch.cuda.empty_cache()

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.22659553587436676: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s]


Acc on val in epoch 0 is: 0.6487080103359173


Current loss in epoch 1 is 0.15353482961654663: 100%|██████████| 31/31 [00:05<00:00,  5.99it/s]


Acc on val in epoch 1 is: 0.9607235142118863


Current loss in epoch 2 is 0.057142093777656555: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 2 is: 0.9998708010335917


Current loss in epoch 3 is 0.014143585227429867: 100%|██████████| 31/31 [00:05<00:00,  5.90it/s]


Acc on val in epoch 3 is: 0.9988372093023256


Current loss in epoch 4 is 0.005974425468593836: 100%|██████████| 31/31 [00:05<00:00,  5.97it/s]


Acc on val in epoch 4 is: 0.9962532299741602


Current loss in epoch 5 is 0.007371600717306137: 100%|██████████| 31/31 [00:05<00:00,  5.91it/s]


Acc on val in epoch 5 is: 0.9945736434108527


Current loss in epoch 6 is 0.0014881030656397343: 100%|██████████| 31/31 [00:05<00:00,  5.99it/s]


Acc on val in epoch 6 is: 0.9998708010335917


Current loss in epoch 7 is 0.012833620421588421: 100%|██████████| 31/31 [00:05<00:00,  5.90it/s]


Acc on val in epoch 7 is: 0.8090439276485788


Current loss in epoch 8 is 0.002956148236989975: 100%|██████████| 31/31 [00:05<00:00,  5.91it/s]


Acc on val in epoch 8 is: 0.9790697674418605


Current loss in epoch 9 is 0.010070526041090488: 100%|██████████| 31/31 [00:05<00:00,  5.98it/s]


Acc on val in epoch 9 is: 0.9989664082687338


Current loss in epoch 10 is 0.0010090130381286144: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 10 is: 1.0


Current loss in epoch 11 is 0.00014885059499647468: 100%|██████████| 31/31 [00:05<00:00,  5.95it/s]


Acc on val in epoch 11 is: 1.0


Current loss in epoch 12 is 0.009283722378313541: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 12 is: 1.0


Current loss in epoch 13 is 0.00033112725941464305: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 13 is: 1.0


Current loss in epoch 14 is 0.0002602368185762316: 100%|██████████| 31/31 [00:05<00:00,  5.97it/s]


Acc on val in epoch 14 is: 1.0


Current loss in epoch 15 is 0.0008254482527263463: 100%|██████████| 31/31 [00:05<00:00,  5.92it/s]


Acc on val in epoch 15 is: 1.0


Current loss in epoch 16 is 0.0037790737114846706: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 0.08014591783285141: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 17 is: 0.993798449612403


Current loss in epoch 18 is 0.034321606159210205: 100%|██████████| 31/31 [00:05<00:00,  6.00it/s]


Acc on val in epoch 18 is: 0.9002583979328166


Current loss in epoch 19 is 0.001096317428164184: 100%|██████████| 31/31 [00:05<00:00,  5.96it/s]


Acc on val in epoch 19 is: 0.9767441860465116
Fianl metrics on test is:  {'acc': 0.8128708352350524, 'prec': 0.7481572481572482, 'sens': 1.0, 'spec': np.float64(0.578622816032888), 'f1': 0.8559381588193956}
After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.030367756262421608: 100%|██████████| 33/33 [00:05<00:00,  5.92it/s]


Acc on val in epoch 0 is: 0.6822576746434614


Current loss in epoch 1 is 0.021808981895446777: 100%|██████████| 33/33 [00:05<00:00,  5.94it/s]


Acc on val in epoch 1 is: 0.996495044718395


Current loss in epoch 2 is 0.01605135016143322: 100%|██████████| 33/33 [00:05<00:00,  5.95it/s]


Acc on val in epoch 2 is: 0.9989122552574329


Current loss in epoch 3 is 0.01141443196684122: 100%|██████████| 33/33 [00:05<00:00,  5.93it/s]


Acc on val in epoch 3 is: 0.9990331157843848


Current loss in epoch 4 is 0.0025430992245674133: 100%|██████████| 33/33 [00:05<00:00,  5.92it/s]


Acc on val in epoch 4 is: 0.9910563210055596


Current loss in epoch 5 is 0.027844253927469254: 100%|██████████| 33/33 [00:05<00:00,  5.97it/s]


Acc on val in epoch 5 is: 0.9847715736040609


Current loss in epoch 6 is 0.0010147879365831614: 100%|██████████| 33/33 [00:05<00:00,  5.93it/s]


Acc on val in epoch 6 is: 0.9998791394730481


Current loss in epoch 7 is 0.0013588623842224479: 100%|██████████| 33/33 [00:05<00:00,  5.95it/s]


Acc on val in epoch 7 is: 1.0


Current loss in epoch 8 is 0.002337454352527857: 100%|██████████| 33/33 [00:05<00:00,  5.96it/s]


Acc on val in epoch 8 is: 1.0


Current loss in epoch 9 is 0.007948371581733227: 100%|██████████| 33/33 [00:05<00:00,  5.96it/s]


Acc on val in epoch 9 is: 1.0


Current loss in epoch 10 is 0.00044998538214713335: 100%|██████████| 33/33 [00:05<00:00,  5.96it/s]


Acc on val in epoch 10 is: 0.9991539763113367


Current loss in epoch 11 is 0.000556230777874589: 100%|██████████| 33/33 [00:05<00:00,  5.96it/s]


Acc on val in epoch 11 is: 1.0


Current loss in epoch 12 is 0.00013319257413968444: 100%|██████████| 33/33 [00:05<00:00,  5.97it/s]


Acc on val in epoch 12 is: 1.0


Current loss in epoch 13 is 0.00014057611406315118: 100%|██████████| 33/33 [00:05<00:00,  5.93it/s]


Acc on val in epoch 13 is: 1.0


Current loss in epoch 14 is 9.838840924203396e-05: 100%|██████████| 33/33 [00:05<00:00,  5.95it/s]


Acc on val in epoch 14 is: 1.0


Current loss in epoch 15 is 0.0002547555195633322: 100%|██████████| 33/33 [00:05<00:00,  5.96it/s]


Acc on val in epoch 15 is: 1.0


Current loss in epoch 16 is 0.00012499088188633323: 100%|██████████| 33/33 [00:05<00:00,  5.94it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 7.803235348546878e-05: 100%|██████████| 33/33 [00:05<00:00,  5.94it/s]


Acc on val in epoch 17 is: 1.0


Current loss in epoch 18 is 8.251326653407887e-05: 100%|██████████| 33/33 [00:05<00:00,  5.95it/s]


Acc on val in epoch 18 is: 1.0


Current loss in epoch 19 is 0.00046260113595053554: 100%|██████████| 33/33 [00:05<00:00,  5.96it/s]


Acc on val in epoch 19 is: 1.0
Fianl metrics on test is:  {'acc': 0.8829209414604707, 'prec': 0.9555765595463138, 'sens': 0.8730569948186528, 'spec': np.float64(0.905811623246493), 'f1': 0.9124548736462094}
After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.08599117398262024: 100%|██████████| 31/31 [00:05<00:00,  5.79it/s]


Acc on val in epoch 0 is: 0.6238288174221321


Current loss in epoch 1 is 0.031801074743270874: 100%|██████████| 31/31 [00:05<00:00,  5.82it/s]


Acc on val in epoch 1 is: 0.9940491263611041


Current loss in epoch 2 is 0.0036728400737047195: 100%|██████████| 31/31 [00:05<00:00,  5.87it/s]


Acc on val in epoch 2 is: 0.960369713851608


Current loss in epoch 3 is 0.0556333027780056: 100%|██████████| 31/31 [00:05<00:00,  5.87it/s]


Acc on val in epoch 3 is: 0.9763231197771588


Current loss in epoch 4 is 0.01603630930185318: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s]


Acc on val in epoch 4 is: 0.9975943276778931


Current loss in epoch 5 is 0.0017513945931568742: 100%|██████████| 31/31 [00:05<00:00,  5.87it/s]


Acc on val in epoch 5 is: 0.9998733856672576


Current loss in epoch 6 is 0.0005390570731833577: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s]


Acc on val in epoch 6 is: 1.0


Current loss in epoch 7 is 0.0004692900402005762: 100%|██████████| 31/31 [00:05<00:00,  5.85it/s]


Acc on val in epoch 7 is: 1.0


Current loss in epoch 8 is 0.000429670384619385: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s]


Acc on val in epoch 8 is: 1.0


Current loss in epoch 9 is 0.0004370200040284544: 100%|██████████| 31/31 [00:05<00:00,  5.82it/s]


Acc on val in epoch 9 is: 1.0


Current loss in epoch 10 is 0.00042947212932631373: 100%|██████████| 31/31 [00:05<00:00,  5.66it/s]


Acc on val in epoch 10 is: 1.0


Current loss in epoch 11 is 0.00029918045038357377: 100%|██████████| 31/31 [00:05<00:00,  5.83it/s]


Acc on val in epoch 11 is: 1.0


Current loss in epoch 12 is 0.0001858683244790882: 100%|██████████| 31/31 [00:05<00:00,  5.71it/s]


Acc on val in epoch 12 is: 1.0


Current loss in epoch 13 is 0.00019524437084328383: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s]


Acc on val in epoch 13 is: 1.0


Current loss in epoch 14 is 0.00013815969577990472: 100%|██████████| 31/31 [00:05<00:00,  5.88it/s]


Acc on val in epoch 14 is: 1.0


Current loss in epoch 15 is 0.0002885018475353718: 100%|██████████| 31/31 [00:05<00:00,  5.83it/s]


Acc on val in epoch 15 is: 1.0


Current loss in epoch 16 is 0.000148935301695019: 100%|██████████| 31/31 [00:05<00:00,  5.86it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 9.922177559928969e-05: 100%|██████████| 31/31 [00:05<00:00,  5.78it/s]


Acc on val in epoch 17 is: 1.0


Current loss in epoch 18 is 0.00010614503116812557: 100%|██████████| 31/31 [00:05<00:00,  5.82it/s]


Acc on val in epoch 18 is: 1.0


Current loss in epoch 19 is 9.87760431598872e-05: 100%|██████████| 31/31 [00:05<00:00,  5.84it/s]


Acc on val in epoch 19 is: 1.0
Fianl metrics on test is:  {'acc': 0.927201180521397, 'prec': 0.9797898140662894, 'sens': 0.9078651685393259, 'spec': np.float64(0.9641833810888252), 'f1': 0.942457231726283}
After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.039605241268873215: 100%|██████████| 32/32 [00:05<00:00,  5.89it/s]


Acc on val in epoch 0 is: 0.685057182355159


Current loss in epoch 1 is 0.00593223050236702: 100%|██████████| 32/32 [00:05<00:00,  6.00it/s]


Acc on val in epoch 1 is: 0.9566419504838507


Current loss in epoch 2 is 0.03454827144742012: 100%|██████████| 32/32 [00:05<00:00,  5.97it/s]


Acc on val in epoch 2 is: 0.9988689204474048


Current loss in epoch 3 is 0.044482577592134476: 100%|██████████| 32/32 [00:05<00:00,  6.01it/s]


Acc on val in epoch 3 is: 0.9904486615558628


Current loss in epoch 4 is 0.012418564409017563: 100%|██████████| 32/32 [00:05<00:00,  5.99it/s]


Acc on val in epoch 4 is: 0.9950986552720875


Current loss in epoch 5 is 0.048366695642471313: 100%|██████████| 32/32 [00:05<00:00,  5.95it/s]


Acc on val in epoch 5 is: 0.9977378408948095


Current loss in epoch 6 is 0.02006375603377819: 100%|██████████| 32/32 [00:05<00:00,  5.98it/s]


Acc on val in epoch 6 is: 0.9969837878597462


Current loss in epoch 7 is 0.0058974530547857285: 100%|██████████| 32/32 [00:05<00:00,  5.96it/s]


Acc on val in epoch 7 is: 0.9899459595324871


Current loss in epoch 8 is 0.0008137910044752061: 100%|██████████| 32/32 [00:05<00:00,  5.96it/s]


Acc on val in epoch 8 is: 0.9998743244941561


Current loss in epoch 9 is 0.00038162534474395216: 100%|██████████| 32/32 [00:05<00:00,  5.95it/s]


Acc on val in epoch 9 is: 1.0


Current loss in epoch 10 is 0.010380364954471588: 100%|██████████| 32/32 [00:05<00:00,  5.95it/s]


Acc on val in epoch 10 is: 1.0


Current loss in epoch 11 is 0.004143791273236275: 100%|██████████| 32/32 [00:05<00:00,  5.98it/s]


Acc on val in epoch 11 is: 0.9994972979766243


Current loss in epoch 12 is 0.0002971053763758391: 100%|██████████| 32/32 [00:05<00:00,  6.00it/s]


Acc on val in epoch 12 is: 0.9987432449415609


Current loss in epoch 13 is 0.05712864547967911: 100%|██████████| 32/32 [00:05<00:00,  6.01it/s]


Acc on val in epoch 13 is: 0.9994972979766243


Current loss in epoch 14 is 0.02248665690422058: 100%|██████████| 32/32 [00:05<00:00,  5.97it/s]


Acc on val in epoch 14 is: 0.998617569435717


Current loss in epoch 15 is 0.0018943536560982466: 100%|██████████| 32/32 [00:05<00:00,  5.97it/s]


Acc on val in epoch 15 is: 0.9992459469649365


Current loss in epoch 16 is 0.008066062815487385: 100%|██████████| 32/32 [00:05<00:00,  5.99it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 0.00034711294574663043: 100%|██████████| 32/32 [00:05<00:00,  5.99it/s]


Acc on val in epoch 17 is: 1.0


Current loss in epoch 18 is 0.004350549541413784: 100%|██████████| 32/32 [00:05<00:00,  6.01it/s]


Acc on val in epoch 18 is: 1.0


Current loss in epoch 19 is 0.028981279581785202: 100%|██████████| 32/32 [00:05<00:00,  5.94it/s]


Acc on val in epoch 19 is: 0.9940932512253362
Fianl metrics on test is:  {'acc': 0.7502532928064843, 'prec': 0.7217832957110609, 'sens': 1.0, 'spec': np.float64(0.2906474820143885), 'f1': 0.8384136348738118}
After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.17844301462173462:  71%|███████   | 22/31 [00:04<00:01,  5.29it/s]


KeyboardInterrupt: 

In [23]:
import json
print(k_fold_acc)
save_dir = '/content/drive/MyDrive/logs'
os.makedirs(save_dir, exist_ok=True)

out_path = os.path.join(save_dir, 'k_fold_oldcnn_2.json')
with open(out_path, 'w') as f:
    json.dump(k_fold_acc, f, indent=2)

{0: {'acc': 0.9265175718849841, 'prec': 0.8923533778767632, 'sens': 0.986863711001642, 'spec': np.float64(0.8509763617677287), 'f1': 0.9372319688109162}, 1: {'acc': 0.8871454435727217, 'prec': 0.981169474727453, 'sens': 0.8549222797927462, 'spec': np.float64(0.9619238476953907), 'f1': 0.9137055837563451}, 2: {'acc': 0.9739301524840138, 'prec': 0.9678832116788321, 'sens': 0.9932584269662922, 'spec': np.float64(0.9369627507163324), 'f1': 0.9804066543438078}, 3: {'acc': 0.9199594731509625, 'prec': 0.8916841369671559, 'sens': 0.9976544175136826, 'spec': np.float64(0.7769784172661871), 'f1': 0.9416974169741698}, 4: {'acc': 0.9315992292870906, 'prec': 0.9608333333333333, 'sens': 0.9238782051282052, 'spec': np.float64(0.9432367149758454), 'f1': 0.9419934640522876}}


# Approximate the Compute and Memory Requirements

In [None]:
!pip install ptflops

Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->ptflops)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0->ptflops)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0->ptflops)
  Downloading nvidia_

In [None]:
from ptflops import get_model_complexity_info
model_1 = CNN_2D()
model_2 = CNN_1D()
model_1.to(device)
model_2.to(device)
macs, params = get_model_complexity_info(model_1, (12, 5000), as_strings=True,print_per_layer_stat=False, verbose=False)
print(f"The 2D Conv based model requires {macs} macs and has {params} params")

macs, params = get_model_complexity_info(model_2, (12, 5000), as_strings=True,print_per_layer_stat=False, verbose=False)
print(f"The 1D Conv based model requires {macs} macs and has {params} params")


The 2D Conv based model requires 7.25 GMac macs and has 3.68 M params
The 1D Conv based model requires 132.71 MMac macs and has 1.24 M params
