# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import torch
import torch.nn as nn

from tqdm import tqdm
from torch.utils.data import DataLoader

In [3]:
from dataloaders.code15text import CODE, CODEsplit
from models.baseline import ResnetBaseline
from utils import get_inputs

# init

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

database = CODE()
signal_model = ResnetBaseline(n_classes = 6)

# class

In [5]:
class JoinText(nn.Module):
  def __init__(
      self,
      signal_model,

      signal_in_chanels = 1280,
      text_in_chanels = 768,
      out_chanels = 1280,
    ):
    
      super().__init__()
      self.signal_model = signal_model

      self.signal_in_chanels = signal_in_chanels
      self.text_in_chanels = text_in_chanels
      self.out_chanels = out_chanels
      
      self.W_s = nn.Linear(self.signal_in_chanels, self.out_chanels)
      self.W_t = nn.Linear(self.text_in_chanels, self.out_chanels)

  def forward(
      self,
      signal,
      text_features,
    ):
    
      output = self.signal_model(signal)
      signal_embedding = output['signal_embedding']
      logits = output['logits']

      signal_embedding = self.W_s(signal_embedding)
      text_embedding = self.W_t(text_features)

      return {'logits': logits, 'signal_embedding': signal_embedding, 'text_embedding': text_embedding}

# draft

In [6]:
trn_ds = CODEsplit(database, database.trn_metadata)
trn_loader = DataLoader(trn_ds, batch_size = 42, shuffle = True, num_workers = 6)

In [7]:
for batch in (trn_loader):
    break
batch['x'].shape, batch['y'].shape, batch['h'].shape

(torch.Size([42, 4096, 12]), torch.Size([42, 6]), torch.Size([42, 768]))

In [11]:
raw = batch['x']
label = batch['y']
text_features = batch['h']

ecg = get_inputs(raw, device = device)
label = label.to(device).float()
text_features = text_features.to(device).float()

In [9]:
model = JoinText(signal_model = signal_model)
model = model.to(device)

In [12]:
model.eval()
with torch.no_grad():
    output = model(ecg, text_features)
    logits = output['logits']
    signal_embedding = output['signal_embedding']
    text_embedding = output['text_embedding']

In [13]:
logits.shape, signal_embedding.shape, text_embedding.shape

(torch.Size([42, 6]), torch.Size([42, 1280]), torch.Size([42, 1280]))

In [9]:
# signal_model = signal_model.to(device)

In [10]:
# signal_model.eval()
# with torch.no_grad():
#     output = signal_model(ecg)
#     logits = output['logits']
#     signal_embedding = output['signal_embedding']

In [11]:
# logits.shape, signal_embedding.shape

(torch.Size([42, 6]), torch.Size([42, 1280]))