# setup

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os

from tqdm import tqdm

In [2]:
os.chdir('..') # notebooks folder

from data.load_data import LoadData
from configs.gate import LoadDataConfig
# from configs.baseline import LoadDataConfig
from utils import get_inputs

# init

In [3]:
loader_config = LoadDataConfig()

In [4]:
dataloader = LoadData(**loader_config.__dict__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data

In [5]:
train_dl = dataloader.get_train_dataloader()
val_dl = dataloader.get_val_dataloader()
test_dl = dataloader.get_test_dataloader()

In [6]:
train_dl.dataset_size, val_dl.dataset_size, test_dl.dataset_size

(293728, 34775, 17276)

In [7]:
for train_batch in (train_dl):
    raw, exam_id, label = train_batch
    ecg = get_inputs(raw).to(device)
    label = label.to(device).float()
    break

In [8]:
raw.shape, raw

(torch.Size([128, 4096, 12]),
 tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]

In [9]:
ecg.shape, ecg

(torch.Size([128, 12, 2560]),
 tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.9702e-01,
           -1.9544e-01, -2.1698e-01],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.5979e-01,
           -1.9544e-01, -2.0546e-01],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  3.7229e-02,
            2.8250e-19,  1.1518e-02],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -5.5427e-02,
           -6.8403e-02, -6.9504e-02],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.8855e-01,
           -1.9544e-01, -1.9638e-01],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.2524e-01,
           -1.3681e-01, -1.3644e-01]],
 
         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -9.5852e-02,
           -9.7718e-02, -1.0607e-01],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -8.4687e-02,
           -8.7947e-02, -9.6648e-02],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  1.1165e-02,
            9.7718e-03,  9.

In [10]:
exam_id.shape, exam_id

((128,),
 array([3172506,  824319,  652737, 3199595,   18148, 1463563, 1175532,
        2654895,   79264,  801107, 3211241, 1285238, 2743211, 1630791,
        1499926, 1463771, 1481569, 3169198,  989421,  467853,   86393,
        1539302,  122023,  140925,  532209,  617443,  898489,  197210,
        2825990,  432213, 2668984, 1370693,   66958,  686823, 1577156,
         728647,  216801,   89693,   45289, 3618947,  659567,  594210,
         141054, 1577303, 3186816,  476817,  810932,  252111,  989631,
         837762,  622994,  411018,  696855, 1178174, 1016170,  622876,
        1658897, 3167924, 2658317, 1630955, 1339754,   90606, 1446726,
         903665, 1140243,  446125, 4270982, 3204706,  188635,  383852,
        1504657, 1155899,  480556,  440161,  292542,  367715, 2516781,
        4203670,  785939,  652654,  165093,  964656, 1272654, 1668236,
        1266379,  956795, 1293393, 2888861,   15571, 2726021,  586972,
        2681431, 1426188, 1520863, 2754004, 3140535,  922947,  82657

In [11]:
label.shape, label

(torch.Size([128, 3]),
 tensor([[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0