# Model Processing

## Imports & General Settings 

In [1]:
import unittest
import os
import sys
import time
import pathlib

import matplotlib.pyplot as plt
import sklearn
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage
import PIL
from tqdm.notebook import trange, tqdm
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_curve, roc_auc_score

from matplotlib import pyplot

# Our imports
from data import WaveletTransform, AFECGDataset
import dsp
from model.blocks import ConvNet, BRNN, SoftmaxAttention
from model.baseline import Baseline
from training import train, test
import utils


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


## Dataset creation

In [3]:
dataset_name = 'afdb'
dataset = AFECGDataset(dataset_name, '../data/files/')

In [4]:
dataset.load()

36it [00:00, 358.35it/s]

Preparing 1397 samples


1397it [00:05, 235.85it/s]


Elapsed time: 5925.168991088867 ms
Skipped 1397 files which had a backup


In [None]:
data_af, label_af = dataset[0]
data_nsr, label_nsr = dataset[1]

t = data_nsr[0]
utils.show_spectrogram(t)

In [None]:
images_per_sample = 20
total_data_size = len(dataset)
print("Total data size: ", total_data_size)
print("Patients with AF: ", dataset.labels.sum().item())

In [None]:
# data = [dataset[i][0] for i in range(total_data_size)]
# labels = [dataset[i][1] for i in range(total_data_size)]

### Example of one ECG sample

In [None]:
# samples, label = data[0], labels[0]
# print('P-signal: ', samples)
# print('Has AF: ', 'Yes' if label == 1 else 'No')

In [None]:
# to_wavelet = WaveletTransform(wavelet.Morlet(6), resample=20)
# t = to_wavelet(data[0][0])
# image_test = (t * 100 * 255).int() # Simple visualization test
# transforms.ToPILImage()(image_test).show()

##  Wavelet Transform

In [None]:
# Total data size is 1397
# You can choose the data size 
data_size = len(dataset)

In [None]:
# x0 = x_train[0][0].float()
# encoder_cnn = ConvNet((375, 20))

# display(x0.unsqueeze(0).shape)
# h = encoder_cnn(x0.unsqueeze(0))
# print(h.shape)

# test.assertEqual(h.dim(), 2)
# test.assertSequenceEqual(h.shape, (1, 50))

## BRNN

In [None]:
display(BRNN(50, 50, images_per_sample))

## Attention

Notations:

* $Y = \left[ y_1, \ldots, y_T \right]$ – the input matrix of size $\left( N \times T \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $w_\mathrm{att}$ – The parameters of the attention model, of size $\left( N \times 1 \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $\alpha$ – The attention weights, given as $\alpha = \mathrm{softmax} \left( w_\mathrm{att}^T Y \right)$. This is an element-wise softmax, where the output size of $\alpha$ is $\left( 1 \times T \right)$

* $h_\mathrm{att}$ – Output of the attention mechanism, given by $h_\mathrm{att} = Y \alpha^T$, of size $\left( N \times 1 \right)$, i.e. a vector of $N$ features.

In [None]:
SoftmaxAttention(100)

## Training

In [None]:
model = Baseline()

### CNN

In [None]:
data1, label1 = dataset[0]
utils.show_spectrogram(data1[0])

In [13]:
heldout = int(len(dataset) * 0.2)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - heldout, heldout])

model = nn.Sequential(
    ConvNet(size=(375, 20)),
    nn.Linear(50, 2)
)
config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.001,
    weight_decay=0.01,
    
    num_epochs=200,

#     num_epochs=200,
    is_notebook=True
)

train(model, train_dataset, config)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=200.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.6827, grad_fn=<NllLossBackward>)
Output: tensor([[-0.0524, -0.0870],
        [-0.0494, -0.0965],
        [-0.0521, -0.0965],
        [-0.0475, -0.0901],
        [-0.0344, -0.0886],
        [-0.0489, -0.0961],
        [-0.0342, -0.0994],
        [-0.0527, -0.0960],
        [-0.0449, -0.1113],
        [-0.0332, -0.0958],
        [-0.0374, -0.1049],
        [-0.0506, -0.0961],
        [-0.0495, -0.0823],
        [-0.0415, -0.0967],
        [-0.0545, -0.1032],
        [-0.0478, -0.0911],
        [-0.0516, -0.0960],
        [-0.0472, -0.1069],
        [-0.0602, -0.0915],
        [-0.0448, -0.0836],
        [-0.0465, -0.1102],
        [-0.0451, -0.1063],
        [-0.0509, -0.0968],
        [-0.0606, -0.1057],
        [-0.0450, -0.0968],
        [-0.0423, -0.0958],
        [-0.0494, -0.0953],
        [-0.0401, -0.0594],
        [-0.0492, -0.1093],
        [-0.0495, -0.0908],
        [-0.0473, -0.0848],
        [-0.0401, -0.1038],
        [-0.0468, -0.

Output: tensor([[ 0.2912, -0.7031],
        [ 0.2938, -0.6840],
        [ 0.2448, -0.6383],
        [ 0.2363, -0.6187],
        [ 0.1519, -0.4346],
        [ 0.2376, -0.6190],
        [ 0.2075, -0.5687],
        [ 0.2971, -0.7233],
        [ 0.2971, -0.7134],
        [ 0.2350, -0.6182],
        [ 0.2592, -0.6635],
        [ 0.2383, -0.6124],
        [ 0.3148, -0.7518],
        [ 0.2103, -0.5760],
        [ 0.2443, -0.6327],
        [ 0.2942, -0.7191],
        [ 0.3120, -0.7140],
        [ 0.2896, -0.7058],
        [ 0.2604, -0.6533],
        [ 0.2964, -0.7233],
        [ 0.3028, -0.7210],
        [ 0.2560, -0.6521],
        [ 0.2415, -0.6330],
        [ 0.2164, -0.5534],
        [ 0.2960, -0.6914],
        [ 0.2841, -0.7003],
        [ 0.1764, -0.4752],
        [ 0.1178, -0.3867],
        [ 0.2153, -0.5781],
        [ 0.3041, -0.7343],
        [ 0.2797, -0.6950],
        [ 0.3107, -0.7447],
        [ 0.2534, -0.6528],
        [ 0.1915, -0.5233],
        [ 0.2470, -0.6411],
        [ 0.

Output: tensor([[ 0.3973, -0.8354],
        [ 0.3627, -0.7891],
        [ 0.3894, -0.8320],
        [ 0.2799, -0.6466],
        [ 0.4035, -0.8532],
        [ 0.3480, -0.7640],
        [ 0.3643, -0.7898],
        [ 0.3219, -0.7166],
        [ 0.3471, -0.7633],
        [ 0.4040, -0.8540],
        [ 0.3702, -0.7975],
        [ 0.4193, -0.8747],
        [ 0.3877, -0.8187],
        [ 0.3851, -0.8176],
        [ 0.3876, -0.8273],
        [ 0.3297, -0.7377],
        [ 0.3307, -0.7326],
        [ 0.3351, -0.7421],
        [ 0.2443, -0.5952],
        [ 0.3549, -0.7692],
        [ 0.3811, -0.8071],
        [ 0.2762, -0.6433],
        [ 0.3523, -0.7684],
        [ 0.3333, -0.7412],
        [ 0.3916, -0.8344],
        [ 0.3249, -0.7219],
        [ 0.3622, -0.7855],
        [ 0.3876, -0.8283],
        [ 0.2387, -0.5844],
        [ 0.2278, -0.5612],
        [ 0.2819, -0.6536],
        [ 0.3956, -0.8364],
        [ 0.3882, -0.8192],
        [ 0.3959, -0.8406],
        [ 0.4025, -0.8511],
        [ 0.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5859, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.1767, -0.4596],
        [ 0.2510, -0.5666],
        [ 0.2802, -0.6115],
        [ 0.1840, -0.4687],
        [ 0.1824, -0.4586],
        [ 0.2729, -0.6004],
        [ 0.2631, -0.5843],
        [ 0.2788, -0.6099],
        [ 0.3257, -0.6723],
        [ 0.2659, -0.5897],
        [ 0.3028, -0.6434],
        [ 0.2790, -0.6100],
        [ 0.1653, -0.4351],
        [ 0.2164, -0.5194],
        [ 0.2462, -0.5596],
        [ 0.2294, -0.5374],
        [ 0.2752, -0.6047],
        [ 0.2798, -0.6115],
        [ 0.2115, -0.5076],
        [ 0.2528, -0.5726],
        [ 0.3187, -0.6648],
        [ 0.2688, -0.5941],
        [ 0.2791, -0.6094],
        [ 0.1596, -0.4272],
        [ 0.2405, -0.5473],
        [ 0.2382, -0.5445],
        [ 0.2560, -0.5766],
        [ 0.1624, -0.4310],
        [ 0.2769, -0.6023],
        [ 0.1836, -0.4652],
        [ 0.2013, -0.4929],
        [ 0.2789, -0.6043],
        [ 0.2807, -0.

Output: tensor([[ 0.1995, -0.4403],
        [ 0.2459, -0.5054],
        [ 0.1763, -0.4101],
        [ 0.1765, -0.4116],
        [ 0.1355, -0.3492],
        [ 0.1843, -0.4219],
        [ 0.1545, -0.3790],
        [ 0.1980, -0.4386],
        [ 0.2262, -0.4815],
        [ 0.1711, -0.4034],
        [ 0.1852, -0.4226],
        [ 0.1753, -0.4071],
        [ 0.2018, -0.4439],
        [ 0.1615, -0.3870],
        [ 0.1794, -0.4157],
        [ 0.1946, -0.4336],
        [ 0.2371, -0.4908],
        [ 0.1982, -0.4395],
        [ 0.1985, -0.4400],
        [ 0.1970, -0.4377],
        [ 0.2162, -0.4662],
        [ 0.1832, -0.4171],
        [ 0.1715, -0.4031],
        [ 0.1730, -0.4008],
        [ 0.2435, -0.5039],
        [ 0.1938, -0.4341],
        [ 0.1678, -0.3941],
        [ 0.0986, -0.2999],
        [ 0.1736, -0.4039],
        [ 0.2013, -0.4435],
        [ 0.1853, -0.4218],
        [ 0.1997, -0.4406],
        [ 0.1779, -0.4106],
        [ 0.1464, -0.3646],
        [ 0.1787, -0.4125],
        [ 0.

Output: tensor([[ 0.4285, -0.7988],
        [ 0.3867, -0.7443],
        [ 0.3963, -0.7551],
        [ 0.3171, -0.6377],
        [ 0.3975, -0.7548],
        [ 0.3686, -0.7156],
        [ 0.3891, -0.7457],
        [ 0.3598, -0.7016],
        [ 0.3698, -0.7186],
        [ 0.4016, -0.7607],
        [ 0.3867, -0.7439],
        [ 0.4236, -0.7937],
        [ 0.4342, -0.8076],
        [ 0.3897, -0.7416],
        [ 0.3947, -0.7514],
        [ 0.3546, -0.6979],
        [ 0.3767, -0.7287],
        [ 0.3630, -0.7095],
        [ 0.2853, -0.5957],
        [ 0.3880, -0.7407],
        [ 0.4243, -0.7915],
        [ 0.3146, -0.6358],
        [ 0.3945, -0.7536],
        [ 0.3566, -0.6983],
        [ 0.3959, -0.7538],
        [ 0.3684, -0.7158],
        [ 0.3885, -0.7463],
        [ 0.3930, -0.7495],
        [ 0.2802, -0.5881],
        [ 0.2305, -0.5093],
        [ 0.3073, -0.6253],
        [ 0.4072, -0.7696],
        [ 0.4361, -0.8113],
        [ 0.4195, -0.7906],
        [ 0.3990, -0.7575],
        [ 0.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5832, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2006, -0.4406],
        [ 0.2767, -0.5359],
        [ 0.2853, -0.5438],
        [ 0.2157, -0.4596],
        [ 0.2421, -0.4932],
        [ 0.2822, -0.5406],
        [ 0.3246, -0.6067],
        [ 0.2835, -0.5419],
        [ 0.3634, -0.6479],
        [ 0.2993, -0.5690],
        [ 0.3459, -0.6312],
        [ 0.2856, -0.5447],
        [ 0.2114, -0.4514],
        [ 0.2495, -0.5062],
        [ 0.2979, -0.5701],
        [ 0.2514, -0.5037],
        [ 0.2824, -0.5404],
        [ 0.3192, -0.5983],
        [ 0.2388, -0.4880],
        [ 0.2876, -0.5569],
        [ 0.3508, -0.6332],
        [ 0.2980, -0.5652],
        [ 0.2839, -0.5421],
        [ 0.2082, -0.4520],
        [ 0.2854, -0.5488],
        [ 0.2870, -0.5492],
        [ 0.2769, -0.5372],
        [ 0.2068, -0.4487],
        [ 0.3148, -0.5855],
        [ 0.2118, -0.4518],
        [ 0.2397, -0.4900],
        [ 0.3218, -0.5939],
        [ 0.2906, -0.

Output: tensor([[ 0.2113, -0.4038],
        [ 0.3044, -0.5337],
        [ 0.1964, -0.3878],
        [ 0.1979, -0.3907],
        [ 0.1914, -0.3861],
        [ 0.2091, -0.4068],
        [ 0.1836, -0.3725],
        [ 0.2060, -0.3952],
        [ 0.2616, -0.4760],
        [ 0.1966, -0.3911],
        [ 0.2040, -0.3966],
        [ 0.2048, -0.3993],
        [ 0.2044, -0.3922],
        [ 0.1916, -0.3821],
        [ 0.2053, -0.4008],
        [ 0.2024, -0.3900],
        [ 0.2831, -0.5002],
        [ 0.2110, -0.4042],
        [ 0.2381, -0.4427],
        [ 0.2049, -0.3938],
        [ 0.2351, -0.4395],
        [ 0.2028, -0.3937],
        [ 0.1908, -0.3801],
        [ 0.2180, -0.4183],
        [ 0.2987, -0.5271],
        [ 0.2068, -0.3982],
        [ 0.2351, -0.4486],
        [ 0.1230, -0.3007],
        [ 0.2154, -0.4143],
        [ 0.2075, -0.3975],
        [ 0.1961, -0.3839],
        [ 0.2039, -0.3913],
        [ 0.1961, -0.3846],
        [ 0.1721, -0.3550],
        [ 0.1998, -0.3911],
        [ 0.

Output: tensor([[ 0.4516, -0.7908],
        [ 0.3925, -0.7138],
        [ 0.3926, -0.7093],
        [ 0.3345, -0.6298],
        [ 0.3873, -0.7004],
        [ 0.3747, -0.6863],
        [ 0.4011, -0.7235],
        [ 0.3896, -0.7073],
        [ 0.3782, -0.6926],
        [ 0.3930, -0.7083],
        [ 0.3850, -0.7046],
        [ 0.4195, -0.7471],
        [ 0.4691, -0.8158],
        [ 0.3865, -0.6985],
        [ 0.3917, -0.7071],
        [ 0.3636, -0.6739],
        [ 0.4065, -0.7340],
        [ 0.3727, -0.6869],
        [ 0.3057, -0.5942],
        [ 0.4105, -0.7340],
        [ 0.4577, -0.7983],
        [ 0.3335, -0.6297],
        [ 0.4227, -0.7536],
        [ 0.3644, -0.6718],
        [ 0.3914, -0.7074],
        [ 0.3902, -0.7100],
        [ 0.3970, -0.7188],
        [ 0.3890, -0.7039],
        [ 0.3045, -0.5916],
        [ 0.2281, -0.4792],
        [ 0.3254, -0.6174],
        [ 0.4111, -0.7358],
        [ 0.4709, -0.8196],
        [ 0.4342, -0.7704],
        [ 0.3909, -0.7062],
        [ 0.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5815, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2112, -0.4296],
        [ 0.2855, -0.5161],
        [ 0.2807, -0.5042],
        [ 0.2285, -0.4502],
        [ 0.2842, -0.5226],
        [ 0.2802, -0.5050],
        [ 0.3674, -0.6299],
        [ 0.2791, -0.5025],
        [ 0.3922, -0.6477],
        [ 0.3133, -0.5554],
        [ 0.3748, -0.6332],
        [ 0.2815, -0.5058],
        [ 0.2364, -0.4585],
        [ 0.2616, -0.4925],
        [ 0.3316, -0.5836],
        [ 0.2566, -0.4799],
        [ 0.2786, -0.5019],
        [ 0.3414, -0.5937],
        [ 0.2448, -0.4669],
        [ 0.3010, -0.5436],
        [ 0.3730, -0.6258],
        [ 0.3134, -0.5525],
        [ 0.2796, -0.5031],
        [ 0.2358, -0.4659],
        [ 0.3109, -0.5521],
        [ 0.3192, -0.5590],
        [ 0.2816, -0.5110],
        [ 0.2322, -0.4595],
        [ 0.3388, -0.5830],
        [ 0.2239, -0.4425],
        [ 0.2546, -0.4807],
        [ 0.3524, -0.5994],
        [ 0.2887, -0.

Output: tensor([[ 0.2115, -0.3787],
        [ 0.3483, -0.5636],
        [ 0.2029, -0.3720],
        [ 0.2012, -0.3687],
        [ 0.2284, -0.4090],
        [ 0.2144, -0.3883],
        [ 0.1985, -0.3687],
        [ 0.2052, -0.3693],
        [ 0.2854, -0.4798],
        [ 0.2113, -0.3876],
        [ 0.2098, -0.3794],
        [ 0.2219, -0.3980],
        [ 0.2012, -0.3629],
        [ 0.2061, -0.3777],
        [ 0.2177, -0.3921],
        [ 0.2015, -0.3640],
        [ 0.3193, -0.5200],
        [ 0.2136, -0.3827],
        [ 0.2658, -0.4532],
        [ 0.2041, -0.3675],
        [ 0.2431, -0.4242],
        [ 0.2090, -0.3776],
        [ 0.1972, -0.3646],
        [ 0.2487, -0.4347],
        [ 0.3404, -0.5536],
        [ 0.2096, -0.3765],
        [ 0.2872, -0.4944],
        [ 0.1330, -0.2943],
        [ 0.2401, -0.4223],
        [ 0.2059, -0.3702],
        [ 0.1977, -0.3613],
        [ 0.2016, -0.3634],
        [ 0.2017, -0.3678],
        [ 0.1761, -0.3339],
        [ 0.2081, -0.3777],
        [ 0.

Output: tensor([[ 0.4676, -0.7896],
        [ 0.3891, -0.6883],
        [ 0.3850, -0.6780],
        [ 0.3396, -0.6179],
        [ 0.3764, -0.6655],
        [ 0.3736, -0.6643],
        [ 0.4060, -0.7085],
        [ 0.4105, -0.7145],
        [ 0.3798, -0.6744],
        [ 0.3829, -0.6743],
        [ 0.3738, -0.6703],
        [ 0.4119, -0.7148],
        [ 0.4932, -0.8246],
        [ 0.3803, -0.6697],
        [ 0.3851, -0.6774],
        [ 0.3640, -0.6551],
        [ 0.4251, -0.7378],
        [ 0.3731, -0.6673],
        [ 0.3143, -0.5888],
        [ 0.4251, -0.7322],
        [ 0.4814, -0.8072],
        [ 0.3419, -0.6222],
        [ 0.4443, -0.7588],
        [ 0.3649, -0.6521],
        [ 0.3836, -0.6762],
        [ 0.3987, -0.7002],
        [ 0.3960, -0.6952],
        [ 0.3815, -0.6731],
        [ 0.3177, -0.5907],
        [ 0.2223, -0.4568],
        [ 0.3361, -0.6125],
        [ 0.4118, -0.7157],
        [ 0.4943, -0.8273],
        [ 0.4432, -0.7598],
        [ 0.3815, -0.6733],
        [ 0.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5801, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2164, -0.4240],
        [ 0.2908, -0.5070],
        [ 0.2761, -0.4822],
        [ 0.2342, -0.4439],
        [ 0.3168, -0.5478],
        [ 0.2770, -0.4850],
        [ 0.3985, -0.6499],
        [ 0.2747, -0.4809],
        [ 0.4159, -0.6564],
        [ 0.3204, -0.5476],
        [ 0.3948, -0.6381],
        [ 0.2769, -0.4839],
        [ 0.2502, -0.4614],
        [ 0.2660, -0.4824],
        [ 0.3563, -0.5978],
        [ 0.2568, -0.4645],
        [ 0.2741, -0.4801],
        [ 0.3560, -0.5937],
        [ 0.2429, -0.4486],
        [ 0.3054, -0.5328],
        [ 0.3907, -0.6281],
        [ 0.3236, -0.5480],
        [ 0.2752, -0.4816],
        [ 0.2531, -0.4751],
        [ 0.3272, -0.5561],
        [ 0.3418, -0.5690],
        [ 0.2822, -0.4953],
        [ 0.2492, -0.4684],
        [ 0.3561, -0.5865],
        [ 0.2298, -0.4370],
        [ 0.2611, -0.4730],
        [ 0.3755, -0.6094],
        [ 0.2854, -0.

Output: tensor([[ 0.2101, -0.3643],
        [ 0.3794, -0.5858],
        [ 0.2053, -0.3629],
        [ 0.2005, -0.3532],
        [ 0.2537, -0.4235],
        [ 0.2149, -0.3751],
        [ 0.2073, -0.3672],
        [ 0.2038, -0.3555],
        [ 0.3017, -0.4852],
        [ 0.2224, -0.3900],
        [ 0.2126, -0.3705],
        [ 0.2337, -0.4005],
        [ 0.1987, -0.3480],
        [ 0.2141, -0.3752],
        [ 0.2260, -0.3895],
        [ 0.2001, -0.3504],
        [ 0.3455, -0.5369],
        [ 0.2148, -0.3719],
        [ 0.2855, -0.4638],
        [ 0.2027, -0.3536],
        [ 0.2477, -0.4166],
        [ 0.2115, -0.3686],
        [ 0.2000, -0.3563],
        [ 0.2721, -0.4505],
        [ 0.3699, -0.5737],
        [ 0.2107, -0.3657],
        [ 0.3234, -0.5248],
        [ 0.1365, -0.2857],
        [ 0.2557, -0.4283],
        [ 0.2038, -0.3557],
        [ 0.1980, -0.3496],
        [ 0.1998, -0.3493],
        [ 0.2042, -0.3592],
        [ 0.1748, -0.3168],
        [ 0.2124, -0.3712],
        [ 0.

Loss:  tensor(0.5367, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.4777, -0.7877],
        [ 0.3834, -0.6680],
        [ 0.3777, -0.6558],
        [ 0.3394, -0.6052],
        [ 0.3674, -0.6419],
        [ 0.3701, -0.6468],
        [ 0.4071, -0.6959],
        [ 0.4232, -0.7165],
        [ 0.3789, -0.6604],
        [ 0.3744, -0.6510],
        [ 0.3614, -0.6422],
        [ 0.4046, -0.6916],
        [ 0.5078, -0.8275],
        [ 0.3744, -0.6495],
        [ 0.3788, -0.6567],
        [ 0.3606, -0.6387],
        [ 0.4352, -0.7364],
        [ 0.3704, -0.6510],
        [ 0.3168, -0.5811],
        [ 0.4337, -0.7292],
        [ 0.4964, -0.8111],
        [ 0.3455, -0.6146],
        [ 0.4601, -0.7632],
        [ 0.3629, -0.6369],
        [ 0.3763, -0.6542],
        [ 0.4003, -0.6881],
        [ 0.3914, -0.6752],
        [ 0.3744, -0.6515],
        [ 0.3240, -0.5862],
        [ 0.2154, -0.4384],
        [ 0.3404, -0.6049],
        [ 0.4120, -0.7023],
        [ 0.5083, -0.8292],
        [ 0.4480, -

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5788, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2197, -0.4218],
        [ 0.2959, -0.5050],
        [ 0.2742, -0.4720],
        [ 0.2381, -0.4415],
        [ 0.3422, -0.5691],
        [ 0.2761, -0.4758],
        [ 0.4222, -0.6668],
        [ 0.2731, -0.4710],
        [ 0.4356, -0.6675],
        [ 0.3256, -0.5447],
        [ 0.4092, -0.6432],
        [ 0.2750, -0.4735],
        [ 0.2588, -0.4633],
        [ 0.2686, -0.4768],
        [ 0.3759, -0.6115],
        [ 0.2567, -0.4562],
        [ 0.2722, -0.4698],
        [ 0.3672, -0.5965],
        [ 0.2397, -0.4358],
        [ 0.3071, -0.5256],
        [ 0.4056, -0.6345],
        [ 0.3319, -0.5482],
        [ 0.2735, -0.4716],
        [ 0.2651, -0.4821],
        [ 0.3390, -0.5609],
        [ 0.3585, -0.5783],
        [ 0.2830, -0.4878],
        [ 0.2625, -0.4771],
        [ 0.3700, -0.5926],
        [ 0.2340, -0.4349],
        [ 0.2658, -0.4700],
        [ 0.3935, -0.6199],
        [ 0.2843, -0.

Loss:  tensor(0.5657, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2098, -0.3571],
        [ 0.4023, -0.6027],
        [ 0.2068, -0.3584],
        [ 0.2000, -0.3443],
        [ 0.2728, -0.4350],
        [ 0.2149, -0.3673],
        [ 0.2140, -0.3684],
        [ 0.2035, -0.3490],
        [ 0.3139, -0.4907],
        [ 0.2319, -0.3951],
        [ 0.2151, -0.3671],
        [ 0.2435, -0.4054],
        [ 0.1980, -0.3413],
        [ 0.2198, -0.3750],
        [ 0.2330, -0.3908],
        [ 0.1999, -0.3441],
        [ 0.3652, -0.5506],
        [ 0.2168, -0.3679],
        [ 0.3006, -0.4736],
        [ 0.2024, -0.3472],
        [ 0.2515, -0.4137],
        [ 0.2135, -0.3644],
        [ 0.2023, -0.3528],
        [ 0.2916, -0.4656],
        [ 0.3915, -0.5886],
        [ 0.2123, -0.3614],
        [ 0.3485, -0.5445],
        [ 0.1385, -0.2798],
        [ 0.2666, -0.4335],
        [ 0.2030, -0.3486],
        [ 0.1989, -0.3445],
        [ 0.1995, -0.3430],
        [ 0.2066, -0.3561],
        [ 0.1738, -

Loss:  tensor(0.5369, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.4847, -0.7848],
        [ 0.3778, -0.6512],
        [ 0.3712, -0.6383],
        [ 0.3370, -0.5927],
        [ 0.3604, -0.6238],
        [ 0.3658, -0.6317],
        [ 0.4066, -0.6845],
        [ 0.4312, -0.7150],
        [ 0.3777, -0.6491],
        [ 0.3675, -0.6331],
        [ 0.3500, -0.6188],
        [ 0.3982, -0.6731],
        [ 0.5173, -0.8267],
        [ 0.3693, -0.6336],
        [ 0.3735, -0.6406],
        [ 0.3564, -0.6246],
        [ 0.4410, -0.7324],
        [ 0.3671, -0.6371],
        [ 0.3168, -0.5726],
        [ 0.4393, -0.7251],
        [ 0.5067, -0.8117],
        [ 0.3469, -0.6069],
        [ 0.4727, -0.7662],
        [ 0.3602, -0.6239],
        [ 0.3700, -0.6367],
        [ 0.3992, -0.6758],
        [ 0.3861, -0.6580],
        [ 0.3683, -0.6344],
        [ 0.3270, -0.5800],
        [ 0.2078, -0.4214],
        [ 0.3418, -0.5961],
        [ 0.4121, -0.6917],
        [ 0.5173, -0.8275],
        [ 0.4508, -

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5776, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2218, -0.4207],
        [ 0.3011, -0.5062],
        [ 0.2740, -0.4669],
        [ 0.2413, -0.4407],
        [ 0.3632, -0.5872],
        [ 0.2764, -0.4712],
        [ 0.4413, -0.6808],
        [ 0.2731, -0.4661],
        [ 0.4523, -0.6784],
        [ 0.3300, -0.5439],
        [ 0.4204, -0.6477],
        [ 0.2746, -0.4682],
        [ 0.2648, -0.4649],
        [ 0.2703, -0.4732],
        [ 0.3920, -0.6237],
        [ 0.2567, -0.4512],
        [ 0.2718, -0.4645],
        [ 0.3765, -0.6000],
        [ 0.2360, -0.4259],
        [ 0.3079, -0.5205],
        [ 0.4185, -0.6417],
        [ 0.3389, -0.5499],
        [ 0.2734, -0.4666],
        [ 0.2740, -0.4872],
        [ 0.3482, -0.5653],
        [ 0.3713, -0.5859],
        [ 0.2841, -0.4838],
        [ 0.2736, -0.4850],
        [ 0.3816, -0.5990],
        [ 0.2370, -0.4336],
        [ 0.2699, -0.4695],
        [ 0.4083, -0.6295],
        [ 0.2846, -0.

Loss:  tensor(0.5639, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2103, -0.3534],
        [ 0.4206, -0.6160],
        [ 0.2081, -0.3560],
        [ 0.2007, -0.3397],
        [ 0.2891, -0.4457],
        [ 0.2154, -0.3627],
        [ 0.2200, -0.3711],
        [ 0.2040, -0.3460],
        [ 0.3238, -0.4957],
        [ 0.2404, -0.4009],
        [ 0.2176, -0.3661],
        [ 0.2522, -0.4113],
        [ 0.1986, -0.3385],
        [ 0.2246, -0.3758],
        [ 0.2396, -0.3939],
        [ 0.2005, -0.3413],
        [ 0.3812, -0.5620],
        [ 0.2194, -0.3671],
        [ 0.3130, -0.4824],
        [ 0.2031, -0.3444],
        [ 0.2548, -0.4126],
        [ 0.2154, -0.3626],
        [ 0.2044, -0.3514],
        [ 0.3087, -0.4799],
        [ 0.4084, -0.6003],
        [ 0.2143, -0.3601],
        [ 0.3672, -0.5581],
        [ 0.1402, -0.2763],
        [ 0.2751, -0.4380],
        [ 0.2033, -0.3451],
        [ 0.2002, -0.3424],
        [ 0.2003, -0.3404],
        [ 0.2094, -0.3557],
        [ 0.1742, -

Loss:  tensor(0.5373, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.4906, -0.7815],
        [ 0.3729, -0.6368],
        [ 0.3659, -0.6234],
        [ 0.3341, -0.5810],
        [ 0.3547, -0.6088],
        [ 0.3615, -0.6181],
        [ 0.4059, -0.6742],
        [ 0.4373, -0.7121],
        [ 0.3768, -0.6396],
        [ 0.3620, -0.6181],
        [ 0.3397, -0.5984],
        [ 0.3927, -0.6570],
        [ 0.5245, -0.8242],
        [ 0.3649, -0.6197],
        [ 0.3693, -0.6271],
        [ 0.3526, -0.6122],
        [ 0.4454, -0.7277],
        [ 0.3640, -0.6248],
        [ 0.3159, -0.5640],
        [ 0.4438, -0.7207],
        [ 0.5148, -0.8107],
        [ 0.3474, -0.5992],
        [ 0.4837, -0.7682],
        [ 0.3575, -0.6122],
        [ 0.3646, -0.6218],
        [ 0.3975, -0.6640],
        [ 0.3812, -0.6428],
        [ 0.3632, -0.6199],
        [ 0.3286, -0.5732],
        [ 0.2000, -0.4052],
        [ 0.3421, -0.5872],
        [ 0.4125, -0.6826],
        [ 0.5240, -0.8242],
        [ 0.4531, -

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])
Loss:  tensor(0.5766, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2237, -0.4200],
        [ 0.3062, -0.5082],
        [ 0.2750, -0.4643],
        [ 0.2441, -0.4403],
        [ 0.3808, -0.6022],
        [ 0.2777, -0.4691],
        [ 0.4571, -0.6923],
        [ 0.2743, -0.4638],
        [ 0.4665, -0.6878],
        [ 0.3340, -0.5438],
        [ 0.4294, -0.6512],
        [ 0.2755, -0.4654],
        [ 0.2694, -0.4659],
        [ 0.2717, -0.4706],
        [ 0.4056, -0.6338],
        [ 0.2573, -0.4480],
        [ 0.2727, -0.4618],
        [ 0.3842, -0.6029],
        [ 0.2327, -0.4179],
        [ 0.3084, -0.5165],
        [ 0.4297, -0.6481],
        [ 0.3449, -0.5517],
        [ 0.2745, -0.4642],
        [ 0.2809, -0.4908],
        [ 0.3558, -0.5689],
        [ 0.3816, -0.5919],
        [ 0.2857, -0.4817],
        [ 0.2834, -0.4919],
        [ 0.3913, -0.6045],
        [ 0.2394, -0.4322],
        [ 0.2740, -0.4700],
        [ 0.4206, -0.6376],
        [ 0.2858, -0.

Loss:  tensor(0.5622, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.2118, -0.3519],
        [ 0.4358, -0.6271],
        [ 0.2096, -0.3549],
        [ 0.2025, -0.3380],
        [ 0.3037, -0.4560],
        [ 0.2166, -0.3603],
        [ 0.2257, -0.3743],
        [ 0.2055, -0.3451],
        [ 0.3322, -0.5002],
        [ 0.2482, -0.4066],
        [ 0.2203, -0.3663],
        [ 0.2603, -0.4173],
        [ 0.2002, -0.3378],
        [ 0.2291, -0.3774],
        [ 0.2460, -0.3977],
        [ 0.2021, -0.3406],
        [ 0.3947, -0.5717],
        [ 0.2227, -0.3680],
        [ 0.3238, -0.4901],
        [ 0.2046, -0.3437],
        [ 0.2580, -0.4125],
        [ 0.2176, -0.3622],
        [ 0.2066, -0.3511],
        [ 0.3243, -0.4933],
        [ 0.4224, -0.6097],
        [ 0.2168, -0.3605],
        [ 0.3816, -0.5679],
        [ 0.1417, -0.2739],
        [ 0.2823, -0.4421],
        [ 0.2044, -0.3438],
        [ 0.2020, -0.3419],
        [ 0.2020, -0.3398],
        [ 0.2125, -0.3568],
        [ 0.1762, -

Loss:  tensor(0.5376, grad_fn=<NllLossBackward>)
Output: tensor([[ 0.4955, -0.7774],
        [ 0.3689, -0.6241],
        [ 0.3616, -0.6104],
        [ 0.3315, -0.5700],
        [ 0.3504, -0.5958],
        [ 0.3578, -0.6056],
        [ 0.4054, -0.6646],
        [ 0.4418, -0.7079],
        [ 0.3764, -0.6309],
        [ 0.3577, -0.6051],
        [ 0.3308, -0.5802],
        [ 0.3881, -0.6427],
        [ 0.5301, -0.8203],
        [ 0.3613, -0.6074],
        [ 0.3661, -0.6154],
        [ 0.3496, -0.6012],
        [ 0.4487, -0.7223],
        [ 0.3615, -0.6137],
        [ 0.3148, -0.5554],
        [ 0.4473, -0.7156],
        [ 0.5214, -0.8082],
        [ 0.3476, -0.5915],
        [ 0.4934, -0.7693],
        [ 0.3552, -0.6015],
        [ 0.3604, -0.6087],
        [ 0.3958, -0.6530],
        [ 0.3769, -0.6291],
        [ 0.3592, -0.6072],
        [ 0.3297, -0.5663],
        [ 0.1922, -0.3897],
        [ 0.3420, -0.5784],
        [ 0.4134, -0.6743],
        [ 0.5291, -0.8195],
        [ 0.4549, -

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=13.0, style=ProgressStyle(description_wid…

torch.Size([90, 20, 375])




Traceback (most recent call last):
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Users/miki/opt/miniconda3/envs/cs236781-project/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

### Baseline

In [None]:
heldout = int(len(dataset) * 0.2)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - heldout, heldout])

model = Baseline(add_brnn=False)
config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.001,
    weight_decay=0.01,
    
    num_epochs=200,

#     num_epochs=200,
    is_notebook=True
)

train(model, train_dataset, config)

## Testing

In [None]:
y_true = dataset.labels[test_dataset.indices]
y_pred, test_acc = test(model, test_dataset, config)
print(len(test_dataset))
print(dataset.labels[test_dataset.indices].sum().item())

In [None]:
results = pd.DataFrame(classification_report(y_true, y_pred, zero_division=0, output_dict=True)).transpose()
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn+fp)
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred)
auc_score = metrics.auc(fpr, tpr)

In [None]:
display(results)
print("Specificity:", specificity)
print("AUC:", auc_score)

In [None]:
lr_auc = roc_auc_score(y_true, y_pred)
print('ROC AUC=%.3f' % (lr_auc))
lr_fpr, lr_tpr, _ = roc_curve(y_true, y_pred)
pyplot.plot(lr_fpr, lr_tpr, marker='.', label='Baseline model')

pyplot.xlabel('False Positive Rate')
pyplot.ylabel('True Positive Rate')
pyplot.legend()
pyplot.show()

In [None]:
pr_auc = average_precision_score(y_true, y_pred)
print("PR AUC:", specificity)