# Imports

In [None]:
from getpass import getpass

token = getpass('Enter your GitHub personal access token: ')
name = getpass('Enter your GitHub name: ')

!git clone https://{token}@github.com/{name}/comp_med.git
#%cd comp_med


Enter your GitHub personal access token: ··········
Enter your GitHub name: ··········
Cloning into 'comp_med'...
remote: Enumerating objects: 179, done.[K
remote: Counting objects: 100% (179/179), done.[K
remote: Compressing objects: 100% (145/145), done.[K
remote: Total 179 (delta 105), reused 81 (delta 32), pack-reused 0 (from 0)[K
Receiving objects: 100% (179/179), 1.02 MiB | 4.25 MiB/s, done.
Resolving deltas: 100% (105/105), done.


In [None]:
#!rm -r comp_med

## Make Code Deterministic for Reproducibility

In [3]:
import os
import random
import numpy as np
import torch
def set_seed(SEED):
  os.environ["PYTHONHASHSEED"] = str(SEED)
  random.seed(SEED)
  np.random.seed(SEED)
  torch.manual_seed(SEED)
  torch.cuda.manual_seed_all(SEED)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

In [4]:
!pip install wfdb

Collecting wfdb
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting pandas>=2.2.3 (from wfdb)
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Downloading wfdb-4.3.0-py3-none-any.whl (163 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m33.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pandas, wfdb
  Attempting uninstall: pandas
    Found existing installation: pandas 2.2.2
    Uninstalling pandas-2.2.2:
      Successfully uninstalled pandas-2.2.2
[31mERROR: pip's dependency resolver does not currently take into accoun

In [5]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [6]:
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
import torch.nn as nn
import json
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
)

#my imports
from comp_med.models.attentionCNN import CNN_1D
from comp_med.models.oldCNN import CNN_2D
from comp_med.data.preprocessing import get_dataloaders


## Save Function

In [42]:
def save_file(object_to_save = None, save_name=None, save_dir = '/content/drive/MyDrive/logs'):
  print(object_to_save)
  os.makedirs(save_dir, exist_ok=True)
  out_path = os.path.join(save_dir, save_name)
  with open(out_path, 'w') as f:
      json.dump(object_to_save, f, indent=2)

## Training Helpers

In [7]:
def train(model, criterion, optimizer, train_loader, val_loader, epochs, device="cpu", return_loss=False):
  model.to(device)
  loss_tracker = []
  for epoch in range(epochs):
    pbar = tqdm(train_loader, desc=f"Train the model in epoch {epoch}...")
    loss = 0
    for x,y in pbar:
      optimizer.zero_grad()
      x,y = x.to(device),y.to(device)
      out = model(x)
      out = torch.sigmoid(out)
      loss = criterion(out,y)
      loss.backward()
      optimizer.step()
      loss = loss.item()
      pbar.set_description(f"Current loss in epoch {epoch} is {loss}")
      loss_tracker.append(loss)
    #start validation
    acc = eval(model, val_loader, device)
    print(f"Acc on val in epoch {epoch} is: {acc}")
  if return_loss:
    return loss_tracker

def eval(model, data_loader, device="cpu", all_metrics=False):
    model.to(device)
    #start validation
    model.eval()
    preds = []
    labels = []
    with torch.no_grad():
      for x,y in data_loader:
        x,y = x.to(device), y.to(device)
        out = model(x)
        preds.append((torch.sigmoid(out) > 0.5).float())
        labels.append(y)

    model.train()
    preds = torch.cat(preds).cpu().numpy().ravel()
    labels = torch.cat(labels).cpu().numpy().ravel()
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds)
    sens = recall_score(labels, preds)
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    spec = tn/(tn + fp)
    f1 = f1_score(labels, preds)
    if all_metrics:
      return {"acc":acc, "prec":prec, "sens":sens, "spec":spec, "f1":f1}
    else:
      return acc

# Verification Dummy Task

In [8]:
# create a dummy dataset to verify the architecture works
batch_size = 256
length = 5
samples_per_sec = 1000
num_leads = 12
num_samples = 500

t = np.arange(0, length, 1/samples_per_sec)
base_freq = 1

X = np.zeros((num_samples, num_leads, t.size))
labels = np.zeros((num_samples,))

for i in range(num_samples):
  label = i % 2
  for l in range(num_leads):
    amplitude = l+1
    X[i,l] = amplitude * np.sin( 2*np.pi * base_freq * t)
  if label == 1:
    #get a random lead
    lead_idx = np.random.randint(0, num_leads)
    #doulbe the frequency when label is 1
    X[i, lead_idx] = (lead_idx+1) * np.sin( 2*np.pi * base_freq*2 * t)

  #add some noise
  noise = np.random.normal(0, 0.1, size=X[i].shape)
  X[i] = X[i]+noise
  labels[i]= label

X_t = torch.from_numpy(X).float()
y_t = torch.from_numpy(labels).unsqueeze(1).float()

len = X_t.shape[0]
permutation = torch.randperm(len)
train_idx = permutation[:400]
test_idx =permutation[400:]

X_train, y_train = X_t[train_idx], y_t[train_idx]
X_test, y_test = X_t[test_idx] ,y_t[test_idx]
print(X_t.shape)
print(X_test.shape)
print(X_train.shape)
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size)

torch.Size([500, 12, 5000])
torch.Size([100, 12, 5000])
torch.Size([400, 12, 5000])


In [9]:
epochs = 25
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNN_1D()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
train(model, criterion, optimizer, train_loader, train_loader ,epochs, device)
print("Fianl acc on test is: ", eval(model, test_loader, device))

Current loss in epoch 0 is 0.7298088073730469: 100%|██████████| 2/2 [00:01<00:00,  1.03it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Acc on val in epoch 0 is: 0.4775


Current loss in epoch 1 is 0.40963423252105713: 100%|██████████| 2/2 [00:00<00:00,  7.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Acc on val in epoch 1 is: 0.4775


Current loss in epoch 2 is 0.12633655965328217: 100%|██████████| 2/2 [00:00<00:00,  7.53it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Acc on val in epoch 2 is: 0.4775


Current loss in epoch 3 is 0.05722875893115997: 100%|██████████| 2/2 [00:00<00:00,  7.60it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Acc on val in epoch 3 is: 0.4775


Current loss in epoch 4 is 0.03046264313161373: 100%|██████████| 2/2 [00:00<00:00,  7.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Acc on val in epoch 4 is: 0.4775


Current loss in epoch 5 is 0.028351323679089546: 100%|██████████| 2/2 [00:00<00:00,  7.57it/s]


Acc on val in epoch 5 is: 0.6325


Current loss in epoch 6 is 0.019088348373770714: 100%|██████████| 2/2 [00:00<00:00,  7.55it/s]


Acc on val in epoch 6 is: 0.6825


Current loss in epoch 7 is 0.016611231490969658: 100%|██████████| 2/2 [00:00<00:00,  7.54it/s]


Acc on val in epoch 7 is: 0.595


Current loss in epoch 8 is 0.013602720573544502: 100%|██████████| 2/2 [00:00<00:00,  7.52it/s]


Acc on val in epoch 8 is: 0.745


Current loss in epoch 9 is 0.012127259746193886: 100%|██████████| 2/2 [00:00<00:00,  7.56it/s]


Acc on val in epoch 9 is: 0.9575


Current loss in epoch 10 is 0.009833808057010174: 100%|██████████| 2/2 [00:00<00:00,  7.57it/s]


Acc on val in epoch 10 is: 1.0


Current loss in epoch 11 is 0.010784504935145378: 100%|██████████| 2/2 [00:00<00:00,  7.54it/s]


Acc on val in epoch 11 is: 1.0


Current loss in epoch 12 is 0.008468355052173138: 100%|██████████| 2/2 [00:00<00:00,  7.49it/s]


Acc on val in epoch 12 is: 1.0


Current loss in epoch 13 is 0.008336307480931282: 100%|██████████| 2/2 [00:00<00:00,  7.50it/s]


Acc on val in epoch 13 is: 1.0


Current loss in epoch 14 is 0.006228478625416756: 100%|██████████| 2/2 [00:00<00:00,  7.52it/s]


Acc on val in epoch 14 is: 1.0


Current loss in epoch 15 is 0.005703845992684364: 100%|██████████| 2/2 [00:00<00:00,  7.56it/s]


Acc on val in epoch 15 is: 1.0


Current loss in epoch 16 is 0.006311472505331039: 100%|██████████| 2/2 [00:00<00:00,  7.57it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 0.005264902487397194: 100%|██████████| 2/2 [00:00<00:00,  7.52it/s]


Acc on val in epoch 17 is: 1.0


Current loss in epoch 18 is 0.004639027174562216: 100%|██████████| 2/2 [00:00<00:00,  7.56it/s]


Acc on val in epoch 18 is: 1.0


Current loss in epoch 19 is 0.003905044635757804: 100%|██████████| 2/2 [00:00<00:00,  7.48it/s]


Acc on val in epoch 19 is: 1.0


Current loss in epoch 20 is 0.0038583772256970406: 100%|██████████| 2/2 [00:00<00:00,  7.53it/s]


Acc on val in epoch 20 is: 1.0


Current loss in epoch 21 is 0.005719500593841076: 100%|██████████| 2/2 [00:00<00:00,  7.52it/s]


Acc on val in epoch 21 is: 1.0


Current loss in epoch 22 is 0.00314587471075356: 100%|██████████| 2/2 [00:00<00:00,  7.46it/s]


Acc on val in epoch 22 is: 1.0


Current loss in epoch 23 is 0.0037121449131518602: 100%|██████████| 2/2 [00:00<00:00,  7.53it/s]


Acc on val in epoch 23 is: 1.0


Current loss in epoch 24 is 0.003171222750097513: 100%|██████████| 2/2 [00:00<00:00,  7.55it/s]


Acc on val in epoch 24 is: 1.0
Fianl acc on test is:  1.0


# Convergence Analysis of the Model

In [10]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt", train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
Load data from given path


In [11]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [12]:
set_seed(0)
model = CNN_1D()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=True)
print("Fianl acc on test is: ", eval(model, test_loader, device))

Current loss in epoch 0 is 0.10174188017845154: 100%|██████████| 23/23 [00:05<00:00,  4.25it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.0340830460190773: 100%|██████████| 23/23 [00:05<00:00,  4.25it/s]


Acc on val in epoch 1 is: 0.6589018302828619


Current loss in epoch 2 is 0.01336760725826025: 100%|██████████| 23/23 [00:05<00:00,  4.26it/s]


Acc on val in epoch 2 is: 0.7396006655574043


Current loss in epoch 3 is 0.004478707443922758: 100%|██████████| 23/23 [00:05<00:00,  4.24it/s]


Acc on val in epoch 3 is: 0.7138103161397671


Current loss in epoch 4 is 0.0085354745388031: 100%|██████████| 23/23 [00:05<00:00,  4.26it/s]


Acc on val in epoch 4 is: 0.7695507487520798


Current loss in epoch 5 is 0.005845550913363695: 100%|██████████| 23/23 [00:05<00:00,  4.12it/s]


Acc on val in epoch 5 is: 0.7321131447587355


Current loss in epoch 6 is 0.0013205426512286067: 100%|██████████| 23/23 [00:05<00:00,  4.23it/s]


Acc on val in epoch 6 is: 0.8452579034941764


Current loss in epoch 7 is 0.03889819607138634: 100%|██████████| 23/23 [00:05<00:00,  4.21it/s]


Acc on val in epoch 7 is: 0.7612312811980033


Current loss in epoch 8 is 0.045491546392440796: 100%|██████████| 23/23 [00:05<00:00,  4.23it/s]


Acc on val in epoch 8 is: 0.8202995008319468


Current loss in epoch 9 is 0.04615143686532974: 100%|██████████| 23/23 [00:05<00:00,  4.21it/s]


Acc on val in epoch 9 is: 0.8094841930116472


Current loss in epoch 10 is 0.01300874538719654: 100%|██████████| 23/23 [00:05<00:00,  4.20it/s]


Acc on val in epoch 10 is: 0.7936772046589018


Current loss in epoch 11 is 0.0006851897342130542: 100%|██████████| 23/23 [00:05<00:00,  4.15it/s]


Acc on val in epoch 11 is: 0.7013311148086523


Current loss in epoch 12 is 0.0017312794225290418: 100%|██████████| 23/23 [00:05<00:00,  4.20it/s]


Acc on val in epoch 12 is: 0.697171381031614


Current loss in epoch 13 is 0.0003950475074816495: 100%|██████████| 23/23 [00:05<00:00,  4.21it/s]


Acc on val in epoch 13 is: 0.699667221297837


Current loss in epoch 14 is 0.001493047340773046: 100%|██████████| 23/23 [00:05<00:00,  4.20it/s]


Acc on val in epoch 14 is: 0.7204658901830283


Current loss in epoch 15 is 0.0004947695997543633: 100%|██████████| 23/23 [00:05<00:00,  4.21it/s]


Acc on val in epoch 15 is: 0.7154742096505824


Current loss in epoch 16 is 0.00027321375091560185: 100%|██████████| 23/23 [00:05<00:00,  4.22it/s]


Acc on val in epoch 16 is: 0.7071547420965059


Current loss in epoch 17 is 0.0002552543010096997: 100%|██████████| 23/23 [00:05<00:00,  4.23it/s]


Acc on val in epoch 17 is: 0.7054908485856906


Current loss in epoch 18 is 0.00023154637892730534: 100%|██████████| 23/23 [00:05<00:00,  4.19it/s]


Acc on val in epoch 18 is: 0.7063227953410982


Current loss in epoch 19 is 0.00016317047993652523: 100%|██████████| 23/23 [00:05<00:00,  4.23it/s]


Acc on val in epoch 19 is: 0.7054908485856906
Fianl acc on test is:  0.9325691166554282


In [13]:
print(losses)

[0.6905691623687744, 0.786510705947876, 0.7059726715087891, 0.6001118421554565, 0.47714635729789734, 0.4209638833999634, 0.3082520067691803, 0.2662844657897949, 0.2587848901748657, 0.28281304240226746, 0.1781378537416458, 0.19684909284114838, 0.1342707872390747, 0.15646979212760925, 0.11259046941995621, 0.11482617259025574, 0.13019491732120514, 0.13390281796455383, 0.10805634409189224, 0.08212548494338989, 0.09083904325962067, 0.09195514023303986, 0.10174188017845154, 0.12215659022331238, 0.09455606341362, 0.06379201263189316, 0.08316996693611145, 0.0641225129365921, 0.08017149567604065, 0.06020168587565422, 0.07326832413673401, 0.07447246462106705, 0.048916351050138474, 0.038726530969142914, 0.03138140216469765, 0.049165062606334686, 0.06500668078660965, 0.03210264444351196, 0.03014521114528179, 0.04445832967758179, 0.02848118357360363, 0.03116218000650406, 0.024265948683023453, 0.021523119881749153, 0.021911637857556343, 0.0340830460190773, 0.017026258632540703, 0.030514370650053024,

# Sensitivity Analysis Towards Data Perturbation
We add random noise $\mathcal{N}(0,\sigma^2)$ to the input of the model and evaluate how robust the model is.

In [14]:
trials = 3
epochs = 20
lr = 1e-3
seeds = [0,1,2]

In [15]:
def eval_with_perturbation(model, data_loader, std_levels=[1e-3,5e-3,1e-2,5e-2,1e-1,5e-1,1], device="cpu"):
    accs = []
    for std in std_levels:
      model.to(device)
      #start validation
      model.eval()
      correct = 0
      total = 0
      with torch.no_grad():
        for x,y in data_loader:
          x,y = x.to(device), y.to(device)
          # add noise
          x = x + torch.randn_like(x) * std
          out = model(x)
          preds = (torch.sigmoid(out) > 0.5).float()
          correct += (preds == y).sum().item()
          total += y.shape[0]
      model.train()
      acc = correct / total
      accs.append(acc)
    return accs

In [16]:
trial_to_acc = {}
for trial,seed in zip(range(trials),seeds):
  set_seed(seed)
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval_with_perturbation(model, test_loader, device=device)
  trial_to_acc[trial] = accs

Current loss in epoch 0 is 0.10174188017845154: 100%|██████████| 23/23 [00:05<00:00,  4.24it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.0340830460190773: 100%|██████████| 23/23 [00:05<00:00,  4.24it/s]


Acc on val in epoch 1 is: 0.6589018302828619


Current loss in epoch 2 is 0.01336760725826025: 100%|██████████| 23/23 [00:05<00:00,  4.22it/s]


Acc on val in epoch 2 is: 0.7396006655574043


Current loss in epoch 3 is 0.004478707443922758: 100%|██████████| 23/23 [00:05<00:00,  4.22it/s]


Acc on val in epoch 3 is: 0.7138103161397671


Current loss in epoch 4 is 0.0085354745388031: 100%|██████████| 23/23 [00:05<00:00,  4.24it/s]


Acc on val in epoch 4 is: 0.7695507487520798


Current loss in epoch 5 is 0.005845550913363695: 100%|██████████| 23/23 [00:05<00:00,  4.20it/s]


Acc on val in epoch 5 is: 0.7321131447587355


Train the model in epoch 6...:   0%|          | 0/23 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
save_file(object_to_save=trial_to_acc, save_name="noise_perturbation.json")

# Sensetivity Analysis Towards Hyperparameters
We evaluate the sensitivity towards hyperparameters. We execute every experiment 3 times and report the average

In [None]:
seeds = [0,0,0] #always use the same random seed for every run

## Kernel Size
We evaluate different kernel sizes  $[50,100,200]$

In [17]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
kernel_sizes = [25,50,200]
kernel_to_acc = {}
seeds = [0,0,0]
for kernel_size, seed in zip(kernel_sizes,seeds):
  set_seed(seed)
  model = CNN_1D(kernel_size=kernel_size)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  kernel_to_acc[kernel_size] = acc

Current loss in epoch 0 is 0.36201754212379456:  61%|██████    | 14/23 [00:03<00:02,  4.13it/s]


KeyboardInterrupt: 

In [None]:

save_file(object_to_save=kernel_to_acc, save_name="kernel_size.json")

{25: 0.9086311530681052, 50: 0.9265003371544167, 200: 0.9369521240728254}


## Stride
We evaluate different strides $[25,50,100]$

In [19]:
strides = [25,100,200]
stride_to_acc = {}
seeds = [0,0,0]
for stride, seed in zip(strides,seeds):
  set_seed(seed)
  model = CNN_1D(stride = stride)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  stride_to_acc[stride] = acc

Current loss in epoch 0 is 0.21232058107852936:  48%|████▊     | 11/23 [00:04<00:05,  2.21it/s]


KeyboardInterrupt: 

In [None]:

save_file(object_to_save=stride_to_acc, save_name="strides.json")

{25: 0.897167902899528, 100: 0.9416722859069454, 200: 0.914025623735671}


## Attention Heads
We evaluate different amount of attention heads $[2,4,8]$

In [20]:
attention_heads = [1,2,10]
head_to_acc = {}
seeds = [0,0,0]
for heads,seed in zip(attention_heads,seeds):
  set_seed(seed)
  model = CNN_1D(attn_heads=heads)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  head_to_acc[heads] = acc

Current loss in epoch 0 is 0.22072342038154602:  52%|█████▏    | 12/23 [00:02<00:02,  5.08it/s]


KeyboardInterrupt: 

In [None]:

save_file(object_to_save=head_to_acc, save_name="attention_heads.json")

{10: 0.9204315576534052, 1: 0.928186109238031, 2: 0.928186109238831}


## Filter Size
We investigate different filter sizes $[10,20,40]$

In [21]:
filter_sizes = [4,12,40]
filter_to_acc = {}
seeds = [0,0,0]
for filter,seed in zip(filter_sizes,seeds):
  set_seed(seed)
  model = CNN_1D(filters=filter)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device, return_loss=False)
  acc=eval(model, test_loader, device)
  print("Fianl acc on test is: ",acc )
  filter_to_acc[filter] = acc

Current loss in epoch 0 is 0.5765625834465027:  74%|███████▍  | 17/23 [00:02<00:00,  7.31it/s]


KeyboardInterrupt: 

In [None]:

save_file(object_to_save=filter_to_acc, save_name="filters.json")

{4: 0.899527983816588, 12: 0.9359406608226568, 40: 0.9221173297370195}


## 3-Lead Ablation

In [23]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data_3leads.pt", desired_leads=['i','ii','iii'],train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
Load data from given path


In [24]:
trial_to_acc = {}
seeds = [0,1,2]
trials = 3
lr = 0.001
epochs = 20
device = "cuda" if torch.cuda.is_available() else "cpu"
for trial,seed in zip(range(trials),seeds):
  set_seed(seed)
  model = CNN_1D(num_leads = 3)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval(model, test_loader, device=device)
  trial_to_acc[trial] = accs

Current loss in epoch 0 is 0.21026206016540527: 100%|██████████| 23/23 [00:02<00:00,  9.16it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.18960320949554443: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.06585603207349777: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 2 is: 0.6821963394342762


Current loss in epoch 3 is 0.040490951389074326: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 3 is: 0.7096505823627288


Current loss in epoch 4 is 0.02664928510785103: 100%|██████████| 23/23 [00:02<00:00,  9.28it/s]


Acc on val in epoch 4 is: 0.7346089850249584


Current loss in epoch 5 is 0.01232950109988451: 100%|██████████| 23/23 [00:02<00:00,  9.29it/s]


Acc on val in epoch 5 is: 0.8352745424292846


Current loss in epoch 6 is 0.00760186742991209: 100%|██████████| 23/23 [00:02<00:00,  9.21it/s]


Acc on val in epoch 6 is: 0.8494176372712147


Current loss in epoch 7 is 0.006488289684057236: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 7 is: 0.7645590682196339


Current loss in epoch 8 is 0.00986521691083908: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 8 is: 0.829450915141431


Current loss in epoch 9 is 0.0047140903770923615: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 9 is: 0.8610648918469218


Current loss in epoch 10 is 0.0034715929068624973: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 10 is: 0.8386023294509152


Current loss in epoch 11 is 0.002896765945479274: 100%|██████████| 23/23 [00:02<00:00,  9.20it/s]


Acc on val in epoch 11 is: 0.8469217970049917


Current loss in epoch 12 is 0.0026915855705738068: 100%|██████████| 23/23 [00:02<00:00,  9.28it/s]


Acc on val in epoch 12 is: 0.8502495840266223


Current loss in epoch 13 is 0.004025511909276247: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 13 is: 0.8768718801996672


Current loss in epoch 14 is 0.002844566246494651: 100%|██████████| 23/23 [00:02<00:00,  9.26it/s]


Acc on val in epoch 14 is: 0.8652246256239601


Current loss in epoch 15 is 0.0017299213213846087: 100%|██████████| 23/23 [00:02<00:00,  9.19it/s]


Acc on val in epoch 15 is: 0.867720465890183


Current loss in epoch 16 is 0.0016327719204127789: 100%|██████████| 23/23 [00:02<00:00,  9.26it/s]


Acc on val in epoch 16 is: 0.8718801996672213


Current loss in epoch 17 is 0.0023760602343827486: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 17 is: 0.8843594009983361


Current loss in epoch 18 is 0.002012839773669839: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 18 is: 0.846089850249584


Current loss in epoch 19 is 0.001412057550624013: 100%|██████████| 23/23 [00:02<00:00,  9.25it/s]


Acc on val in epoch 19 is: 0.8635607321131448


Current loss in epoch 0 is 0.29373472929000854: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.10629494488239288: 100%|██████████| 23/23 [00:02<00:00,  9.29it/s]


Acc on val in epoch 1 is: 0.6530782029950083


Current loss in epoch 2 is 0.049200646579265594: 100%|██████████| 23/23 [00:02<00:00,  9.25it/s]


Acc on val in epoch 2 is: 0.6871880199667221


Current loss in epoch 3 is 0.04368804767727852: 100%|██████████| 23/23 [00:02<00:00,  9.28it/s]


Acc on val in epoch 3 is: 0.8494176372712147


Current loss in epoch 4 is 0.024073699489235878: 100%|██████████| 23/23 [00:02<00:00,  9.16it/s]


Acc on val in epoch 4 is: 0.8860232945091514


Current loss in epoch 5 is 0.016197917982935905: 100%|██████████| 23/23 [00:02<00:00,  9.25it/s]


Acc on val in epoch 5 is: 0.6980033277870217


Current loss in epoch 6 is 0.008453449234366417: 100%|██████████| 23/23 [00:02<00:00,  9.26it/s]


Acc on val in epoch 6 is: 0.8635607321131448


Current loss in epoch 7 is 0.010017390362918377: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 7 is: 0.8402662229617305


Current loss in epoch 8 is 0.016455333679914474: 100%|██████████| 23/23 [00:02<00:00,  9.21it/s]


Acc on val in epoch 8 is: 0.802828618968386


Current loss in epoch 9 is 0.010801819153130054: 100%|██████████| 23/23 [00:02<00:00,  9.18it/s]


Acc on val in epoch 9 is: 0.8053244592346089


Current loss in epoch 10 is 0.013165338896214962: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 10 is: 0.9276206322795341


Current loss in epoch 11 is 0.007458628620952368: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 11 is: 0.872712146422629


Current loss in epoch 12 is 0.011933313682675362: 100%|██████████| 23/23 [00:02<00:00,  9.24it/s]


Acc on val in epoch 12 is: 0.8186356073211315


Current loss in epoch 13 is 0.0036012723576277494: 100%|██████████| 23/23 [00:02<00:00,  9.14it/s]


Acc on val in epoch 13 is: 0.9043261231281198


Current loss in epoch 14 is 0.003511768765747547: 100%|██████████| 23/23 [00:02<00:00,  9.20it/s]


Acc on val in epoch 14 is: 0.872712146422629


Current loss in epoch 15 is 0.0015689053107053041: 100%|██████████| 23/23 [00:02<00:00,  9.20it/s]


Acc on val in epoch 15 is: 0.8951747088186356


Current loss in epoch 16 is 0.004451868124306202: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 16 is: 0.8635607321131448


Current loss in epoch 17 is 0.0011909096501767635: 100%|██████████| 23/23 [00:02<00:00,  9.16it/s]


Acc on val in epoch 17 is: 0.9059900166389351


Current loss in epoch 18 is 0.0015271517913788557: 100%|██████████| 23/23 [00:02<00:00,  9.16it/s]


Acc on val in epoch 18 is: 0.8926788685524126


Current loss in epoch 19 is 0.0012027017073705792: 100%|██████████| 23/23 [00:02<00:00,  9.21it/s]


Acc on val in epoch 19 is: 0.889351081530782


Current loss in epoch 0 is 0.3383757770061493: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.15729448199272156: 100%|██████████| 23/23 [00:02<00:00,  9.15it/s]


Acc on val in epoch 1 is: 0.6281198003327787


Current loss in epoch 2 is 0.06048832833766937: 100%|██████████| 23/23 [00:02<00:00,  9.14it/s]


Acc on val in epoch 2 is: 0.6306156405990017


Current loss in epoch 3 is 0.03662187233567238: 100%|██████████| 23/23 [00:02<00:00,  9.19it/s]


Acc on val in epoch 3 is: 0.6713810316139767


Current loss in epoch 4 is 0.025121383368968964: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 4 is: 0.7038269550748752


Current loss in epoch 5 is 0.013403444550931454: 100%|██████████| 23/23 [00:02<00:00,  9.25it/s]


Acc on val in epoch 5 is: 0.8851913477537438


Current loss in epoch 6 is 0.012103383429348469: 100%|██████████| 23/23 [00:02<00:00,  9.20it/s]


Acc on val in epoch 6 is: 0.7703826955074875


Current loss in epoch 7 is 0.006069548428058624: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 7 is: 0.8302828618968386


Current loss in epoch 8 is 0.004905888345092535: 100%|██████████| 23/23 [00:02<00:00,  9.25it/s]


Acc on val in epoch 8 is: 0.8327787021630616


Current loss in epoch 9 is 0.005028172396123409: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 9 is: 0.8286189683860233


Current loss in epoch 10 is 0.07164298743009567: 100%|██████████| 23/23 [00:02<00:00,  9.18it/s]


Acc on val in epoch 10 is: 0.8344425956738769


Current loss in epoch 11 is 0.07983115315437317: 100%|██████████| 23/23 [00:02<00:00,  9.17it/s]


Acc on val in epoch 11 is: 0.7113144758735441


Current loss in epoch 12 is 0.01433236338198185: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 12 is: 0.8635607321131448


Current loss in epoch 13 is 0.004497188609093428: 100%|██████████| 23/23 [00:02<00:00,  9.19it/s]


Acc on val in epoch 13 is: 0.831946755407654


Current loss in epoch 14 is 0.003128932323306799: 100%|██████████| 23/23 [00:02<00:00,  9.23it/s]


Acc on val in epoch 14 is: 0.870216306156406


Current loss in epoch 15 is 0.002377311233431101: 100%|██████████| 23/23 [00:02<00:00,  9.19it/s]


Acc on val in epoch 15 is: 0.8876871880199667


Current loss in epoch 16 is 0.005788018926978111: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 16 is: 0.872712146422629


Current loss in epoch 17 is 0.0017737727612257004: 100%|██████████| 23/23 [00:02<00:00,  9.26it/s]


Acc on val in epoch 17 is: 0.891846921797005


Current loss in epoch 18 is 0.0030544346664100885: 100%|██████████| 23/23 [00:02<00:00,  9.21it/s]


Acc on val in epoch 18 is: 0.8926788685524126


Current loss in epoch 19 is 0.001591993379406631: 100%|██████████| 23/23 [00:02<00:00,  9.22it/s]


Acc on val in epoch 19 is: 0.8968386023294509


In [None]:

save_file(object_to_save=trial_to_acc, save_name="3_lead_experiment.json")

{0: 0.9207687120701281, 1: 0.9238031018206339, 2: 0.9177343223196224}


In [28]:

print("Average is: ", (trial_to_acc[0] + trial_to_acc[1] + trial_to_acc[2]) / 3)

Average is:  0.9275118004045853


In [29]:
train_loader, val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt",train_ratio=0.6, val_ratio=0.1)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Patients: train: 120 | val: 20 | test: 60
Load data from given path


In [30]:
trial_to_acc = {}
seeds = [0,1,2]
trials = 3
lr = 0.001
epochs = 20
device = "cuda" if torch.cuda.is_available() else "cpu"
for trial,seed in zip(range(trials),seeds):
  set_seed(seed)
  model = CNN_1D(num_leads = 12)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, val_loader ,epochs, device)
  accs = eval(model, test_loader, device=device)
  trial_to_acc[trial] = accs

Current loss in epoch 0 is 0.10174188017845154: 100%|██████████| 23/23 [00:05<00:00,  4.21it/s]


Acc on val in epoch 0 is: 0.6530782029950083


Current loss in epoch 1 is 0.06020168587565422:  30%|███       | 7/23 [00:01<00:03,  4.08it/s]


KeyboardInterrupt: 

In [31]:

save_file(object_to_save=trial_to_acc, save_name="12_lead_experiment.json")

{}


In [32]:
print("Average is: ", (trial_to_acc[0] + trial_to_acc[1] + trial_to_acc[2]) / 3)

KeyError: 0

# 5-Fold Cross Validation
In this section we compare our model to the 2D CNN model on 5-Fold Cross Validation

In [33]:
seeds = [0,0,0]
#my imports
from comp_med.data.preprocessing import get_record_paths, filter_records, split_patients

#get patient data and split on my own
path="/content/drive/MyDrive/ptbdb"
records = get_record_paths(path)
filtered_records = filter_records(records)
#now obtain k-fold split
train_folds, test_folds = split_patients(records, k_fold=5)

After filtering, we got: 228 records. Healthy: 80, Disease: 148
k_fold cv has no validation set


In [34]:
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

In [35]:
k_fold_acc = {}
for i, train_ids, test_ids in zip(range(5),train_folds,test_folds):
  set_seed(0)
  train_loader,val_loader, test_loader = get_dataloaders("/content/drive/MyDrive/ptbdb",preprocessed_data_path="/content/drive/MyDrive/ptbdb/preprocessed_data.pt",train_ids=train_ids, test_ids = test_ids, val_ids=[],train_ratio=0.8, val_ratio=0)
  model = CNN_1D()
  optimizer = optim.Adam(model.parameters(), lr=lr)
  criterion = nn.BCELoss()
  losses = train(model, criterion, optimizer, train_loader, train_loader ,epochs, device, return_loss=False) #train loader as val because we don't have val laoder
  metrics=eval(model, test_loader, device, all_metrics=True)
  print("Fianl metrics on test is: ",metrics )
  k_fold_acc[i] = metrics
  del model, optimizer, criterion
  del train_loader, val_loader, test_loader
  torch.cuda.empty_cache()

After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.13805902004241943: 100%|██████████| 31/31 [00:07<00:00,  4.26it/s]


Acc on val in epoch 0 is: 0.6649870801033592


Current loss in epoch 1 is 0.03806470334529877: 100%|██████████| 31/31 [00:07<00:00,  4.25it/s]


Acc on val in epoch 1 is: 0.9984496124031008


Current loss in epoch 2 is 0.006284705363214016: 100%|██████████| 31/31 [00:07<00:00,  4.22it/s]


Acc on val in epoch 2 is: 0.9828165374677003


Current loss in epoch 3 is 0.06308766454458237: 100%|██████████| 31/31 [00:07<00:00,  4.15it/s]


Acc on val in epoch 3 is: 0.9957364341085271


Current loss in epoch 4 is 0.006218987982720137: 100%|██████████| 31/31 [00:07<00:00,  4.17it/s]


Acc on val in epoch 4 is: 0.9981912144702842


Current loss in epoch 5 is 0.009393931366503239: 100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Acc on val in epoch 5 is: 0.9998708010335917


Current loss in epoch 6 is 0.01288131158798933: 100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Acc on val in epoch 6 is: 0.9943152454780362


Current loss in epoch 7 is 0.08377961069345474: 100%|██████████| 31/31 [00:07<00:00,  4.06it/s]


Acc on val in epoch 7 is: 0.9996124031007751


Current loss in epoch 8 is 0.0014467876171693206: 100%|██████████| 31/31 [00:07<00:00,  4.18it/s]


Acc on val in epoch 8 is: 0.9968992248062015


Current loss in epoch 9 is 0.0014182310551404953: 100%|██████████| 31/31 [00:07<00:00,  4.18it/s]


Acc on val in epoch 9 is: 0.9998708010335917


Current loss in epoch 10 is 0.0002068759931717068: 100%|██████████| 31/31 [00:07<00:00,  4.18it/s]


Acc on val in epoch 10 is: 0.9998708010335917


Current loss in epoch 11 is 0.07233093678951263: 100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Acc on val in epoch 11 is: 0.9748062015503876


Current loss in epoch 12 is 0.03987053409218788: 100%|██████████| 31/31 [00:07<00:00,  4.21it/s]


Acc on val in epoch 12 is: 0.9425064599483204


Current loss in epoch 13 is 0.0021505102049559355: 100%|██████████| 31/31 [00:07<00:00,  4.21it/s]


Acc on val in epoch 13 is: 0.9998708010335917


Current loss in epoch 14 is 0.00033421520492993295: 100%|██████████| 31/31 [00:07<00:00,  4.21it/s]


Acc on val in epoch 14 is: 1.0


Current loss in epoch 15 is 0.00513937184587121: 100%|██████████| 31/31 [00:07<00:00,  4.18it/s]


Acc on val in epoch 15 is: 1.0


Current loss in epoch 16 is 0.004501892253756523: 100%|██████████| 31/31 [00:07<00:00,  4.18it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 0.00018135250138584524: 100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Acc on val in epoch 17 is: 1.0


Current loss in epoch 18 is 0.0001603780110599473: 100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Acc on val in epoch 18 is: 1.0


Current loss in epoch 19 is 8.76116901054047e-05: 100%|██████████| 31/31 [00:07<00:00,  4.20it/s]


Acc on val in epoch 19 is: 1.0
Fianl metrics on test is:  {'acc': 0.9406663623916021, 'prec': 0.906576980568012, 'sens': 0.9958949096880131, 'spec': np.float64(0.8715313463514902), 'f1': 0.9491392801251957}
After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.09422256052494049: 100%|██████████| 33/33 [00:07<00:00,  4.20it/s]


Acc on val in epoch 0 is: 0.6146966400773507


Current loss in epoch 1 is 0.01314585655927658: 100%|██████████| 33/33 [00:07<00:00,  4.20it/s]


Acc on val in epoch 1 is: 0.9978245105148659


Current loss in epoch 2 is 0.02754082717001438: 100%|██████████| 33/33 [00:07<00:00,  4.17it/s]


Acc on val in epoch 2 is: 0.9990331157843848


Current loss in epoch 3 is 0.0067322347313165665: 100%|██████████| 33/33 [00:07<00:00,  4.17it/s]


Acc on val in epoch 3 is: 0.991781484167271


Current loss in epoch 4 is 0.005830859299749136: 100%|██████████| 33/33 [00:07<00:00,  4.14it/s]


Acc on val in epoch 4 is: 0.9757070340826686


Current loss in epoch 5 is 0.0021013692021369934: 100%|██████████| 33/33 [00:07<00:00,  4.19it/s]


Acc on val in epoch 5 is: 0.9977036499879139


Current loss in epoch 6 is 0.00403080927208066: 100%|██████████| 33/33 [00:07<00:00,  4.18it/s]


Acc on val in epoch 6 is: 0.9980662315687696


Current loss in epoch 7 is 0.0039741480723023415: 100%|██████████| 33/33 [00:07<00:00,  4.20it/s]


Acc on val in epoch 7 is: 0.9991539763113367


Current loss in epoch 8 is 0.0007364654447883368: 100%|██████████| 33/33 [00:07<00:00,  4.18it/s]


Acc on val in epoch 8 is: 1.0


Current loss in epoch 9 is 0.00019686459563672543: 100%|██████████| 33/33 [00:07<00:00,  4.20it/s]


Acc on val in epoch 9 is: 1.0


Current loss in epoch 10 is 0.00014797848416492343: 100%|██████████| 33/33 [00:07<00:00,  4.22it/s]


Acc on val in epoch 10 is: 1.0


Current loss in epoch 11 is 0.00011316651216475293: 100%|██████████| 33/33 [00:07<00:00,  4.20it/s]


Acc on val in epoch 11 is: 1.0


Current loss in epoch 12 is 0.0001240532728843391: 100%|██████████| 33/33 [00:08<00:00,  4.09it/s]


Acc on val in epoch 12 is: 1.0


Current loss in epoch 13 is 0.00011658339644782245: 100%|██████████| 33/33 [00:07<00:00,  4.20it/s]


Acc on val in epoch 13 is: 1.0


Current loss in epoch 14 is 0.00019878061721101403: 100%|██████████| 33/33 [00:07<00:00,  4.15it/s]


Acc on val in epoch 14 is: 1.0


Current loss in epoch 15 is 0.0008450507884845138: 100%|██████████| 33/33 [00:07<00:00,  4.17it/s]


Acc on val in epoch 15 is: 1.0


Current loss in epoch 16 is 9.349943138659e-05: 100%|██████████| 33/33 [00:07<00:00,  4.17it/s]


Acc on val in epoch 16 is: 1.0


Current loss in epoch 17 is 5.2680181397590786e-05: 100%|██████████| 33/33 [00:07<00:00,  4.19it/s]


Acc on val in epoch 17 is: 1.0


Current loss in epoch 18 is 0.0008098270045593381: 100%|██████████| 33/33 [00:07<00:00,  4.18it/s]


Acc on val in epoch 18 is: 1.0


Current loss in epoch 19 is 5.075513035990298e-05: 100%|██████████| 33/33 [00:07<00:00,  4.14it/s]


Acc on val in epoch 19 is: 1.0
Fianl metrics on test is:  {'acc': 0.8823174411587206, 'prec': 0.9555345316934721, 'sens': 0.8721934369602763, 'spec': np.float64(0.905811623246493), 'f1': 0.9119638826185101}
After filtering, we got: 228 records. Healthy: 80, Disease: 148
Load data from given path


Current loss in epoch 0 is 0.057710837572813034: 100%|██████████| 31/31 [00:07<00:00,  4.10it/s]


Acc on val in epoch 0 is: 0.6223094454292226


Current loss in epoch 1 is 0.011198568157851696: 100%|██████████| 31/31 [00:07<00:00,  4.12it/s]


Acc on val in epoch 1 is: 0.8971891618131173


Current loss in epoch 2 is 0.029252244159579277: 100%|██████████| 31/31 [00:07<00:00,  4.12it/s]


Acc on val in epoch 2 is: 0.9922765257027095


Current loss in epoch 3 is 0.013498978689312935: 100%|██████████| 31/31 [00:07<00:00,  4.10it/s]


Acc on val in epoch 3 is: 0.9969612560141808


Current loss in epoch 4 is 0.0024352092295885086: 100%|██████████| 31/31 [00:07<00:00,  4.09it/s]


Acc on val in epoch 4 is: 0.9994935426690301


Current loss in epoch 5 is 0.0020600049756467342: 100%|██████████| 31/31 [00:07<00:00,  4.11it/s]


Acc on val in epoch 5 is: 1.0


Current loss in epoch 6 is 0.000935537158511579: 100%|██████████| 31/31 [00:07<00:00,  4.13it/s]


Acc on val in epoch 6 is: 1.0


Current loss in epoch 7 is 0.00046167417895048857:  84%|████████▍ | 26/31 [00:06<00:01,  4.06it/s]


KeyboardInterrupt: 

In [43]:
k_fold_acc = {0: {'acc': 0.9265175718849841, 'prec': 0.8923533778767632, 'sens': 0.986863711001642, 'spec': np.float64(0.8509763617677287), 'f1': 0.9372319688109162}, 1: {'acc': 0.8871454435727217, 'prec': 0.981169474727453, 'sens': 0.8549222797927462, 'spec': np.float64(0.9619238476953907), 'f1': 0.9137055837563451}, 2: {'acc': 0.9739301524840138, 'prec': 0.9678832116788321, 'sens': 0.9932584269662922, 'spec': np.float64(0.9369627507163324), 'f1': 0.9804066543438078}, 3: {'acc': 0.9199594731509625, 'prec': 0.8916841369671559, 'sens': 0.9976544175136826, 'spec': np.float64(0.7769784172661871), 'f1': 0.9416974169741698}, 4: {'acc': 0.9315992292870906, 'prec': 0.9608333333333333, 'sens': 0.9238782051282052, 'spec': np.float64(0.9432367149758454), 'f1': 0.9419934640522876}}
save_file(object_to_save=k_fold_acc, save_name="k_fold_oldcnn_2.json")

{0: {'acc': 0.9265175718849841, 'prec': 0.8923533778767632, 'sens': 0.986863711001642, 'spec': np.float64(0.8509763617677287), 'f1': 0.9372319688109162}, 1: {'acc': 0.8871454435727217, 'prec': 0.981169474727453, 'sens': 0.8549222797927462, 'spec': np.float64(0.9619238476953907), 'f1': 0.9137055837563451}, 2: {'acc': 0.9739301524840138, 'prec': 0.9678832116788321, 'sens': 0.9932584269662922, 'spec': np.float64(0.9369627507163324), 'f1': 0.9804066543438078}, 3: {'acc': 0.9199594731509625, 'prec': 0.8916841369671559, 'sens': 0.9976544175136826, 'spec': np.float64(0.7769784172661871), 'f1': 0.9416974169741698}, 4: {'acc': 0.9315992292870906, 'prec': 0.9608333333333333, 'sens': 0.9238782051282052, 'spec': np.float64(0.9432367149758454), 'f1': 0.9419934640522876}}


In [44]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Approximate the Compute and Memory Requirements

In [37]:
!pip install ptflops

Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->ptflops)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0->ptflops)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0->ptflops)
  Downloading nvidia_

In [39]:
from ptflops import get_model_complexity_info
model_1 = CNN_2D()
model_2 = CNN_1D()
model_1.to(device)
model_2.to(device)
macs, params = get_model_complexity_info(model_1, (12, 5000), as_strings=True,print_per_layer_stat=False, verbose=False)
print(f"The 2D Conv based model requires {macs} macs and has {params} params")

macs, params = get_model_complexity_info(model_2, (12, 5000), as_strings=True,print_per_layer_stat=False, verbose=False)
print(f"The 1D Conv based model requires {macs} macs and has {params} params")


The 2D Conv based model requires 7.25 GMac macs and has 3.68 M params
The 1D Conv based model requires 132.71 MMac macs and has 1.24 M params
