# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import DataLoader

In [3]:
from dataloaders.code15text import CODE, CODEsplit
from models.baseline import ResnetBaseline
from models.join import JoinText
from utils import get_inputs

# init

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

database = CODE()
model = JoinText(signal_model = ResnetBaseline(n_classes = 6))
model = model.to(device)

# run

In [5]:
trn_ds = CODEsplit(database, database.trn_metadata)
trn_loader = DataLoader(trn_ds, batch_size = 42, shuffle = True, num_workers = 6)

In [6]:
for batch in (trn_loader):
    break
batch['x'].shape, batch['y'].shape, batch['h'].shape

(torch.Size([42, 4096, 12]), torch.Size([42, 6]), torch.Size([42, 768]))

In [7]:
raw = batch['x']
label = batch['y']
text_features = batch['h']

ecg = get_inputs(raw, device = device)
label = label.to(device).float()
text_features = text_features.to(device).float()

In [10]:
model.eval()
with torch.no_grad():
    output = model(ecg, text_features)
    logits = output['logits']
    signal_embedding = output['signal_embedding']
    text_embedding = output['text_embedding']

In [11]:
logits.shape, signal_embedding.shape, text_embedding.shape

(torch.Size([42, 6]), torch.Size([42, 1280]), torch.Size([42, 1280]))

# l1

In [15]:
F.binary_cross_entropy_with_logits(logits, label)

tensor(0.6996, device='cuda:0')

In [16]:
F.l1_loss(signal_embedding, text_embedding)

tensor(0.1424, device='cuda:0')

In [19]:
def join_l1(input, target, alpha = 5.56, beta = 1.23):
    loss_l = F.binary_cross_entropy_with_logits(input['logits'], target['label'])
    loss_t = F.l1_loss(input['signal_embedding'], target['text_embedding'])
    return alpha * loss_l + beta * loss_t

In [20]:
join_l1({'logits': logits, 'signal_embedding': signal_embedding}, 
        {'label': label, 'text_embedding': text_embedding})

tensor(4.0649, device='cuda:0')