# setup

In [51]:
import torch
import torch.nn as nn
import os

In [52]:
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

# from configs.baseline import LoadDataConfig
from configs.fake import LoadDataConfig
from configs.moe import MoE_cnn_args
from data.load_data import LoadData
from models.moe import ResnetMoE
from utils import train, eval, plot_log, export

# init

In [53]:
model_label = 'moe'

In [54]:
loader_config = LoadDataConfig()
moe_config = MoE_cnn_args()

In [55]:
dataloader = LoadData(**loader_config.__dict__)
# model = ResnetMoE(**moe_config.__dict__)
model = torch.load('output/pretrained_moe.pt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EPOCHS = 5



# train

In [56]:
SIGNAL_CROP_LEN = 2560
SIGNAL_NON_ZERO_START = 571

def get_inputs_conjugado(batch, label, apply = "non_zero", device = "cuda"):
    # (B, C, L)
    if batch.shape[1] > batch.shape[2]:
        batch = batch.permute(0, 2, 1)

    B, n_leads, signal_len = batch.shape

    if apply == "non_zero":
        transformed_data = torch.zeros(B, n_leads, SIGNAL_CROP_LEN)
        for b in range(B):
            start = SIGNAL_NON_ZERO_START
            diff = signal_len - start
            if start > diff:
                correction = start - diff
                start -= correction
            end = start + SIGNAL_CROP_LEN
            for l in range(n_leads):
                transformed_data[b, l, :] = batch[b, l, start:end]

    else:
        transformed_data = batch
    
    block = torch.tensor([label[i, :3].any() for i in range(label.shape[0])])
    rhythm = torch.tensor([label[i, 3:].any() for i in range(label.shape[0])])
    normal = torch.tensor([not label[i, :].any() for i in range(label.shape[0])])
    superlabel = torch.stack((block, rhythm, normal)).T

    return transformed_data.to(device), superlabel.to(device).float()

In [57]:
from tqdm import tqdm

def train_conjugado(model, loader, optimizer, criterion, device = "cuda"):
    log = []
    model.train()
    for batch in tqdm(loader):
        raw, exam_id, label = batch
        ecg, superlabel = get_inputs_conjugado(raw, label)
        label = label.to(device).float()

        logits = model.forward(ecg)
        g = model.gate.forward(ecg)
        loss = criterion(logits, label) + criterion(g, superlabel)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        log.append(loss.item())
    return log

In [58]:
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [59]:
log = []
for epoch in range(EPOCHS):
    train_dl, val_dl = dataloader.get_train_dataloader(), dataloader.get_val_dataloader()

    train_log = train_conjugado(model, train_dl, optimizer, criterion, device)
    val_log = eval(model, val_dl, criterion, device)
    plot_log(train_log, val_log, epoch = epoch)
    export(model, model_label, epoch)

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:00<00:00, 17.89it/s]
100%|██████████| 1/1 [00:00<00:00, 102.55it/s]


exporting partial model at epoch 0


100%|██████████| 2/2 [00:00<00:00, 19.48it/s]
100%|██████████| 1/1 [00:00<00:00, 104.44it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

exporting partial model at epoch 1


100%|██████████| 2/2 [00:00<00:00, 18.66it/s]
100%|██████████| 1/1 [00:00<00:00, 103.06it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

exporting partial model at epoch 2


100%|██████████| 2/2 [00:00<00:00, 17.25it/s]
100%|██████████| 1/1 [00:00<00:00, 63.29it/s]


exporting partial model at epoch 3


100%|██████████| 2/2 [00:00<00:00, 18.64it/s]
100%|██████████| 1/1 [00:00<00:00, 104.39it/s]


exporting partial model at epoch 4
