# Model Processing

## Imports & General Settings 

In [None]:
import unittest
import os
import sys
import time
import pathlib

import matplotlib.pyplot as plt
import sklearn
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage
import PIL
from tqdm.notebook import trange, tqdm
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_curve, roc_auc_score

from matplotlib import pyplot

# Our imports
from data import WaveletTransform, AFECGDataset, SecondDataset, WrapperDataset
import dsp
from model.blocks import ConvNet, BRNN, SoftmaxAttention
from model.baseline import Baseline
from training import train, test
import utils


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

## Dataset creation

In [None]:
dataset_name = 'afdb'
dataset2 = SecondDataset(dataset_name, '../data/files/')
class_weights = dataset2.load('./data')

In [None]:
example, label = dataset2[0]
example

In [None]:
total_data_size = len(dataset2)
print("Total data size: ", total_data_size)
print("Patients with AF: ", dataset2.labels.sum().item())

In [None]:
dataset_name = 'afdb'
dataset = AFECGDataset(dataset_name, '../data/files/')

In [None]:
dataset.load()

In [None]:
data_af, label_af = dataset[0]
data_nsr, label_nsr = dataset[1]

t = data_nsr[0]
utils.show_spectrogram(t)

In [None]:
images_per_sample = 20
total_data_size = len(dataset)
print("Total data size: ", total_data_size)
print("Patients with AF: ", dataset.labels.sum().item())

In [None]:
# data = [dataset[i][0] for i in range(total_data_size)]
# labels = [dataset[i][1] for i in range(total_data_size)]

### Example of one ECG sample

In [None]:
# samples, label = data[0], labels[0]
# print('P-signal: ', samples)
# print('Has AF: ', 'Yes' if label == 1 else 'No')

In [None]:
# to_wavelet = WaveletTransform(wavelet.Morlet(6), resample=20)
# t = to_wavelet(data[0][0])
# image_test = (t * 100 * 255).int() # Simple visualization test
# transforms.ToPILImage()(image_test).show()

##  Wavelet Transform

In [None]:
# Total data size is 1397
# You can choose the data size 
data_size = len(dataset)

In [None]:
# x0 = x_train[0][0].float()
# encoder_cnn = ConvNet((375, 20))

# display(x0.unsqueeze(0).shape)
# h = encoder_cnn(x0.unsqueeze(0))
# print(h.shape)

# test.assertEqual(h.dim(), 2)
# test.assertSequenceEqual(h.shape, (1, 50))

## BRNN

In [None]:
display(BRNN(50, 50, images_per_sample))

## Attention

Notations:

* $Y = \left[ y_1, \ldots, y_T \right]$ – the input matrix of size $\left( N \times T \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $w_\mathrm{att}$ – The parameters of the attention model, of size $\left( N \times 1 \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $\alpha$ – The attention weights, given as $\alpha = \mathrm{softmax} \left( w_\mathrm{att}^T Y \right)$. This is an element-wise softmax, where the output size of $\alpha$ is $\left( 1 \times T \right)$

* $h_\mathrm{att}$ – Output of the attention mechanism, given by $h_\mathrm{att} = Y \alpha^T$, of size $\left( N \times 1 \right)$, i.e. a vector of $N$ features.

In [None]:
SoftmaxAttention(100)

## Training

In [None]:
model = Baseline()

### CNN

In [None]:
data1, label1 = dataset[0]
utils.show_spectrogram(data1[0])

In [None]:
heldout = int(len(dataset2) * 0.2)
train_dataset2, test_dataset2 = torch.utils.data.random_split(dataset2, [len(dataset2) - heldout, heldout])

In [None]:
data_pos = dataset2.samples[dataset2.labels == 1]
data_neg = dataset2.samples[dataset2.labels == 0][:100]
labels_pos = dataset2.labels[dataset2.labels == 1]
labels_neg = dataset2.labels[dataset2.labels == 0][:100]

data = torch.cat([data_pos, data_neg])
labels = torch.cat([labels_pos, labels_neg])

In [None]:
# train_class_weights = torch.tensor(class_weights)[train_dataset2.indices]

model = nn.Sequential(
    ConvNet(size=(375, 20), batch=False),
    nn.Linear(50, 2)
)

config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.001,
    weight_decay=0.01,
    class_weights=None,
    num_epochs=200,
    is_notebook=True
)

train(model, WrapperDataset(data, labels), config)

### Baseline

In [None]:
heldout = int(len(dataset) * 0.2)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - heldout, heldout])

model = Baseline(add_brnn=False)
config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.001,
    weight_decay=0.01,
    
    num_epochs=200,

#     num_epochs=200,
    is_notebook=True
)

train(model, train_dataset, config)

## Testing

In [None]:
y_true = dataset2.labels[test_dataset2.indices]
y_pred, test_acc = test(model, test_dataset2, config)
print(len(test_dataset2))
print(dataset2.labels[test_dataset2.indices].sum().item())

In [None]:
results = pd.DataFrame(classification_report(y_true, y_pred, zero_division=0, output_dict=True)).transpose()
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn+fp)
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred)
auc_score = metrics.auc(fpr, tpr)

In [None]:
display(results)
print("Specificity:", specificity)
print("AUC:", auc_score)

In [None]:
lr_auc = roc_auc_score(y_true, y_pred)
print('ROC AUC=%.3f' % (lr_auc))
lr_fpr, lr_tpr, _ = roc_curve(y_true, y_pred)
pyplot.plot(lr_fpr, lr_tpr, marker='.', label='Baseline model')

pyplot.xlabel('False Positive Rate')
pyplot.ylabel('True Positive Rate')
pyplot.legend()
pyplot.show()

In [None]:
pr_auc = average_precision_score(y_true, y_pred)
print("PR AUC:", specificity)