# Model Processing

## Imports & General Settings 

In [1]:
import unittest
import os
import sys
import time
import pathlib

import matplotlib.pyplot as plt
import sklearn
import pandas as pd
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn import metrics
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage
import PIL
from tqdm.notebook import trange, tqdm
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from matplotlib import pyplot

# Our imports
from data import WaveletTransform, AFECGDataset
import dsp
from model.blocks import ConvNet, BRNN, SoftmaxAttention
from model.baseline import Baseline
from training import train, test
import utils


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


## Dataset creation

In [3]:
dataset_name = 'afdb'
dataset = AFECGDataset(dataset_name, '../data/files/')

In [4]:
dataset.load()

23it [00:00, 217.71it/s]

Preparing 1397 samples


1397it [00:06, 212.71it/s]


Elapsed time: 6569.844961166382 ms
Skipped 1397 files which had a backup


In [None]:
data_af, label_af = dataset[0]
data_nsr, label_nsr = dataset[1]

t = data_nsr[0]
utils.show_spectrogram(t)

In [None]:
images_per_sample = 20
total_data_size = len(dataset)
print("Total data size: ", total_data_size)
print("Patients with AF: ", dataset.labels.sum().item())

In [None]:
# data = [dataset[i][0] for i in range(total_data_size)]
# labels = [dataset[i][1] for i in range(total_data_size)]

### Example of one ECG sample

In [None]:
# samples, label = data[0], labels[0]
# print('P-signal: ', samples)
# print('Has AF: ', 'Yes' if label == 1 else 'No')

In [None]:
# to_wavelet = WaveletTransform(wavelet.Morlet(6), resample=20)
# t = to_wavelet(data[0][0])
# image_test = (t * 100 * 255).int() # Simple visualization test
# transforms.ToPILImage()(image_test).show()

##  Wavelet Transform

In [None]:
# Total data size is 1397
# You can choose the data size 
data_size = len(dataset)

In [None]:
# x0 = x_train[0][0].float()
# encoder_cnn = ConvNet((375, 20))

# display(x0.unsqueeze(0).shape)
# h = encoder_cnn(x0.unsqueeze(0))
# print(h.shape)

# test.assertEqual(h.dim(), 2)
# test.assertSequenceEqual(h.shape, (1, 50))

## BRNN

In [None]:
display(BRNN(50, 50, images_per_sample))

## Attention

Notations:

* $Y = \left[ y_1, \ldots, y_T \right]$ – the input matrix of size $\left( N \times T \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $w_\mathrm{att}$ – The parameters of the attention model, of size $\left( N \times 1 \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $\alpha$ – The attention weights, given as $\alpha = \mathrm{softmax} \left( w_\mathrm{att}^T Y \right)$. This is an element-wise softmax, where the output size of $\alpha$ is $\left( 1 \times T \right)$

* $h_\mathrm{att}$ – Output of the attention mechanism, given by $h_\mathrm{att} = Y \alpha^T$, of size $\left( N \times 1 \right)$, i.e. a vector of $N$ features.

In [None]:
SoftmaxAttention(100)

## Baseline model

In [None]:
model = Baseline()

### Training 

In [None]:
data1, label1 = dataset[0]
data2, label2 = dataset[1]
batch_data = torch.cat([data1.unsqueeze(0) * 10, data2.unsqueeze(0) * 10], dim=0)
print(batch_data.shape)

batch_data = torch.rand((90, 20, 20, 375))
output = model(batch_data)

In [None]:
utils.show_spectrogram(data1[0])
utils.show_spectrogram(data1[1])

In [11]:
heldout = int(len(dataset) * 0.2)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - heldout, heldout])

model = Baseline()
config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.005,
    weight_decay=0.02,
    
    num_epochs=200,

#     num_epochs=200,
    is_notebook=True
)

train(model, train_dataset, config)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=200.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.7649, grad_fn=<NllLossBackward>)
Correct: 13
Loss:  tensor(0.4910, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4689, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4134, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4357, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4802, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4252, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4471, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.3929, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4360, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4708, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4589, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3922, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [1/200], Accuracy: 81.04%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4672, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4912, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4712, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4178, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4380, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4831, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4301, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4492, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.3984, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4375, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4736, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4621, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3922, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [2/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4746, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4918, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4739, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4205, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4394, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4843, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4326, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4501, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4004, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4383, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4745, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4634, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3922, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [3/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4774, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4921, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4747, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4219, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4400, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4851, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4337, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4507, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4017, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4388, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4752, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4641, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3923, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [4/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4800, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4923, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4749, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4234, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4405, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4858, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4348, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4513, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4031, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4393, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4761, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4649, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3923, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [5/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4829, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4925, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4750, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4251, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4411, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4867, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4359, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4521, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4048, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4401, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4771, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4658, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3923, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [6/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4862, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4929, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4751, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4271, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4421, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4878, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4374, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4532, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4066, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4411, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4782, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4669, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3924, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [7/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4890, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4934, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4757, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4289, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4432, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4887, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4389, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4542, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4083, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4420, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4792, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4679, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3924, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [8/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4915, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4938, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4763, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4303, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4442, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4895, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4401, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4550, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4097, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4426, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4800, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4685, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3925, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [9/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4935, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4940, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4768, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4313, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4447, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4901, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4409, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4554, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4106, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4427, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4806, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4687, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3925, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [10/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4954, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4941, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4772, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4320, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4450, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4905, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4413, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4556, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4114, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4426, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4811, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4687, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3926, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [11/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4968, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4942, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4775, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4326, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4451, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4909, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4414, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4557, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4119, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4425, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4815, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4686, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3927, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [12/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4977, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4942, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4777, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4331, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4450, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4912, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4414, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4559, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4122, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4425, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4817, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4687, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3928, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [13/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4980, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4943, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4780, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4335, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4449, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4914, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4414, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4560, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4125, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4425, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4818, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4688, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3928, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [14/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4981, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4944, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4784, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4338, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4448, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4916, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4415, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4561, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4127, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4424, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4819, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4689, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3929, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [15/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4977, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4945, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4788, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4340, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4447, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4917, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4415, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4562, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4128, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4424, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4821, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4688, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3930, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [16/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4976, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4945, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4792, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4342, grad_fn=<NllLossBackward>)
Correct: 81
Loss:  tensor(0.4447, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4919, grad_fn=<NllLossBackward>)
Correct: 75
Loss:  tensor(0.4415, grad_fn=<NllLossBackward>)
Correct: 80
Loss:  tensor(0.4562, grad_fn=<NllLossBackward>)
Correct: 78
Loss:  tensor(0.4130, grad_fn=<NllLossBackward>)
Correct: 83
Loss:  tensor(0.4423, grad_fn=<NllLossBackward>)
Correct: 79
Loss:  tensor(0.4822, grad_fn=<NllLossBackward>)
Correct: 76
Loss:  tensor(0.4689, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.3931, grad_fn=<NllLossBackward>)
Correct: 35

Epoch [17/200], Accuracy: 86.76%


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

Loss:  tensor(0.4972, grad_fn=<NllLossBackward>)
Correct: 77
Loss:  tensor(0.4946, grad_fn=<NllLossBackward>)
Correct: 74
Loss:  tensor(0.4797, grad_fn=<NllLossBackward>)




Traceback (most recent call last):
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
  

KeyboardInterrupt: 

## Testing

In [None]:
y_true = dataset.labels[test_dataset.indices]
y_pred, test_acc = test(model, test_dataset, config)
print(len(test_dataset))
print(dataset.labels[test_dataset.indices].sum().item())

In [None]:
results = pd.DataFrame(classification_report(y_true, y_pred, zero_division=0, output_dict=True)).transpose()
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn+fp)
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred)
auc_score = metrics.auc(fpr, tpr)

In [None]:
display(results)
print("Specificity:", specificity)
print("AUC:", auc_score)

In [None]:
lr_auc = roc_auc_score(y_true, y_pred)
print('ROC AUC=%.3f' % (lr_auc))
lr_fpr, lr_tpr, _ = roc_curve(y_true, y_pred)
pyplot.plot(lr_fpr, lr_tpr, marker='.', label='Baseline model')

pyplot.xlabel('False Positive Rate')
pyplot.ylabel('True Positive Rate')
pyplot.legend()
pyplot.show()

In [None]:
from sklearn.metrics import average_precision_score
pr_auc = average_precision_score(y_true, y_pred)
print("PR AUC:", specificity)