# setup

In [12]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os

from tqdm import tqdm

In [13]:
os.chdir('..') # notebooks folder

from data.load_data import LoadData
from configs.gate import LoadDataConfig
# from configs.baseline import LoadDataConfig

# init

In [14]:
loader_config = LoadDataConfig()

In [15]:
dataloader = LoadData(**loader_config.__dict__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

signal_crop_len = 2560
signal_non_zero_start = 571

# get inputs

In [16]:
def get_inputs(batch, apply = "non_zero", device = "cuda"):
        # (B, C, L)
        if batch.shape[1] > batch.shape[2]:
            batch = batch.permute(0, 2, 1)

        B, n_leads, signal_len = batch.shape

        if apply == "non_zero":
            transformed_data = torch.zeros(B, n_leads, signal_crop_len)
            for b in range(B):
                start = signal_non_zero_start
                diff = signal_len - start
                if start > diff:
                    correction = start - diff
                    start -= correction
                end = start + signal_crop_len
                for l in range(n_leads):
                    transformed_data[b, l, :] = batch[b, l, start:end]

        else:
            transformed_data = batch

        return transformed_data.to(device)

# data

In [17]:
train_dl = dataloader.get_train_dataloader()
val_dl = dataloader.get_val_dataloader()
test_dl = dataloader.get_test_dataloader()

In [18]:
train_dl.dataset_size, val_dl.dataset_size, test_dl.dataset_size

(293728, 34775, 17276)

In [19]:
for train_batch in (train_dl):
    raw, exam_id, label = train_batch
    break

In [20]:
raw.shape, raw

(torch.Size([128, 4096, 12]),
 tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]

In [21]:
exam_id.shape, exam_id

((128,),
 array([3172506,  824319,  652737, 3199595,   18148, 1463563, 1175532,
        2654895,   79264,  801107, 3211241, 1285238, 2743211, 1630791,
        1499926, 1463771, 1481569, 3169198,  989421,  467853,   86393,
        1539302,  122023,  140925,  532209,  617443,  898489,  197210,
        2825990,  432213, 2668984, 1370693,   66958,  686823, 1577156,
         728647,  216801,   89693,   45289, 3618947,  659567,  594210,
         141054, 1577303, 3186816,  476817,  810932,  252111,  989631,
         837762,  622994,  411018,  696855, 1178174, 1016170,  622876,
        1658897, 3167924, 2658317, 1630955, 1339754,   90606, 1446726,
         903665, 1140243,  446125, 4270982, 3204706,  188635,  383852,
        1504657, 1155899,  480556,  440161,  292542,  367715, 2516781,
        4203670,  785939,  652654,  165093,  964656, 1272654, 1668236,
        1266379,  956795, 1293393, 2888861,   15571, 2726021,  586972,
        2681431, 1426188, 1520863, 2754004, 3140535,  922947,  82657

In [22]:
label.shape, label

(torch.Size([128, 3]),
 tensor([[0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 1, 0],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [1, 0, 0],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 1, 0],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [1, 0, 0],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [1, 0, 0],
         [0, 0, 1],
         [1, 0, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1],
         [0, 0, 1