# Model Processing

## Imports & General Settings 

In [20]:
import unittest
import os
import sys
import time
import pathlib

import matplotlib.pyplot as plt
import sklearn
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage
import PIL
from tqdm.notebook import trange, tqdm
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_curve, roc_auc_score

from matplotlib import pyplot

# Our imports
from data import WaveletTransform, AFECGDataset, SecondDataset
import dsp
from model.blocks import ConvNet, BRNN, SoftmaxAttention
from model.baseline import Baseline
from training import train, test
import utils


%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


## Dataset creation

In [22]:
dataset_name = 'afdb'
dataset2 = SecondDataset(dataset_name, '../data/files/')
dataset2.load('./data')

Loaded 2300 samples from backup
torch.Size([2300])
torch.Size([2300, 20, 375])


In [23]:
example, label = dataset2[0]
example

tensor([[0.5568, 0.5416, 0.5116,  ..., 0.0055, 0.0058, 0.0059],
        [0.5509, 0.5371, 0.5101,  ..., 0.0057, 0.0060, 0.0061],
        [0.5441, 0.5317, 0.5073,  ..., 0.0059, 0.0062, 0.0064],
        ...,
        [0.6811, 0.6756, 0.6658,  ..., 0.0035, 0.0045, 0.0055],
        [0.8220, 0.8225, 0.8183,  ..., 0.0065, 0.0072, 0.0079],
        [0.9911, 0.9980, 1.0000,  ..., 0.0162, 0.0161, 0.0160]],
       dtype=torch.float64)

In [None]:
dataset_name = 'afdb'
dataset = AFECGDataset(dataset_name, '../data/files/')

In [None]:
dataset.load()

In [None]:
data_af, label_af = dataset[0]
data_nsr, label_nsr = dataset[1]

t = data_nsr[0]
utils.show_spectrogram(t)

In [None]:
images_per_sample = 20
total_data_size = len(dataset)
print("Total data size: ", total_data_size)
print("Patients with AF: ", dataset.labels.sum().item())

In [None]:
# data = [dataset[i][0] for i in range(total_data_size)]
# labels = [dataset[i][1] for i in range(total_data_size)]

### Example of one ECG sample

In [None]:
# samples, label = data[0], labels[0]
# print('P-signal: ', samples)
# print('Has AF: ', 'Yes' if label == 1 else 'No')

In [None]:
# to_wavelet = WaveletTransform(wavelet.Morlet(6), resample=20)
# t = to_wavelet(data[0][0])
# image_test = (t * 100 * 255).int() # Simple visualization test
# transforms.ToPILImage()(image_test).show()

##  Wavelet Transform

In [None]:
# Total data size is 1397
# You can choose the data size 
data_size = len(dataset)

In [None]:
# x0 = x_train[0][0].float()
# encoder_cnn = ConvNet((375, 20))

# display(x0.unsqueeze(0).shape)
# h = encoder_cnn(x0.unsqueeze(0))
# print(h.shape)

# test.assertEqual(h.dim(), 2)
# test.assertSequenceEqual(h.shape, (1, 50))

## BRNN

In [None]:
display(BRNN(50, 50, images_per_sample))

## Attention

Notations:

* $Y = \left[ y_1, \ldots, y_T \right]$ – the input matrix of size $\left( N \times T \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $w_\mathrm{att}$ – The parameters of the attention model, of size $\left( N \times 1 \right)$, where $N$ is the number of features in a single output vector of the BRNN

* $\alpha$ – The attention weights, given as $\alpha = \mathrm{softmax} \left( w_\mathrm{att}^T Y \right)$. This is an element-wise softmax, where the output size of $\alpha$ is $\left( 1 \times T \right)$

* $h_\mathrm{att}$ – Output of the attention mechanism, given by $h_\mathrm{att} = Y \alpha^T$, of size $\left( N \times 1 \right)$, i.e. a vector of $N$ features.

In [None]:
SoftmaxAttention(100)

## Training

In [None]:
model = Baseline()

### CNN

In [None]:
data1, label1 = dataset[0]
utils.show_spectrogram(data1[0])

In [28]:
heldout = int(len(dataset2) * 0.2)
train_dataset2, test_dataset2 = torch.utils.data.random_split(dataset2, [len(dataset2) - heldout, heldout])

model = nn.Sequential(
    ConvNet(size=(375, 20), batch=False),
    nn.Linear(50, 2)
)
config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.001,
    weight_decay=0.01,
    
    num_epochs=200,

#     num_epochs=200,
    is_notebook=True
)

train(model, train_dataset2, config)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=200.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.6842, grad_fn=<NllLossBackward>)
Output: tensor([[-0.0801, -0.0773],
        [-0.0654, -0.1028],
        [-0.0593, -0.0967],
        [-0.0668, -0.0968],
        [-0.0805, -0.0898],
        [-0.0636, -0.0997],
        [-0.0599, -0.0793],
        [-0.0739, -0.0872],
        [-0.0683, -0.0959],
        [-0.0677, -0.1104],
        [-0.0891, -0.0909],
        [-0.0662, -0.1041],
        [-0.0678, -0.0966],
        [-0.0733, -0.1016],
        [-0.0680, -0.0848],
        [-0.0864, -0.0704],
        [-0.0674, -0.0961],
        [-0.0651, -0.0894],
        [-0.0754, -0.0790],
        [-0.0652, -0.0960],
        [-0.0670, -0.0928],
        [-0.0724, -0.0794],
        [-0.0785, -0.0740],
        [-0.0486, -0.1055],
        [-0.0660, -0.0996],
        [-0.0557, -0.1008],
        [-0.0686, -0.0975],
        [-0.0980, -0.0840],
        [-0.0852, -0.0678],
        [-0.0673, -0.0987],
        [-0.0721, -0.1035],
        [-0.0732, -0.0963],
        [-0.0

Output: tensor([[ 4.7499, -3.9507],
        [ 4.0517, -3.4035],
        [ 4.1695, -3.4938],
        [ 3.5567, -2.9902],
        [ 4.9019, -4.0829],
        [ 4.4128, -3.6907],
        [ 4.1672, -3.4965],
        [ 4.7104, -3.9255],
        [ 4.4058, -3.6888],
        [ 4.6935, -3.9073],
        [ 4.0742, -3.4384],
        [ 4.8533, -4.0303],
        [ 4.1337, -3.4694],
        [ 4.0322, -3.3936],
        [ 4.6808, -3.8924],
        [ 3.8271, -3.2441],
        [ 4.3928, -3.6691],
        [ 4.9087, -4.0753],
        [ 4.2519, -3.5732],
        [ 4.7570, -3.9768],
        [ 4.2447, -3.5507],
        [ 4.7115, -3.9215],
        [ 4.9390, -4.0994],
        [ 4.9544, -4.1129],
        [ 4.7737, -3.9743],
        [ 4.2837, -3.5980],
        [ 4.1222, -3.4630],
        [ 4.9971, -4.1495],
        [ 4.4445, -3.7206],
        [ 4.0653, -3.4093],
        [ 4.6636, -3.8923],
        [ 4.6489, -3.8861],
        [ 3.6769, -3.1174],
        [ 3.7471, -3.1523],
        [ 4.3140, -3.6189],
        [ 5.

Output: tensor([[ 1.9509, -1.7274],
        [ 2.4398, -2.1213],
        [ 1.9058, -1.6865],
        [ 1.9983, -1.7574],
        [ 1.8668, -1.6575],
        [ 2.6757, -2.3074],
        [ 2.2523, -1.9679],
        [ 1.8951, -1.6848],
        [ 1.7135, -1.5316],
        [ 1.7874, -1.5954],
        [ 2.0405, -1.8001],
        [ 1.8618, -1.6516],
        [ 1.9638, -1.7326],
        [ 1.6919, -1.5191],
        [ 0.8344, -0.8201],
        [ 2.1780, -1.9099],
        [ 1.9453, -1.7163],
        [ 2.2922, -1.9985],
        [ 1.7777, -1.5869],
        [ 1.8654, -1.6582],
        [ 2.0822, -1.8260],
        [ 2.0209, -1.7807],
        [ 2.0179, -1.7806],
        [ 1.6030, -1.4398],
        [ 1.9598, -1.7353],
        [ 2.4147, -2.1018],
        [ 1.7983, -1.6063],
        [ 2.3137, -2.0268],
        [ 1.4987, -1.3578],
        [ 2.0546, -1.8047],
        [ 1.9128, -1.6977],
        [ 2.3648, -2.0615],
        [ 2.1530, -1.8905],
        [ 2.2980, -2.0059],
        [ 2.1344, -1.8702],
        [ 1.

Output: tensor([[ 1.7329, -1.6160],
        [ 1.3896, -1.3385],
        [ 1.8935, -1.7623],
        [ 2.1485, -1.9707],
        [ 1.6447, -1.5528],
        [ 1.6437, -1.5504],
        [ 2.2802, -2.0624],
        [ 0.6590, -0.6871],
        [ 1.8468, -1.7015],
        [ 1.6582, -1.5588],
        [ 1.9986, -1.8410],
        [ 2.1628, -1.9879],
        [ 1.7544, -1.6362],
        [ 1.6901, -1.5982],
        [ 2.0896, -1.9345],
        [ 2.0659, -1.9109],
        [ 2.0055, -1.8494],
        [ 1.8766, -1.7498],
        [ 1.2979, -1.2589],
        [ 2.0295, -1.8748],
        [ 2.2757, -2.0776],
        [ 1.4775, -1.4137],
        [ 2.1096, -1.9400],
        [ 2.1970, -2.0108],
        [ 1.8411, -1.7171],
        [ 2.0444, -1.8842],
        [ 2.1180, -1.9433],
        [ 2.0092, -1.8474],
        [ 1.5925, -1.4895],
        [ 2.0922, -1.9295],
        [ 2.0123, -1.8555],
        [ 2.3064, -2.0868],
        [ 1.6411, -1.5344],
        [ 1.4327, -1.3673],
        [ 2.2864, -2.0829],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0697, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.3237, -1.3039],
        [ 1.4560, -1.4259],
        [ 1.2790, -1.2543],
        [ 1.4131, -1.3783],
        [ 1.2786, -1.2730],
        [ 1.5470, -1.5264],
        [ 1.3260, -1.2903],
        [ 1.5224, -1.4690],
        [ 1.4488, -1.4277],
        [ 1.3161, -1.2976],
        [ 1.1015, -1.0975],
        [ 1.3628, -1.3389],
        [ 1.5488, -1.5115],
        [ 0.7718, -0.8004],
        [ 1.3993, -1.3655],
        [ 1.4031, -1.3768],
        [ 1.6180, -1.5758],
        [ 1.3060, -1.3010],
        [ 1.3702, -1.3450],
        [ 1.4622, -1.4365],
        [ 1.4263, -1.3919],
        [ 1.2933, -1.2710],
        [ 1.3370, -1.3056],
        [ 0.7199, -0.7699],
        [ 1.4001, -1.3702],
        [ 1.2349, -1.2276],
        [ 1.3240, -1.2997],
        [ 1.5332, -1.5168],
        [ 1.3270, -1.3035],
        [ 1.6294, -1.5876],
        [ 1.4594, -1.4367],
        [ 1.3030, -1.2826],
        [ 1.5

Output: tensor([[ 2.5581, -2.1477],
        [ 2.2306, -1.8931],
        [ 2.2447, -1.9033],
        [ 2.0618, -1.7600],
        [ 2.4650, -2.0833],
        [ 2.3213, -1.9667],
        [ 2.4042, -2.0253],
        [ 2.7295, -2.2761],
        [ 2.5141, -2.1100],
        [ 2.6584, -2.2228],
        [ 2.0716, -1.7755],
        [ 2.6838, -2.2426],
        [ 2.3634, -1.9942],
        [ 2.3240, -1.9622],
        [ 2.4836, -2.0888],
        [ 2.3365, -1.9730],
        [ 2.3899, -2.0154],
        [ 2.5291, -2.1271],
        [ 2.2300, -1.8938],
        [ 2.7694, -2.3073],
        [ 2.2754, -1.9294],
        [ 2.4661, -2.0771],
        [ 2.7376, -2.2865],
        [ 2.5554, -2.1480],
        [ 2.4642, -2.0797],
        [ 2.5179, -2.1153],
        [ 2.2655, -1.9198],
        [ 2.5791, -2.1668],
        [ 2.4045, -2.0294],
        [ 2.1981, -1.8677],
        [ 2.7742, -2.3129],
        [ 2.5798, -2.1658],
        [ 2.0786, -1.7734],
        [ 2.1975, -1.8642],
        [ 2.2648, -1.9244],
        [ 2.

Output: tensor([[ 1.8602, -1.6856],
        [ 2.0236, -1.8246],
        [ 1.9079, -1.7146],
        [ 1.8245, -1.6488],
        [ 1.7889, -1.6237],
        [ 2.4310, -2.1493],
        [ 2.2538, -2.0012],
        [ 1.7522, -1.5989],
        [ 1.7600, -1.5945],
        [ 1.6775, -1.5340],
        [ 1.8370, -1.6676],
        [ 1.8322, -1.6530],
        [ 1.8596, -1.6767],
        [ 1.6602, -1.5152],
        [ 1.3165, -1.2259],
        [ 1.8967, -1.7170],
        [ 1.8404, -1.6671],
        [ 2.1863, -1.9484],
        [ 1.7236, -1.5724],
        [ 1.7642, -1.6057],
        [ 1.9226, -1.7280],
        [ 1.8521, -1.6744],
        [ 1.8721, -1.6954],
        [ 1.7994, -1.6219],
        [ 1.7723, -1.6132],
        [ 2.0306, -1.8301],
        [ 1.8146, -1.6461],
        [ 1.9705, -1.7816],
        [ 1.5882, -1.4511],
        [ 1.8879, -1.7043],
        [ 1.7622, -1.6030],
        [ 2.0140, -1.8163],
        [ 1.9321, -1.7476],
        [ 1.9747, -1.7827],
        [ 1.8892, -1.7088],
        [ 1.

Output: tensor([[ 1.7922, -1.6814],
        [ 1.4539, -1.4040],
        [ 1.7603, -1.6703],
        [ 1.8620, -1.7536],
        [ 1.5875, -1.5244],
        [ 1.5847, -1.5174],
        [ 2.2997, -2.0983],
        [ 1.1767, -1.1431],
        [ 1.9234, -1.7853],
        [ 1.6581, -1.5799],
        [ 1.9084, -1.7844],
        [ 1.8682, -1.7621],
        [ 1.6915, -1.6172],
        [ 1.6063, -1.5516],
        [ 1.7828, -1.7048],
        [ 1.8081, -1.7136],
        [ 1.8715, -1.7566],
        [ 1.7464, -1.6562],
        [ 1.3564, -1.3328],
        [ 1.8086, -1.7070],
        [ 1.9084, -1.7949],
        [ 1.4758, -1.4359],
        [ 1.9458, -1.8216],
        [ 1.8907, -1.7781],
        [ 1.8136, -1.7108],
        [ 1.8119, -1.7116],
        [ 1.8791, -1.7680],
        [ 2.0884, -1.9295],
        [ 1.6025, -1.5289],
        [ 1.8017, -1.7074],
        [ 1.8634, -1.7485],
        [ 2.3016, -2.1021],
        [ 1.7903, -1.6775],
        [ 1.4782, -1.4268],
        [ 2.1192, -1.9614],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0660, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.2955, -1.3059],
        [ 1.3506, -1.3622],
        [ 1.3539, -1.3490],
        [ 1.3914, -1.3919],
        [ 1.2283, -1.2552],
        [ 1.3302, -1.3646],
        [ 1.4494, -1.4295],
        [ 1.5674, -1.5372],
        [ 1.3479, -1.3653],
        [ 1.2648, -1.2828],
        [ 1.2538, -1.2532],
        [ 1.3035, -1.3154],
        [ 1.4029, -1.4117],
        [ 1.1728, -1.1705],
        [ 1.3998, -1.3914],
        [ 1.3371, -1.3464],
        [ 1.4376, -1.4450],
        [ 1.1923, -1.2365],
        [ 1.2940, -1.3046],
        [ 1.3662, -1.3815],
        [ 1.3975, -1.3949],
        [ 1.3598, -1.3565],
        [ 1.3202, -1.3155],
        [ 0.9475, -0.9894],
        [ 1.3280, -1.3381],
        [ 1.2021, -1.2280],
        [ 1.2817, -1.2911],
        [ 1.3657, -1.3929],
        [ 1.3764, -1.3746],
        [ 1.4204, -1.4322],
        [ 1.3369, -1.3573],
        [ 1.3000, -1.3080],
        [ 1.6

Output: tensor([[ 2.4296, -2.0965],
        [ 2.2252, -1.9375],
        [ 2.1802, -1.8994],
        [ 2.1688, -1.8864],
        [ 2.1855, -1.9080],
        [ 2.2655, -1.9707],
        [ 2.5274, -2.1778],
        [ 2.7206, -2.3282],
        [ 2.6429, -2.2698],
        [ 2.5929, -2.2265],
        [ 2.1783, -1.9101],
        [ 2.5780, -2.2133],
        [ 2.4102, -2.0821],
        [ 2.4005, -2.0760],
        [ 2.3354, -2.0210],
        [ 2.6774, -2.3026],
        [ 2.3188, -2.0081],
        [ 2.3289, -2.0155],
        [ 2.2622, -1.9698],
        [ 2.8950, -2.4707],
        [ 2.1631, -1.8854],
        [ 2.2922, -1.9870],
        [ 2.6258, -2.2523],
        [ 2.3216, -2.0099],
        [ 2.2891, -1.9857],
        [ 2.6313, -2.2613],
        [ 2.3173, -2.0114],
        [ 2.3765, -2.0546],
        [ 2.3443, -2.0332],
        [ 2.1849, -1.9033],
        [ 2.8383, -2.4264],
        [ 2.5428, -2.1884],
        [ 2.3366, -2.0339],
        [ 2.3202, -2.0112],
        [ 2.2751, -1.9791],
        [ 2.

Output: tensor([[ 1.8550, -1.7057],
        [ 1.8459, -1.7006],
        [ 1.9450, -1.7757],
        [ 1.7657, -1.6256],
        [ 1.7605, -1.6270],
        [ 2.3310, -2.0993],
        [ 2.2805, -2.0552],
        [ 1.7111, -1.5883],
        [ 1.8514, -1.6967],
        [ 1.6691, -1.5529],
        [ 1.7862, -1.6498],
        [ 1.9109, -1.7457],
        [ 1.8416, -1.6885],
        [ 1.6978, -1.5730],
        [ 1.6769, -1.5518],
        [ 1.7613, -1.6295],
        [ 1.7932, -1.6532],
        [ 2.1333, -1.9349],
        [ 1.7249, -1.5966],
        [ 1.7679, -1.6335],
        [ 1.8584, -1.7017],
        [ 1.8273, -1.6788],
        [ 1.8188, -1.6759],
        [ 2.0026, -1.8201],
        [ 1.7051, -1.5815],
        [ 1.8638, -1.7153],
        [ 1.8789, -1.7268],
        [ 1.8051, -1.6681],
        [ 1.8229, -1.6728],
        [ 1.8198, -1.6736],
        [ 1.7421, -1.6102],
        [ 1.8651, -1.7163],
        [ 1.8343, -1.6898],
        [ 1.8303, -1.6868],
        [ 1.8094, -1.6661],
        [ 1.

Output: tensor([[ 1.8132, -1.7192],
        [ 1.5319, -1.4891],
        [ 1.6652, -1.6029],
        [ 1.6967, -1.6267],
        [ 1.6359, -1.5838],
        [ 1.5532, -1.5059],
        [ 2.2793, -2.1021],
        [ 1.5962, -1.5314],
        [ 1.9645, -1.8396],
        [ 1.6671, -1.6023],
        [ 1.8322, -1.7357],
        [ 1.7096, -1.6403],
        [ 1.6578, -1.6079],
        [ 1.5641, -1.5279],
        [ 1.6029, -1.5616],
        [ 1.6555, -1.5961],
        [ 1.7622, -1.6787],
        [ 1.6496, -1.5876],
        [ 1.4664, -1.4445],
        [ 1.6771, -1.6084],
        [ 1.7085, -1.6380],
        [ 1.5410, -1.5097],
        [ 1.8012, -1.7130],
        [ 1.7141, -1.6415],
        [ 1.7893, -1.7046],
        [ 1.6798, -1.6125],
        [ 1.7059, -1.6373],
        [ 2.0938, -1.9533],
        [ 1.5720, -1.5186],
        [ 1.6238, -1.5701],
        [ 1.7625, -1.6779],
        [ 2.2652, -2.0918],
        [ 1.8509, -1.7511],
        [ 1.5067, -1.4687],
        [ 1.9849, -1.8633],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0681, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.2501, -1.2640],
        [ 1.2774, -1.2909],
        [ 1.3647, -1.3574],
        [ 1.3610, -1.3615],
        [ 1.2102, -1.2355],
        [ 1.1828, -1.2255],
        [ 1.4906, -1.4672],
        [ 1.5605, -1.5274],
        [ 1.2879, -1.3027],
        [ 1.2080, -1.2307],
        [ 1.3339, -1.3235],
        [ 1.2699, -1.2804],
        [ 1.3225, -1.3306],
        [ 1.3908, -1.3683],
        [ 1.3868, -1.3763],
        [ 1.3126, -1.3182],
        [ 1.3305, -1.3402],
        [ 1.1116, -1.1616],
        [ 1.2688, -1.2774],
        [ 1.2793, -1.2979],
        [ 1.3599, -1.3561],
        [ 1.3623, -1.3590],
        [ 1.3220, -1.3152],
        [ 1.1204, -1.1397],
        [ 1.2619, -1.2756],
        [ 1.1719, -1.2009],
        [ 1.2370, -1.2495],
        [ 1.2652, -1.2929],
        [ 1.3939, -1.3885],
        [ 1.3018, -1.3168],
        [ 1.2476, -1.2702],
        [ 1.2792, -1.2858],
        [ 1.6

Output: tensor([[ 2.4275, -2.1415],
        [ 2.2508, -2.0022],
        [ 2.1718, -1.9353],
        [ 2.2544, -1.9981],
        [ 2.0651, -1.8510],
        [ 2.2445, -1.9977],
        [ 2.6066, -2.2935],
        [ 2.7515, -2.4085],
        [ 2.7634, -2.4211],
        [ 2.5905, -2.2761],
        [ 2.1936, -1.9668],
        [ 2.5900, -2.2738],
        [ 2.4664, -2.1754],
        [ 2.4690, -2.1810],
        [ 2.3145, -2.0488],
        [ 2.8480, -2.4979],
        [ 2.3289, -2.0609],
        [ 2.3024, -2.0375],
        [ 2.2989, -2.0457],
        [ 2.9864, -2.6054],
        [ 2.1415, -1.9093],
        [ 2.2655, -2.0089],
        [ 2.5965, -2.2800],
        [ 2.2767, -2.0168],
        [ 2.2174, -1.9700],
        [ 2.6933, -2.3646],
        [ 2.3372, -2.0744],
        [ 2.3183, -2.0531],
        [ 2.3195, -2.0589],
        [ 2.2290, -1.9817],
        [ 2.9014, -2.5356],
        [ 2.5605, -2.2524],
        [ 2.4252, -2.1557],
        [ 2.4019, -2.1235],
        [ 2.3112, -2.0537],
        [ 2.

Output: tensor([[ 1.8463, -1.7171],
        [ 1.7733, -1.6565],
        [ 1.9277, -1.7851],
        [ 1.7222, -1.6087],
        [ 1.6929, -1.5901],
        [ 2.2734, -2.0761],
        [ 2.2882, -2.0870],
        [ 1.6719, -1.5730],
        [ 1.8825, -1.7444],
        [ 1.6151, -1.5277],
        [ 1.7625, -1.6485],
        [ 1.9441, -1.7961],
        [ 1.8004, -1.6750],
        [ 1.6583, -1.5612],
        [ 1.8037, -1.6822],
        [ 1.6819, -1.5804],
        [ 1.7637, -1.6468],
        [ 2.1127, -1.9402],
        [ 1.7065, -1.5991],
        [ 1.7502, -1.6384],
        [ 1.8157, -1.6871],
        [ 1.8187, -1.6912],
        [ 1.7809, -1.6626],
        [ 2.0925, -1.9203],
        [ 1.6578, -1.5598],
        [ 1.7947, -1.6744],
        [ 1.8737, -1.7442],
        [ 1.7124, -1.6079],
        [ 1.9110, -1.7708],
        [ 1.7831, -1.6619],
        [ 1.7183, -1.6090],
        [ 1.8041, -1.6822],
        [ 1.7918, -1.6714],
        [ 1.7686, -1.6523],
        [ 1.7784, -1.6578],
        [ 1.

Output: tensor([[ 1.7814, -1.6984],
        [ 1.4951, -1.4623],
        [ 1.6145, -1.5604],
        [ 1.6268, -1.5674],
        [ 1.6217, -1.5751],
        [ 1.5128, -1.4735],
        [ 2.2722, -2.1054],
        [ 1.7095, -1.6403],
        [ 1.9833, -1.8625],
        [ 1.6505, -1.5902],
        [ 1.7940, -1.7066],
        [ 1.6461, -1.5862],
        [ 1.5940, -1.5558],
        [ 1.5143, -1.4853],
        [ 1.4995, -1.4715],
        [ 1.5835, -1.5344],
        [ 1.7122, -1.6380],
        [ 1.5763, -1.5275],
        [ 1.4600, -1.4422],
        [ 1.6150, -1.5564],
        [ 1.6273, -1.5685],
        [ 1.5308, -1.5046],
        [ 1.7145, -1.6414],
        [ 1.6408, -1.5793],
        [ 1.7586, -1.6822],
        [ 1.6245, -1.5656],
        [ 1.5919, -1.5423],
        [ 2.0817, -1.9503],
        [ 1.5531, -1.5035],
        [ 1.5293, -1.4899],
        [ 1.7172, -1.6418],
        [ 2.2485, -2.0863],
        [ 1.8297, -1.7403],
        [ 1.4623, -1.4338],
        [ 1.9154, -1.8085],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0722, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1915, -1.2118],
        [ 1.2301, -1.2445],
        [ 1.3515, -1.3451],
        [ 1.3243, -1.3269],
        [ 1.1728, -1.2018],
        [ 1.0701, -1.1212],
        [ 1.4858, -1.4646],
        [ 1.5504, -1.5166],
        [ 1.2525, -1.2651],
        [ 1.1488, -1.1772],
        [ 1.3603, -1.3474],
        [ 1.2524, -1.2611],
        [ 1.2823, -1.2887],
        [ 1.4621, -1.4369],
        [ 1.3665, -1.3569],
        [ 1.3023, -1.3048],
        [ 1.2730, -1.2830],
        [ 1.0229, -1.0806],
        [ 1.2500, -1.2589],
        [ 1.2164, -1.2375],
        [ 1.3356, -1.3311],
        [ 1.3304, -1.3319],
        [ 1.3343, -1.3258],
        [ 1.2165, -1.2206],
        [ 1.2134, -1.2291],
        [ 1.1101, -1.1471],
        [ 1.1934, -1.2099],
        [ 1.1823, -1.2142],
        [ 1.3827, -1.3783],
        [ 1.2387, -1.2543],
        [ 1.1958, -1.2183],
        [ 1.2538, -1.2605],
        [ 1.6

Output: tensor([[ 2.4723, -2.2161],
        [ 2.2868, -2.0660],
        [ 2.1850, -1.9792],
        [ 2.3342, -2.0992],
        [ 1.9961, -1.8248],
        [ 2.2336, -2.0226],
        [ 2.6797, -2.3942],
        [ 2.7987, -2.4913],
        [ 2.8890, -2.5676],
        [ 2.6169, -2.3388],
        [ 2.1835, -1.9900],
        [ 2.6449, -2.3600],
        [ 2.5266, -2.2633],
        [ 2.5268, -2.2671],
        [ 2.3446, -2.1094],
        [ 2.9643, -2.6378],
        [ 2.3713, -2.1316],
        [ 2.3396, -2.1035],
        [ 2.3287, -2.1049],
        [ 3.0579, -2.7118],
        [ 2.1502, -1.9490],
        [ 2.2914, -2.0651],
        [ 2.6011, -2.3249],
        [ 2.2990, -2.0702],
        [ 2.1844, -1.9765],
        [ 2.7437, -2.4482],
        [ 2.3518, -2.1220],
        [ 2.3123, -2.0841],
        [ 2.3070, -2.0835],
        [ 2.2923, -2.0679],
        [ 2.9717, -2.6395],
        [ 2.5960, -2.3215],
        [ 2.4575, -2.2187],
        [ 2.4762, -2.2219],
        [ 2.3621, -2.1307],
        [ 2.

Output: tensor([[ 1.8561, -1.7382],
        [ 1.7589, -1.6560],
        [ 1.9084, -1.7847],
        [ 1.7072, -1.6090],
        [ 1.6236, -1.5441],
        [ 2.2544, -2.0784],
        [ 2.3114, -2.1252],
        [ 1.6492, -1.5652],
        [ 1.9045, -1.7777],
        [ 1.5508, -1.4852],
        [ 1.7642, -1.6623],
        [ 1.9703, -1.8333],
        [ 1.7671, -1.6608],
        [ 1.6042, -1.5282],
        [ 1.8399, -1.7280],
        [ 1.6394, -1.5558],
        [ 1.7632, -1.6588],
        [ 2.1235, -1.9655],
        [ 1.7061, -1.6108],
        [ 1.7329, -1.6368],
        [ 1.7958, -1.6846],
        [ 1.8355, -1.7187],
        [ 1.7744, -1.6696],
        [ 2.1400, -1.9775],
        [ 1.6186, -1.5382],
        [ 1.7807, -1.6746],
        [ 1.8527, -1.7409],
        [ 1.6621, -1.5771],
        [ 1.9548, -1.8237],
        [ 1.7775, -1.6700],
        [ 1.7072, -1.6117],
        [ 1.7941, -1.6859],
        [ 1.7889, -1.6810],
        [ 1.7567, -1.6541],
        [ 1.7847, -1.6755],
        [ 1.

Output: tensor([[ 1.7080, -1.6409],
        [ 1.4006, -1.3841],
        [ 1.5630, -1.5185],
        [ 1.5766, -1.5263],
        [ 1.5621, -1.5259],
        [ 1.4506, -1.4221],
        [ 2.2511, -2.0974],
        [ 1.7043, -1.6422],
        [ 1.9776, -1.8646],
        [ 1.6147, -1.5618],
        [ 1.7529, -1.6759],
        [ 1.6010, -1.5493],
        [ 1.5024, -1.4781],
        [ 1.4459, -1.4271],
        [ 1.4061, -1.3913],
        [ 1.5238, -1.4846],
        [ 1.6695, -1.6047],
        [ 1.4882, -1.4552],
        [ 1.3966, -1.3886],
        [ 1.5608, -1.5120],
        [ 1.5726, -1.5233],
        [ 1.4552, -1.4407],
        [ 1.6345, -1.5768],
        [ 1.5890, -1.5370],
        [ 1.7058, -1.6412],
        [ 1.5828, -1.5316],
        [ 1.4783, -1.4479],
        [ 2.0447, -1.9266],
        [ 1.5275, -1.4832],
        [ 1.4472, -1.4211],
        [ 1.6748, -1.6091],
        [ 2.2195, -2.0712],
        [ 1.7663, -1.6920],
        [ 1.3749, -1.3607],
        [ 1.8512, -1.7598],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0749, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1506, -1.1738],
        [ 1.2154, -1.2272],
        [ 1.3487, -1.3410],
        [ 1.3050, -1.3071],
        [ 1.1463, -1.1758],
        [ 0.9894, -1.0455],
        [ 1.4731, -1.4531],
        [ 1.5588, -1.5221],
        [ 1.2457, -1.2541],
        [ 1.1069, -1.1377],
        [ 1.3843, -1.3674],
        [ 1.2572, -1.2616],
        [ 1.2796, -1.2815],
        [ 1.4849, -1.4588],
        [ 1.3620, -1.3511],
        [ 1.3215, -1.3176],
        [ 1.2560, -1.2632],
        [ 0.9506, -1.0126],
        [ 1.2457, -1.2527],
        [ 1.1834, -1.2039],
        [ 1.3382, -1.3300],
        [ 1.3027, -1.3065],
        [ 1.3563, -1.3438],
        [ 1.2874, -1.2792],
        [ 1.1955, -1.2096],
        [ 1.0568, -1.0984],
        [ 1.1717, -1.1883],
        [ 1.1232, -1.1575],
        [ 1.3729, -1.3678],
        [ 1.2189, -1.2316],
        [ 1.1755, -1.1954],
        [ 1.2484, -1.2528],
        [ 1.6

Output: tensor([[ 2.5265, -2.2892],
        [ 2.3280, -2.1249],
        [ 2.2077, -2.0216],
        [ 2.4033, -2.1833],
        [ 1.9391, -1.7977],
        [ 2.2247, -2.0390],
        [ 2.7542, -2.4859],
        [ 2.8451, -2.5621],
        [ 3.0196, -2.7091],
        [ 2.6461, -2.3928],
        [ 2.1804, -2.0089],
        [ 2.7063, -2.4415],
        [ 2.5850, -2.3401],
        [ 2.5800, -2.3395],
        [ 2.3871, -2.1709],
        [ 3.0707, -2.7594],
        [ 2.4206, -2.1991],
        [ 2.3922, -2.1736],
        [ 2.3648, -2.1595],
        [ 3.1278, -2.8044],
        [ 2.1655, -1.9850],
        [ 2.3305, -2.1231],
        [ 2.6092, -2.3613],
        [ 2.3371, -2.1276],
        [ 2.1594, -1.9794],
        [ 2.7892, -2.5167],
        [ 2.3709, -2.1630],
        [ 2.3192, -2.1155],
        [ 2.2999, -2.1023],
        [ 2.3605, -2.1498],
        [ 3.0397, -2.7301],
        [ 2.6322, -2.3809],
        [ 2.4848, -2.2667],
        [ 2.5460, -2.3076],
        [ 2.4233, -2.2072],
        [ 2.

Output: tensor([[ 1.8746, -1.7639],
        [ 1.7612, -1.6670],
        [ 1.8931, -1.7831],
        [ 1.7010, -1.6138],
        [ 1.5586, -1.4974],
        [ 2.2432, -2.0831],
        [ 2.3368, -2.1614],
        [ 1.6352, -1.5616],
        [ 1.9268, -1.8079],
        [ 1.4943, -1.4449],
        [ 1.7766, -1.6822],
        [ 2.0039, -1.8733],
        [ 1.7403, -1.6483],
        [ 1.5551, -1.4950],
        [ 1.8588, -1.7547],
        [ 1.6102, -1.5393],
        [ 1.7705, -1.6746],
        [ 2.1424, -1.9943],
        [ 1.7121, -1.6251],
        [ 1.7191, -1.6346],
        [ 1.7826, -1.6840],
        [ 1.8624, -1.7520],
        [ 1.7774, -1.6818],
        [ 2.1842, -2.0283],
        [ 1.5868, -1.5194],
        [ 1.7820, -1.6851],
        [ 1.8310, -1.7329],
        [ 1.6266, -1.5553],
        [ 1.9964, -1.8710],
        [ 1.7821, -1.6837],
        [ 1.7059, -1.6197],
        [ 1.7983, -1.6989],
        [ 1.7966, -1.6971],
        [ 1.7590, -1.6653],
        [ 1.8013, -1.6993],
        [ 1.

Output: tensor([[ 1.6391, -1.5844],
        [ 1.3130, -1.3088],
        [ 1.5213, -1.4832],
        [ 1.5400, -1.4956],
        [ 1.5114, -1.4820],
        [ 1.3971, -1.3762],
        [ 2.2309, -2.0876],
        [ 1.6805, -1.6252],
        [ 1.9755, -1.8677],
        [ 1.5865, -1.5384],
        [ 1.7205, -1.6507],
        [ 1.5704, -1.5237],
        [ 1.4244, -1.4101],
        [ 1.3877, -1.3759],
        [ 1.3271, -1.3220],
        [ 1.4781, -1.4454],
        [ 1.6372, -1.5788],
        [ 1.4095, -1.3883],
        [ 1.3425, -1.3405],
        [ 1.5204, -1.4780],
        [ 1.5341, -1.4908],
        [ 1.3880, -1.3813],
        [ 1.5643, -1.5185],
        [ 1.5514, -1.5055],
        [ 1.6590, -1.6031],
        [ 1.5537, -1.5073],
        [ 1.3769, -1.3612],
        [ 2.0095, -1.9017],
        [ 1.5071, -1.4664],
        [ 1.3804, -1.3634],
        [ 1.6429, -1.5836],
        [ 2.1932, -2.0557],
        [ 1.7058, -1.6433],
        [ 1.2947, -1.2913],
        [ 1.7948, -1.7153],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0755, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1286, -1.1528],
        [ 1.2211, -1.2298],
        [ 1.3608, -1.3509],
        [ 1.3048, -1.3052],
        [ 1.1412, -1.1689],
        [ 0.9314, -0.9910],
        [ 1.4720, -1.4523],
        [ 1.5792, -1.5401],
        [ 1.2574, -1.2618],
        [ 1.0854, -1.1167],
        [ 1.4239, -1.4014],
        [ 1.2807, -1.2802],
        [ 1.2967, -1.2942],
        [ 1.5073, -1.4803],
        [ 1.3712, -1.3587],
        [ 1.3584, -1.3479],
        [ 1.2591, -1.2634],
        [ 0.9017, -0.9657],
        [ 1.2598, -1.2636],
        [ 1.1701, -1.1895],
        [ 1.3571, -1.3452],
        [ 1.2893, -1.2941],
        [ 1.3952, -1.3774],
        [ 1.3552, -1.3380],
        [ 1.1971, -1.2088],
        [ 1.0247, -1.0682],
        [ 1.1701, -1.1853],
        [ 1.0819, -1.1184],
        [ 1.3747, -1.3686],
        [ 1.2203, -1.2301],
        [ 1.1748, -1.1918],
        [ 1.2586, -1.2602],
        [ 1.7

Output: tensor([[ 2.5759, -2.3516],
        [ 2.3702, -2.1787],
        [ 2.2342, -2.0608],
        [ 2.4597, -2.2508],
        [ 1.8849, -1.7660],
        [ 2.2146, -2.0474],
        [ 2.8256, -2.5685],
        [ 2.8811, -2.6160],
        [ 3.1471, -2.8419],
        [ 2.6672, -2.4322],
        [ 2.1900, -2.0324],
        [ 2.7602, -2.5094],
        [ 2.6372, -2.4050],
        [ 2.6274, -2.4002],
        [ 2.4267, -2.2236],
        [ 3.1734, -2.8710],
        [ 2.4669, -2.2578],
        [ 2.4425, -2.2356],
        [ 2.4053, -2.2118],
        [ 3.1956, -2.8873],
        [ 2.1795, -2.0136],
        [ 2.3678, -2.1732],
        [ 2.6089, -2.3824],
        [ 2.3723, -2.1762],
        [ 2.1315, -1.9727],
        [ 2.8292, -2.5731],
        [ 2.3927, -2.1996],
        [ 2.3228, -2.1370],
        [ 2.2946, -2.1155],
        [ 2.4254, -2.2234],
        [ 3.0965, -2.8032],
        [ 2.6602, -2.4260],
        [ 2.5182, -2.3136],
        [ 2.6100, -2.3822],
        [ 2.4830, -2.2769],
        [ 2.

Output: tensor([[ 1.8960, -1.7904],
        [ 1.7646, -1.6775],
        [ 1.8796, -1.7805],
        [ 1.6970, -1.6181],
        [ 1.5016, -1.4546],
        [ 2.2249, -2.0789],
        [ 2.3535, -2.1877],
        [ 1.6271, -1.5611],
        [ 1.9502, -1.8369],
        [ 1.4510, -1.4132],
        [ 1.7929, -1.7037],
        [ 2.0441, -1.9170],
        [ 1.7187, -1.6376],
        [ 1.5181, -1.4694],
        [ 1.8793, -1.7809],
        [ 1.5859, -1.5249],
        [ 1.7772, -1.6880],
        [ 2.1567, -2.0169],
        [ 1.7190, -1.6383],
        [ 1.7110, -1.6349],
        [ 1.7691, -1.6807],
        [ 1.8900, -1.7841],
        [ 1.7801, -1.6917],
        [ 2.2319, -2.0799],
        [ 1.5623, -1.5045],
        [ 1.7837, -1.6942],
        [ 1.8130, -1.7255],
        [ 1.5944, -1.5342],
        [ 2.0463, -1.9232],
        [ 1.7877, -1.6964],
        [ 1.7105, -1.6308],
        [ 1.8024, -1.7101],
        [ 1.8032, -1.7105],
        [ 1.7622, -1.6755],
        [ 1.8181, -1.7216],
        [ 1.

Output: tensor([[ 1.5767, -1.5343],
        [ 1.2373, -1.2449],
        [ 1.4818, -1.4517],
        [ 1.5057, -1.4688],
        [ 1.4688, -1.4470],
        [ 1.3496, -1.3372],
        [ 2.2051, -2.0735],
        [ 1.6529, -1.6054],
        [ 1.9726, -1.8715],
        [ 1.5615, -1.5198],
        [ 1.6892, -1.6280],
        [ 1.5419, -1.5019],
        [ 1.3596, -1.3548],
        [ 1.3355, -1.3319],
        [ 1.2550, -1.2603],
        [ 1.4353, -1.4106],
        [ 1.6059, -1.5555],
        [ 1.3365, -1.3277],
        [ 1.2996, -1.3039],
        [ 1.4839, -1.4492],
        [ 1.4987, -1.4629],
        [ 1.3343, -1.3355],
        [ 1.4959, -1.4628],
        [ 1.5163, -1.4781],
        [ 1.6152, -1.5689],
        [ 1.5265, -1.4867],
        [ 1.2818, -1.2811],
        [ 1.9719, -1.8756],
        [ 1.4851, -1.4502],
        [ 1.3185, -1.3115],
        [ 1.6126, -1.5611],
        [ 2.1627, -2.0373],
        [ 1.6488, -1.5984],
        [ 1.2241, -1.2315],
        [ 1.7377, -1.6711],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0747, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1199, -1.1437],
        [ 1.2389, -1.2438],
        [ 1.3838, -1.3704],
        [ 1.3186, -1.3159],
        [ 1.1527, -1.1765],
        [ 0.8895, -0.9508],
        [ 1.4824, -1.4612],
        [ 1.6064, -1.5643],
        [ 1.2807, -1.2804],
        [ 1.0792, -1.1092],
        [ 1.4759, -1.4466],
        [ 1.3160, -1.3098],
        [ 1.3241, -1.3169],
        [ 1.5367, -1.5066],
        [ 1.3887, -1.3739],
        [ 1.4034, -1.3863],
        [ 1.2725, -1.2736],
        [ 0.8733, -0.9367],
        [ 1.2865, -1.2859],
        [ 1.1697, -1.1871],
        [ 1.3848, -1.3689],
        [ 1.2877, -1.2917],
        [ 1.4455, -1.4212],
        [ 1.4250, -1.3986],
        [ 1.2101, -1.2186],
        [ 1.0106, -1.0533],
        [ 1.1821, -1.1944],
        [ 1.0537, -1.0912],
        [ 1.3868, -1.3783],
        [ 1.2330, -1.2394],
        [ 1.1865, -1.1998],
        [ 1.2784, -1.2766],
        [ 1.7

Output: tensor([[ 2.6160, -2.4021],
        [ 2.4093, -2.2263],
        [ 2.2593, -2.0958],
        [ 2.5014, -2.3023],
        [ 1.8321, -1.7313],
        [ 2.2015, -2.0496],
        [ 2.8883, -2.6399],
        [ 2.9054, -2.6553],
        [ 3.2640, -2.9623],
        [ 2.6784, -2.4586],
        [ 2.2057, -2.0578],
        [ 2.8030, -2.5638],
        [ 2.6812, -2.4593],
        [ 2.6669, -2.4503],
        [ 2.4584, -2.2659],
        [ 3.2676, -2.9717],
        [ 2.5059, -2.3065],
        [ 2.4843, -2.2868],
        [ 2.4449, -2.2600],
        [ 3.2591, -2.9619],
        [ 2.1897, -2.0354],
        [ 2.3986, -2.2142],
        [ 2.5993, -2.3907],
        [ 2.3998, -2.2145],
        [ 2.0995, -1.9583],
        [ 2.8614, -2.6185],
        [ 2.4127, -2.2309],
        [ 2.3188, -2.1479],
        [ 2.2885, -2.1239],
        [ 2.4828, -2.2874],
        [ 3.1393, -2.8597],
        [ 2.6774, -2.4577],
        [ 2.5533, -2.3582],
        [ 2.6648, -2.4454],
        [ 2.5344, -2.3358],
        [ 2.

Output: tensor([[ 1.9199, -1.8178],
        [ 1.7672, -1.6858],
        [ 1.8686, -1.7779],
        [ 1.6927, -1.6207],
        [ 1.4527, -1.4171],
        [ 2.2042, -2.0700],
        [ 2.3657, -2.2081],
        [ 1.6243, -1.5639],
        [ 1.9753, -1.8660],
        [ 1.4184, -1.3893],
        [ 1.8112, -1.7257],
        [ 2.0889, -1.9633],
        [ 1.7011, -1.6288],
        [ 1.4900, -1.4496],
        [ 1.9030, -1.8077],
        [ 1.5640, -1.5112],
        [ 1.7828, -1.6991],
        [ 2.1690, -2.0360],
        [ 1.7253, -1.6498],
        [ 1.7086, -1.6386],
        [ 1.7545, -1.6746],
        [ 1.9157, -1.8135],
        [ 1.7811, -1.6989],
        [ 2.2821, -2.1322],
        [ 1.5437, -1.4934],
        [ 1.7839, -1.7006],
        [ 1.8004, -1.7208],
        [ 1.5646, -1.5135],
        [ 2.1008, -1.9783],
        [ 1.7933, -1.7076],
        [ 1.7198, -1.6446],
        [ 1.8049, -1.7187],
        [ 1.8081, -1.7211],
        [ 1.7646, -1.6837],
        [ 1.8348, -1.7426],
        [ 1.

Output: tensor([[ 1.5293, -1.4961],
        [ 1.1827, -1.1982],
        [ 1.4506, -1.4267],
        [ 1.4782, -1.4475],
        [ 1.4401, -1.4234],
        [ 1.3141, -1.3080],
        [ 2.1837, -2.0619],
        [ 1.6359, -1.5941],
        [ 1.9766, -1.8804],
        [ 1.5452, -1.5083],
        [ 1.6657, -1.6113],
        [ 1.5205, -1.4859],
        [ 1.3130, -1.3149],
        [ 1.2946, -1.2968],
        [ 1.1945, -1.2083],
        [ 1.3999, -1.3818],
        [ 1.5816, -1.5375],
        [ 1.2771, -1.2777],
        [ 1.2735, -1.2814],
        [ 1.4569, -1.4281],
        [ 1.4706, -1.4410],
        [ 1.3009, -1.3066],
        [ 1.4369, -1.4142],
        [ 1.4881, -1.4562],
        [ 1.5827, -1.5434],
        [ 1.5057, -1.4713],
        [ 1.2004, -1.2118],
        [ 1.9418, -1.8546],
        [ 1.4674, -1.4375],
        [ 1.2668, -1.2678],
        [ 1.5902, -1.5448],
        [ 2.1380, -2.0224],
        [ 1.6042, -1.5630],
        [ 1.1703, -1.1855],
        [ 1.6879, -1.6320],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0733, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1166, -1.1397],
        [ 1.2603, -1.2616],
        [ 1.4101, -1.3933],
        [ 1.3393, -1.3331],
        [ 1.1714, -1.1910],
        [ 0.8570, -0.9192],
        [ 1.5001, -1.4763],
        [ 1.6346, -1.5897],
        [ 1.3070, -1.3026],
        [ 1.0807, -1.1088],
        [ 1.5318, -1.4958],
        [ 1.3537, -1.3424],
        [ 1.3527, -1.3415],
        [ 1.5696, -1.5361],
        [ 1.4075, -1.3906],
        [ 1.4466, -1.4240],
        [ 1.2882, -1.2866],
        [ 0.8581, -0.9199],
        [ 1.3166, -1.3118],
        [ 1.1751, -1.1902],
        [ 1.4135, -1.3940],
        [ 1.2919, -1.2947],
        [ 1.4977, -1.4674],
        [ 1.4909, -1.4579],
        [ 1.2264, -1.2320],
        [ 1.0063, -1.0474],
        [ 1.1993, -1.2087],
        [ 1.0334, -1.0712],
        [ 1.4031, -1.3919],
        [ 1.2486, -1.2520],
        [ 1.2032, -1.2131],
        [ 1.3006, -1.2957],
        [ 1.7

Output: tensor([[ 2.6457, -2.4410],
        [ 2.4429, -2.2670],
        [ 2.2807, -2.1250],
        [ 2.5284, -2.3388],
        [ 1.7820, -1.6962],
        [ 2.1866, -2.0471],
        [ 2.9401, -2.6993],
        [ 2.9201, -2.6828],
        [ 3.3662, -3.0674],
        [ 2.6811, -2.4746],
        [ 2.2226, -2.0822],
        [ 2.8344, -2.6053],
        [ 2.7173, -2.5040],
        [ 2.6979, -2.4901],
        [ 2.4818, -2.2985],
        [ 3.3491, -3.0583],
        [ 2.5362, -2.3454],
        [ 2.5164, -2.3273],
        [ 2.4804, -2.3025],
        [ 3.3165, -3.0280],
        [ 2.1958, -2.0513],
        [ 2.4220, -2.2463],
        [ 2.5829, -2.3897],
        [ 2.4194, -2.2435],
        [ 2.0653, -1.9391],
        [ 2.8849, -2.6534],
        [ 2.4282, -2.2559],
        [ 2.3076, -2.1495],
        [ 2.2807, -2.1281],
        [ 2.5313, -2.3416],
        [ 3.1698, -2.9020],
        [ 2.6840, -2.4771],
        [ 2.5859, -2.3983],
        [ 2.7088, -2.4969],
        [ 2.5733, -2.3825],
        [ 2.

Output: tensor([[ 1.9435, -1.8440],
        [ 1.7688, -1.6923],
        [ 1.8587, -1.7747],
        [ 1.6862, -1.6202],
        [ 1.4144, -1.3872],
        [ 2.1822, -2.0579],
        [ 2.3727, -2.2220],
        [ 1.6252, -1.5688],
        [ 1.9988, -1.8927],
        [ 1.3962, -1.3728],
        [ 1.8286, -1.7459],
        [ 2.1330, -2.0083],
        [ 1.6863, -1.6211],
        [ 1.4696, -1.4356],
        [ 1.9261, -1.8337],
        [ 1.5468, -1.5003],
        [ 1.7863, -1.7073],
        [ 2.1787, -2.0512],
        [ 1.7298, -1.6586],
        [ 1.7091, -1.6437],
        [ 1.7387, -1.6664],
        [ 1.9367, -1.8378],
        [ 1.7801, -1.7032],
        [ 2.3316, -2.1830],
        [ 1.5314, -1.4868],
        [ 1.7832, -1.7051],
        [ 1.7927, -1.7190],
        [ 1.5369, -1.4934],
        [ 2.1521, -2.0297],
        [ 1.7982, -1.7171],
        [ 1.7292, -1.6576],
        [ 1.8060, -1.7248],
        [ 1.8112, -1.7291],
        [ 1.7662, -1.6902],
        [ 1.8498, -1.7611],
        [ 1.

Output: tensor([[ 1.4998, -1.4721],
        [ 1.1487, -1.1686],
        [ 1.4294, -1.4092],
        [ 1.4599, -1.4329],
        [ 1.4267, -1.4119],
        [ 1.2924, -1.2897],
        [ 2.1702, -2.0549],
        [ 1.6311, -1.5931],
        [ 1.9889, -1.8949],
        [ 1.5387, -1.5041],
        [ 1.6522, -1.6018],
        [ 1.5072, -1.4756],
        [ 1.2863, -1.2906],
        [ 1.2674, -1.2726],
        [ 1.1487, -1.1676],
        [ 1.3744, -1.3602],
        [ 1.5662, -1.5260],
        [ 1.2335, -1.2401],
        [ 1.2657, -1.2736],
        [ 1.4398, -1.4146],
        [ 1.4521, -1.4261],
        [ 1.2892, -1.2954],
        [ 1.3904, -1.3748],
        [ 1.4692, -1.4412],
        [ 1.5628, -1.5275],
        [ 1.4935, -1.4623],
        [ 1.1359, -1.1555],
        [ 1.9223, -1.8408],
        [ 1.4565, -1.4297],
        [ 1.2276, -1.2336],
        [ 1.5775, -1.5358],
        [ 2.1225, -2.0136],
        [ 1.5750, -1.5392],
        [ 1.1349, -1.1545],
        [ 1.6492, -1.6007],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0721, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1131, -1.1358],
        [ 1.2783, -1.2774],
        [ 1.4336, -1.4142],
        [ 1.3595, -1.3507],
        [ 1.1884, -1.2048],
        [ 0.8282, -0.8916],
        [ 1.5177, -1.4920],
        [ 1.6581, -1.6113],
        [ 1.3296, -1.3225],
        [ 1.0833, -1.1102],
        [ 1.5817, -1.5411],
        [ 1.3866, -1.3716],
        [ 1.3764, -1.3626],
        [ 1.5998, -1.5635],
        [ 1.4211, -1.4033],
        [ 1.4799, -1.4541],
        [ 1.3003, -1.2972],
        [ 0.8486, -0.9089],
        [ 1.3425, -1.3348],
        [ 1.1800, -1.1938],
        [ 1.4371, -1.4154],
        [ 1.2959, -1.2978],
        [ 1.5440, -1.5091],
        [ 1.5492, -1.5100],
        [ 1.2394, -1.2434],
        [ 1.0049, -1.0446],
        [ 1.2145, -1.2219],
        [ 1.0151, -1.0538],
        [ 1.4168, -1.4038],
        [ 1.2608, -1.2626],
        [ 1.2183, -1.2257],
        [ 1.3190, -1.3121],
        [ 1.8

Output: tensor([[ 2.6685, -2.4711],
        [ 2.4726, -2.3017],
        [ 2.2996, -2.1497],
        [ 2.5457, -2.3635],
        [ 1.7367, -1.6625],
        [ 2.1720, -2.0421],
        [ 2.9842, -2.7487],
        [ 2.9304, -2.7030],
        [ 3.4545, -3.1575],
        [ 2.6802, -2.4840],
        [ 2.2396, -2.1041],
        [ 2.8587, -2.6376],
        [ 2.7487, -2.5417],
        [ 2.7237, -2.5226],
        [ 2.4989, -2.3228],
        [ 3.4215, -3.1338],
        [ 2.5603, -2.3761],
        [ 2.5404, -2.3579],
        [ 2.5123, -2.3392],
        [ 3.3703, -3.0876],
        [ 2.2001, -2.0629],
        [ 2.4402, -2.2713],
        [ 2.5645, -2.3834],
        [ 2.4335, -2.2650],
        [ 2.0324, -1.9180],
        [ 2.9042, -2.6809],
        [ 2.4414, -2.2761],
        [ 2.2927, -2.1452],
        [ 2.2736, -2.1302],
        [ 2.5728, -2.3871],
        [ 3.1929, -2.9341],
        [ 2.6852, -2.4885],
        [ 2.6164, -2.4337],
        [ 2.7449, -2.5385],
        [ 2.6038, -2.4177],
        [ 2.

Output: tensor([[ 1.9672, -1.8692],
        [ 1.7697, -1.6970],
        [ 1.8519, -1.7726],
        [ 1.6800, -1.6182],
        [ 1.3841, -1.3630],
        [ 2.1616, -2.0455],
        [ 2.3780, -2.2330],
        [ 1.6297, -1.5757],
        [ 2.0229, -1.9183],
        [ 1.3818, -1.3620],
        [ 1.8463, -1.7652],
        [ 2.1765, -2.0515],
        [ 1.6751, -1.6146],
        [ 1.4557, -1.4258],
        [ 1.9513, -1.8596],
        [ 1.5310, -1.4895],
        [ 1.7900, -1.7143],
        [ 2.1879, -2.0648],
        [ 1.7336, -1.6655],
        [ 1.7144, -1.6512],
        [ 1.7234, -1.6567],
        [ 1.9547, -1.8580],
        [ 1.7780, -1.7053],
        [ 2.3791, -2.2306],
        [ 1.5230, -1.4821],
        [ 1.7817, -1.7078],
        [ 1.7906, -1.7203],
        [ 1.5121, -1.4743],
        [ 2.2027, -2.0782],
        [ 1.8031, -1.7254],
        [ 1.7409, -1.6711],
        [ 1.8062, -1.7292],
        [ 1.8137, -1.7353],
        [ 1.7674, -1.6951],
        [ 1.8652, -1.7787],
        [ 1.

Output: tensor([[ 1.4877, -1.4566],
        [ 1.1361, -1.1508],
        [ 1.4195, -1.3955],
        [ 1.4517, -1.4218],
        [ 1.4285, -1.4067],
        [ 1.2850, -1.2781],
        [ 2.1658, -2.0494],
        [ 1.6421, -1.5970],
        [ 2.0098, -1.9108],
        [ 1.5431, -1.5034],
        [ 1.6497, -1.5962],
        [ 1.5030, -1.4674],
        [ 1.2774, -1.2756],
        [ 1.2541, -1.2539],
        [ 1.1174, -1.1337],
        [ 1.3595, -1.3423],
        [ 1.5610, -1.5179],
        [ 1.2060, -1.2107],
        [ 1.2748, -1.2736],
        [ 1.4339, -1.4054],
        [ 1.4436, -1.4147],
        [ 1.2981, -1.2950],
        [ 1.3574, -1.3416],
        [ 1.4606, -1.4297],
        [ 1.5560, -1.5167],
        [ 1.4908, -1.4560],
        [ 1.0886, -1.1083],
        [ 1.9147, -1.8308],
        [ 1.4542, -1.4240],
        [ 1.2011, -1.2048],
        [ 1.5755, -1.5305],
        [ 2.1173, -2.0069],
        [ 1.5614, -1.5226],
        [ 1.1183, -1.1338],
        [ 1.6224, -1.5738],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0707, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1119, -1.1352],
        [ 1.2955, -1.2936],
        [ 1.4571, -1.4367],
        [ 1.3807, -1.3707],
        [ 1.2057, -1.2210],
        [ 0.8045, -0.8703],
        [ 1.5380, -1.5114],
        [ 1.6804, -1.6326],
        [ 1.3512, -1.3425],
        [ 1.0887, -1.1156],
        [ 1.6297, -1.5850],
        [ 1.4172, -1.4003],
        [ 1.3975, -1.3825],
        [ 1.6313, -1.5935],
        [ 1.4336, -1.4158],
        [ 1.5078, -1.4800],
        [ 1.3117, -1.3082],
        [ 0.8446, -0.9057],
        [ 1.3663, -1.3579],
        [ 1.1865, -1.2000],
        [ 1.4592, -1.4363],
        [ 1.3025, -1.3042],
        [ 1.5871, -1.5497],
        [ 1.6050, -1.5617],
        [ 1.2522, -1.2556],
        [ 1.0074, -1.0472],
        [ 1.2292, -1.2362],
        [ 1.0016, -1.0414],
        [ 1.4304, -1.4167],
        [ 1.2724, -1.2736],
        [ 1.2339, -1.2403],
        [ 1.3370, -1.3293],
        [ 1.8

Output: tensor([[ 2.6853, -2.4961],
        [ 2.4993, -2.3333],
        [ 2.3135, -2.1702],
        [ 2.5559, -2.3830],
        [ 1.6882, -1.6263],
        [ 2.1571, -2.0368],
        [ 3.0253, -2.7954],
        [ 2.9409, -2.7232],
        [ 3.5358, -3.2407],
        [ 2.6789, -2.4930],
        [ 2.2567, -2.1243],
        [ 2.8781, -2.6657],
        [ 2.7778, -2.5774],
        [ 2.7454, -2.5509],
        [ 2.5100, -2.3420],
        [ 3.4955, -3.2100],
        [ 2.5786, -2.4018],
        [ 2.5561, -2.3817],
        [ 2.5426, -2.3728],
        [ 3.4251, -3.1472],
        [ 2.1999, -2.0710],
        [ 2.4535, -2.2923],
        [ 2.5475, -2.3790],
        [ 2.4423, -2.2823],
        [ 1.9992, -1.8974],
        [ 2.9234, -2.7087],
        [ 2.4545, -2.2957],
        [ 2.2761, -2.1394],
        [ 2.2668, -2.1323],
        [ 2.6108, -2.4296],
        [ 3.2166, -2.9665],
        [ 2.6843, -2.4984],
        [ 2.6479, -2.4674],
        [ 2.7755, -2.5755],
        [ 2.6316, -2.4514],
        [ 2.

Output: tensor([[ 1.9874, -1.8951],
        [ 1.7675, -1.7018],
        [ 1.8436, -1.7760],
        [ 1.6674, -1.6219],
        [ 1.3545, -1.3481],
        [ 2.1441, -2.0340],
        [ 2.3834, -2.2420],
        [ 1.6319, -1.5873],
        [ 2.0420, -1.9459],
        [ 1.3655, -1.3619],
        [ 1.8587, -1.7858],
        [ 2.2146, -2.0957],
        [ 1.6609, -1.6156],
        [ 1.4365, -1.4266],
        [ 1.9672, -1.8902],
        [ 1.5126, -1.4813],
        [ 1.7917, -1.7241],
        [ 2.1974, -2.0783],
        [ 1.7329, -1.6743],
        [ 1.7151, -1.6672],
        [ 1.7050, -1.6516],
        [ 1.9671, -1.8777],
        [ 1.7733, -1.7077],
        [ 2.4182, -2.2771],
        [ 1.5120, -1.4839],
        [ 1.7779, -1.7109],
        [ 1.7839, -1.7282],
        [ 1.4869, -1.4600],
        [ 2.2386, -2.1314],
        [ 1.8049, -1.7353],
        [ 1.7498, -1.6900],
        [ 1.8041, -1.7335],
        [ 1.8142, -1.7423],
        [ 1.7657, -1.7005],
        [ 1.8770, -1.7982],
        [ 1.

Output: tensor([[ 1.4780, -1.4532],
        [ 1.1255, -1.1463],
        [ 1.4118, -1.3929],
        [ 1.4439, -1.4191],
        [ 1.4294, -1.4119],
        [ 1.2768, -1.2747],
        [ 2.1639, -2.0549],
        [ 1.6487, -1.6111],
        [ 2.0309, -1.9360],
        [ 1.5453, -1.5102],
        [ 1.6486, -1.6003],
        [ 1.5006, -1.4695],
        [ 1.2696, -1.2732],
        [ 1.2405, -1.2462],
        [ 1.0879, -1.1111],
        [ 1.3458, -1.3340],
        [ 1.5570, -1.5189],
        [ 1.1813, -1.1929],
        [ 1.2794, -1.2823],
        [ 1.4299, -1.4059],
        [ 1.4359, -1.4119],
        [ 1.3065, -1.3068],
        [ 1.3287, -1.3204],
        [ 1.4526, -1.4268],
        [ 1.5535, -1.5194],
        [ 1.4881, -1.4578],
        [ 1.0451, -1.0734],
        [ 1.9104, -1.8332],
        [ 1.4502, -1.4247],
        [ 1.1757, -1.1855],
        [ 1.5742, -1.5340],
        [ 2.1154, -2.0123],
        [ 1.5503, -1.5177],
        [ 1.1020, -1.1231],
        [ 1.5997, -1.5588],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0707, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.1015, -1.1273],
        [ 1.3012, -1.3008],
        [ 1.4694, -1.4496],
        [ 1.3908, -1.3815],
        [ 1.2143, -1.2287],
        [ 0.7756, -0.8451],
        [ 1.5482, -1.5221],
        [ 1.6909, -1.6452],
        [ 1.3602, -1.3530],
        [ 1.0860, -1.1140],
        [ 1.6637, -1.6179],
        [ 1.4357, -1.4188],
        [ 1.4061, -1.3929],
        [ 1.6493, -1.6114],
        [ 1.4353, -1.4198],
        [ 1.5240, -1.4967],
        [ 1.3126, -1.3115],
        [ 0.8339, -0.8954],
        [ 1.3793, -1.3707],
        [ 1.1839, -1.1996],
        [ 1.4693, -1.4480],
        [ 1.2997, -1.3036],
        [ 1.6161, -1.5771],
        [ 1.6450, -1.5997],
        [ 1.2536, -1.2590],
        [ 1.0021, -1.0427],
        [ 1.2325, -1.2406],
        [ 0.9807, -1.0247],
        [ 1.4326, -1.4208],
        [ 1.2734, -1.2769],
        [ 1.2402, -1.2478],
        [ 1.3441, -1.3378],
        [ 1.8

Output: tensor([[ 2.6990, -2.5104],
        [ 2.5222, -2.3565],
        [ 2.3274, -2.1842],
        [ 2.5619, -2.3899],
        [ 1.6481, -1.5917],
        [ 2.1445, -2.0274],
        [ 3.0603, -2.8301],
        [ 2.9500, -2.7342],
        [ 3.6043, -3.3061],
        [ 2.6772, -2.4938],
        [ 2.2727, -2.1422],
        [ 2.8943, -2.6824],
        [ 2.8040, -2.6033],
        [ 2.7635, -2.5705],
        [ 2.5189, -2.3518],
        [ 3.5571, -3.2701],
        [ 2.5937, -2.4173],
        [ 2.5689, -2.3947],
        [ 2.5679, -2.3993],
        [ 3.4741, -3.1949],
        [ 2.2016, -2.0740],
        [ 2.4654, -2.3046],
        [ 2.5323, -2.3670],
        [ 2.4504, -2.2910],
        [ 1.9712, -1.8731],
        [ 2.9395, -2.7258],
        [ 2.4663, -2.3085],
        [ 2.2623, -2.1288],
        [ 2.2607, -2.1293],
        [ 2.6438, -2.4616],
        [ 3.2354, -2.9872],
        [ 2.6819, -2.4985],
        [ 2.6742, -2.4960],
        [ 2.8002, -2.6002],
        [ 2.6536, -2.4741],
        [ 2.

Output: tensor([[ 2.0078, -1.9148],
        [ 1.7665, -1.7028],
        [ 1.8404, -1.7720],
        [ 1.6594, -1.6137],
        [ 1.3339, -1.3274],
        [ 2.1275, -2.0222],
        [ 2.3866, -2.2486],
        [ 1.6387, -1.5931],
        [ 2.0619, -1.9643],
        [ 1.3597, -1.3524],
        [ 1.8726, -1.7993],
        [ 2.2520, -2.1291],
        [ 1.6523, -1.6066],
        [ 1.4272, -1.4134],
        [ 1.9866, -1.9048],
        [ 1.4983, -1.4690],
        [ 1.7953, -1.7284],
        [ 2.2065, -2.0891],
        [ 1.7340, -1.6760],
        [ 1.7228, -1.6721],
        [ 1.6908, -1.6380],
        [ 1.9787, -1.8897],
        [ 1.7695, -1.7062],
        [ 2.4556, -2.3112],
        [ 1.5074, -1.4784],
        [ 1.7754, -1.7106],
        [ 1.7831, -1.7263],
        [ 1.4674, -1.4418],
        [ 2.2761, -2.1631],
        [ 1.8088, -1.7397],
        [ 1.7633, -1.7009],
        [ 1.8027, -1.7344],
        [ 1.8158, -1.7455],
        [ 1.7654, -1.7020],
        [ 1.8903, -1.8111],
        [ 1.

Output: tensor([[ 1.4482, -1.4840],
        [ 1.0957, -1.1786],
        [ 1.3853, -1.4204],
        [ 1.4187, -1.4438],
        [ 1.4037, -1.4559],
        [ 1.2486, -1.3026],
        [ 2.1415, -2.0883],
        [ 1.6261, -1.6655],
        [ 2.0283, -1.9897],
        [ 1.5237, -1.5481],
        [ 1.6282, -1.6326],
        [ 1.4794, -1.5002],
        [ 1.2370, -1.3123],
        [ 1.2053, -1.2734],
        [ 1.0413, -1.1216],
        [ 1.3145, -1.3545],
        [ 1.5347, -1.5475],
        [ 1.1407, -1.2067],
        [ 1.2580, -1.3308],
        [ 1.4086, -1.4341],
        [ 1.4110, -1.4366],
        [ 1.2878, -1.3631],
        [ 1.2855, -1.3281],
        [ 1.4273, -1.4514],
        [ 1.5307, -1.5550],
        [ 1.4665, -1.4876],
        [ 0.9861, -1.0708],
        [ 1.8856, -1.8663],
        [ 1.4260, -1.4527],
        [ 1.1353, -1.1951],
        [ 1.5544, -1.5658],
        [ 2.0933, -2.0463],
        [ 1.5186, -1.5462],
        [ 1.0662, -1.1473],
        [ 1.5612, -1.5718],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0706, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.0922, -1.1214],
        [ 1.3061, -1.3076],
        [ 1.4807, -1.4616],
        [ 1.4012, -1.3907],
        [ 1.2216, -1.2382],
        [ 0.7524, -0.8219],
        [ 1.5602, -1.5301],
        [ 1.7020, -1.6537],
        [ 1.3686, -1.3618],
        [ 1.0848, -1.1145],
        [ 1.6929, -1.6506],
        [ 1.4517, -1.4372],
        [ 1.4130, -1.4016],
        [ 1.6651, -1.6278],
        [ 1.4367, -1.4230],
        [ 1.5364, -1.5123],
        [ 1.3143, -1.3137],
        [ 0.8265, -0.8877],
        [ 1.3895, -1.3841],
        [ 1.1835, -1.1983],
        [ 1.4783, -1.4583],
        [ 1.2991, -1.3020],
        [ 1.6394, -1.6036],
        [ 1.6785, -1.6375],
        [ 1.2543, -1.2621],
        [ 0.9986, -1.0418],
        [ 1.2343, -1.2455],
        [ 0.9651, -1.0076],
        [ 1.4360, -1.4216],
        [ 1.2751, -1.2796],
        [ 1.2478, -1.2551],
        [ 1.3504, -1.3458],
        [ 1.8

Output: tensor([[ 2.7061, -2.5228],
        [ 2.5402, -2.3776],
        [ 2.3369, -2.1975],
        [ 2.5609, -2.3951],
        [ 1.6099, -1.5613],
        [ 2.1311, -2.0200],
        [ 3.0878, -2.8612],
        [ 2.9545, -2.7449],
        [ 3.6612, -3.3643],
        [ 2.6718, -2.4950],
        [ 2.2878, -2.1594],
        [ 2.9036, -2.6973],
        [ 2.8247, -2.6280],
        [ 2.7764, -2.5878],
        [ 2.5222, -2.3603],
        [ 3.6101, -3.3247],
        [ 2.6024, -2.4309],
        [ 2.5747, -2.4057],
        [ 2.5890, -2.4232],
        [ 3.5158, -3.2395],
        [ 2.2006, -2.0778],
        [ 2.4722, -2.3161],
        [ 2.5143, -2.3570],
        [ 2.4538, -2.2995],
        [ 1.9431, -1.8527],
        [ 2.9501, -2.7415],
        [ 2.4746, -2.3210],
        [ 2.2468, -2.1197],
        [ 2.2524, -2.1269],
        [ 2.6710, -2.4913],
        [ 3.2483, -3.0060],
        [ 2.6749, -2.4986],
        [ 2.6981, -2.5220],
        [ 2.8172, -2.6214],
        [ 2.6698, -2.4940],
        [ 2.

Output: tensor([[ 2.0262, -1.9343],
        [ 1.7654, -1.7045],
        [ 1.8369, -1.7710],
        [ 1.6534, -1.6053],
        [ 1.3170, -1.3120],
        [ 2.1083, -2.0134],
        [ 2.3849, -2.2558],
        [ 1.6464, -1.6006],
        [ 2.0796, -1.9818],
        [ 1.3572, -1.3479],
        [ 1.8858, -1.8122],
        [ 2.2844, -2.1603],
        [ 1.6452, -1.5990],
        [ 1.4216, -1.4037],
        [ 2.0032, -1.9179],
        [ 1.4858, -1.4591],
        [ 1.7987, -1.7333],
        [ 2.2125, -2.1012],
        [ 1.7350, -1.6779],
        [ 1.7324, -1.6786],
        [ 1.6771, -1.6260],
        [ 1.9883, -1.9004],
        [ 1.7651, -1.7053],
        [ 2.4883, -2.3436],
        [ 1.5053, -1.4759],
        [ 1.7726, -1.7111],
        [ 1.7837, -1.7264],
        [ 1.4490, -1.4273],
        [ 2.3114, -2.1897],
        [ 1.8122, -1.7446],
        [ 1.7766, -1.7132],
        [ 1.8007, -1.7360],
        [ 1.8168, -1.7497],
        [ 1.7651, -1.7043],
        [ 1.9031, -1.8239],
        [ 1.

Output: tensor([[ 1.4783, -1.4655],
        [ 1.1289, -1.1588],
        [ 1.4108, -1.4033],
        [ 1.4412, -1.4281],
        [ 1.4466, -1.4405],
        [ 1.2753, -1.2846],
        [ 2.1696, -2.0754],
        [ 1.6783, -1.6505],
        [ 2.0784, -1.9927],
        [ 1.5581, -1.5343],
        [ 1.6577, -1.6206],
        [ 1.5078, -1.4874],
        [ 1.2765, -1.2906],
        [ 1.2305, -1.2489],
        [ 1.0507, -1.0879],
        [ 1.3325, -1.3337],
        [ 1.5606, -1.5341],
        [ 1.1535, -1.1777],
        [ 1.3063, -1.3177],
        [ 1.4352, -1.4217],
        [ 1.4337, -1.4213],
        [ 1.3461, -1.3541],
        [ 1.2908, -1.2972],
        [ 1.4497, -1.4357],
        [ 1.5644, -1.5408],
        [ 1.4938, -1.4747],
        [ 0.9814, -1.0268],
        [ 1.9148, -1.8512],
        [ 1.4511, -1.4373],
        [ 1.1432, -1.1670],
        [ 1.5835, -1.5543],
        [ 2.1222, -2.0333],
        [ 1.5451, -1.5258],
        [ 1.0906, -1.1232],
        [ 1.5706, -1.5450],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0712, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.0779, -1.1132],
        [ 1.3039, -1.3105],
        [ 1.4830, -1.4700],
        [ 1.4014, -1.3989],
        [ 1.2215, -1.2436],
        [ 0.7220, -0.8033],
        [ 1.5579, -1.5404],
        [ 1.7000, -1.6623],
        [ 1.3683, -1.3678],
        [ 1.0777, -1.1132],
        [ 1.7146, -1.6725],
        [ 1.4609, -1.4494],
        [ 1.4120, -1.4059],
        [ 1.6695, -1.6394],
        [ 1.4298, -1.4230],
        [ 1.5410, -1.5209],
        [ 1.3077, -1.3145],
        [ 0.8110, -0.8819],
        [ 1.3929, -1.3913],
        [ 1.1743, -1.1978],
        [ 1.4792, -1.4647],
        [ 1.2884, -1.3013],
        [ 1.6555, -1.6213],
        [ 1.7076, -1.6633],
        [ 1.2487, -1.2616],
        [ 0.9905, -1.0391],
        [ 1.2303, -1.2456],
        [ 0.9403, -0.9948],
        [ 1.4265, -1.4236],
        [ 1.2693, -1.2804],
        [ 1.2479, -1.2615],
        [ 1.3497, -1.3499],
        [ 1.9

Output: tensor([[ 2.7119, -2.5318],
        [ 2.5559, -2.3957],
        [ 2.3457, -2.2081],
        [ 2.5590, -2.3966],
        [ 1.5751, -1.5328],
        [ 2.1193, -2.0129],
        [ 3.1124, -2.8881],
        [ 2.9594, -2.7543],
        [ 3.7115, -3.4154],
        [ 2.6674, -2.4953],
        [ 2.3024, -2.1763],
        [ 2.9120, -2.7091],
        [ 2.8442, -2.6498],
        [ 2.7872, -2.6025],
        [ 2.5245, -2.3657],
        [ 3.6584, -3.3749],
        [ 2.6094, -2.4407],
        [ 2.5788, -2.4126],
        [ 2.6082, -2.4451],
        [ 3.5549, -3.2808],
        [ 2.1999, -2.0802],
        [ 2.4781, -2.3247],
        [ 2.4984, -2.3467],
        [ 2.4567, -2.3052],
        [ 1.9180, -1.8327],
        [ 2.9599, -2.7549],
        [ 2.4835, -2.3321],
        [ 2.2336, -2.1108],
        [ 2.2444, -2.1237],
        [ 2.6958, -2.5173],
        [ 3.2602, -3.0226],
        [ 2.6686, -2.4973],
        [ 2.7198, -2.5469],
        [ 2.8315, -2.6384],
        [ 2.6841, -2.5108],
        [ 2.

Output: tensor([[ 2.0419, -1.9536],
        [ 1.7630, -1.7068],
        [ 1.8328, -1.7719],
        [ 1.6450, -1.6011],
        [ 1.3008, -1.3007],
        [ 2.0906, -2.0041],
        [ 2.3830, -2.2615],
        [ 1.6527, -1.6103],
        [ 2.0942, -2.0000],
        [ 1.3549, -1.3487],
        [ 1.8962, -1.8263],
        [ 2.3125, -2.1910],
        [ 1.6364, -1.5948],
        [ 1.4147, -1.4001],
        [ 2.0154, -1.9337],
        [ 1.4722, -1.4508],
        [ 1.8010, -1.7397],
        [ 2.2179, -2.1123],
        [ 1.7341, -1.6813],
        [ 1.7402, -1.6896],
        [ 1.6629, -1.6172],
        [ 1.9951, -1.9114],
        [ 1.7598, -1.7053],
        [ 2.5146, -2.3739],
        [ 1.5018, -1.4763],
        [ 1.7688, -1.7123],
        [ 1.7827, -1.7300],
        [ 1.4319, -1.4156],
        [ 2.3397, -2.2204],
        [ 1.8139, -1.7505],
        [ 1.7886, -1.7277],
        [ 1.7977, -1.7380],
        [ 1.8166, -1.7542],
        [ 1.7634, -1.7073],
        [ 1.9137, -1.8378],
        [ 1.

Output: tensor([[ 1.4868, -1.4764],
        [ 1.1403, -1.1715],
        [ 1.4178, -1.4123],
        [ 1.4468, -1.4357],
        [ 1.4624, -1.4579],
        [ 1.2817, -1.2929],
        [ 2.1789, -2.0888],
        [ 1.7003, -1.6752],
        [ 2.1063, -2.0221],
        [ 1.5697, -1.5478],
        [ 1.6687, -1.6338],
        [ 1.5184, -1.4997],
        [ 1.2885, -1.3044],
        [ 1.2331, -1.2535],
        [ 1.0418, -1.0814],
        [ 1.3330, -1.3368],
        [ 1.5691, -1.5446],
        [ 1.1492, -1.1760],
        [ 1.3262, -1.3383],
        [ 1.4448, -1.4329],
        [ 1.4396, -1.4291],
        [ 1.3743, -1.3820],
        [ 1.2814, -1.2911],
        [ 1.4554, -1.4434],
        [ 1.5779, -1.5559],
        [ 1.5029, -1.4857],
        [ 0.9606, -1.0102],
        [ 1.9242, -1.8639],
        [ 1.4580, -1.4462],
        [ 1.1360, -1.1626],
        [ 1.5943, -1.5670],
        [ 2.1322, -2.0471],
        [ 1.5509, -1.5346],
        [ 1.0939, -1.1284],
        [ 1.5646, -1.5428],
        [ 1.

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=21.0, style=ProgressStyle(description_wid…

Batch:  torch.Size([90, 20, 375])
Loss:  tensor(0.0714, grad_fn=<NllLossBackward>)
Output: tensor([[ 1.0677, -1.1080],
        [ 1.3039, -1.3147],
        [ 1.4882, -1.4797],
        [ 1.4046, -1.4069],
        [ 1.2254, -1.2516],
        [ 0.7001, -0.7872],
        [ 1.5609, -1.5494],
        [ 1.7022, -1.6705],
        [ 1.3704, -1.3742],
        [ 1.0748, -1.1143],
        [ 1.7376, -1.6978],
        [ 1.4715, -1.4636],
        [ 1.4127, -1.4111],
        [ 1.6776, -1.6526],
        [ 1.4264, -1.4249],
        [ 1.5478, -1.5320],
        [ 1.3046, -1.3162],
        [ 0.8014, -0.8770],
        [ 1.3971, -1.3998],
        [ 1.1695, -1.1978],
        [ 1.4823, -1.4722],
        [ 1.2835, -1.3018],
        [ 1.6725, -1.6417],
        [ 1.7345, -1.6931],
        [ 1.2455, -1.2628],
        [ 0.9874, -1.0399],
        [ 1.2279, -1.2475],
        [ 0.9237, -0.9841],
        [ 1.4229, -1.4261],
        [ 1.2667, -1.2824],
        [ 1.2519, -1.2693],
        [ 1.3513, -1.3556],
        [ 1.9

Output: tensor([[ 2.7129, -2.5396],
        [ 2.5677, -2.4130],
        [ 2.3502, -2.2180],
        [ 2.5536, -2.3982],
        [ 1.5406, -1.5062],
        [ 2.1064, -2.0072],
        [ 3.1301, -2.9124],
        [ 2.9605, -2.7634],
        [ 3.7526, -3.4622],
        [ 2.6602, -2.4960],
        [ 2.3144, -2.1935],
        [ 2.9153, -2.7196],
        [ 2.8586, -2.6704],
        [ 2.7942, -2.6165],
        [ 2.5227, -2.3705],
        [ 3.7000, -3.4228],
        [ 2.6112, -2.4491],
        [ 2.5777, -2.4179],
        [ 2.6231, -2.4660],
        [ 3.5866, -3.3192],
        [ 2.1966, -2.0828],
        [ 2.4801, -2.3326],
        [ 2.4806, -2.3373],
        [ 2.4555, -2.3101],
        [ 1.8922, -1.8148],
        [ 2.9657, -2.7678],
        [ 2.4887, -2.3435],
        [ 2.2191, -2.1032],
        [ 2.2338, -2.1206],
        [ 2.7161, -2.5424],
        [ 3.2673, -3.0382],
        [ 2.6599, -2.4965],
        [ 2.7375, -2.5709],
        [ 2.8405, -2.6539],
        [ 2.6944, -2.5273],
        [ 2.

Output: tensor([[ 2.0560, -1.9711],
        [ 1.7607, -1.7089],
        [ 1.8296, -1.7719],
        [ 1.6381, -1.5947],
        [ 1.2875, -1.2889],
        [ 2.0733, -1.9967],
        [ 2.3798, -2.2677],
        [ 1.6599, -1.6188],
        [ 2.1074, -2.0157],
        [ 1.3553, -1.3475],
        [ 1.9062, -1.8384],
        [ 2.3383, -2.2186],
        [ 1.6287, -1.5885],
        [ 1.4111, -1.3939],
        [ 2.0276, -1.9451],
        [ 1.4600, -1.4423],
        [ 1.8037, -1.7455],
        [ 2.2221, -2.1235],
        [ 1.7334, -1.6834],
        [ 1.7496, -1.6975],
        [ 1.6511, -1.6081],
        [ 2.0009, -1.9208],
        [ 1.7549, -1.7051],
        [ 2.5367, -2.3995],
        [ 1.4996, -1.4750],
        [ 1.7653, -1.7135],
        [ 1.7828, -1.7314],
        [ 1.4172, -1.4044],
        [ 2.3666, -2.2441],
        [ 1.8156, -1.7554],
        [ 1.8007, -1.7404],
        [ 1.7946, -1.7398],
        [ 1.8163, -1.7584],
        [ 1.7619, -1.7099],
        [ 1.9238, -1.8502],
        [ 1.

KeyboardInterrupt: 

### Baseline

In [None]:
heldout = int(len(dataset) * 0.2)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - heldout, heldout])

model = Baseline(add_brnn=False)
config = dict(
    num_workers=8,
    batch_size=90,
    learning_rate=0.001,
    weight_decay=0.01,
    
    num_epochs=200,

#     num_epochs=200,
    is_notebook=True
)

train(model, train_dataset, config)

## Testing

In [None]:
y_true = dataset.labels[test_dataset.indices]
y_pred, test_acc = test(model, test_dataset, config)
print(len(test_dataset))
print(dataset.labels[test_dataset.indices].sum().item())

In [None]:
results = pd.DataFrame(classification_report(y_true, y_pred, zero_division=0, output_dict=True)).transpose()
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn+fp)
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred)
auc_score = metrics.auc(fpr, tpr)

In [None]:
display(results)
print("Specificity:", specificity)
print("AUC:", auc_score)

In [None]:
lr_auc = roc_auc_score(y_true, y_pred)
print('ROC AUC=%.3f' % (lr_auc))
lr_fpr, lr_tpr, _ = roc_curve(y_true, y_pred)
pyplot.plot(lr_fpr, lr_tpr, marker='.', label='Baseline model')

pyplot.xlabel('False Positive Rate')
pyplot.ylabel('True Positive Rate')
pyplot.legend()
pyplot.show()

In [None]:
pr_auc = average_precision_score(y_true, y_pred)
print("PR AUC:", specificity)