# setup

In [1]:
import torch
import torch.nn as nn
import os

from tqdm import tqdm

In [2]:
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

# from configs.baseline import LoadDataConfig, Downstream_cnn_args
from configs.fake import LoadDataConfig, Downstream_cnn_args
from data.load_data import LoadData
from models.baseline import ResnetBaseline
from utils import get_inputs

# init

In [3]:
model_label = 'moe'

In [4]:
loader_config = LoadDataConfig()
resnet_config = Downstream_cnn_args()

In [5]:
dataloader = LoadData(**loader_config.__dict__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [6]:
gate_path = 'output/{}.pt'.format('gate')
n_experts = 3

# draft

In [7]:
gate = torch.load(gate_path)
experts = [ResnetBaseline(**resnet_config.__dict__) for _ in range(n_experts)]

In [8]:
num_classes = resnet_config.__dict__['n_classes']

In [9]:
train_dl = dataloader.get_train_dataloader()

In [10]:
with torch.no_grad():
    for batch in tqdm(train_dl):
        raw, exam_id, label = batch
        ecg = get_inputs(raw).to(device)
        label = label.to(device).float()
        break

  0%|          | 0/2 [00:00<?, ?it/s]


In [11]:
g = gate.forward(ecg)

In [12]:
g.shape, g

(torch.Size([2, 3]),
 tensor([[-0.7242, -0.2769,  0.3864],
         [-0.7259, -0.1610,  0.1848]], device='cuda:0', grad_fn=<AddmmBackward>))

In [13]:
logits = [expert.to(device).forward(ecg) for expert in experts]

In [14]:
logits[0].shape, logits

(torch.Size([2, 6]),
 [tensor([[ 0.4104, -0.3844, -0.1688,  0.2827,  0.5176,  0.3112],
          [ 0.3104, -0.1623,  0.4761,  0.5980, -0.0701,  0.2303]],
         device='cuda:0', grad_fn=<AddmmBackward>),
  tensor([[ 0.2689,  0.5607,  0.3637, -0.4390,  0.4115,  0.2966],
          [ 0.6865,  0.1736,  0.7279, -0.1427, -0.0597, -0.5079]],
         device='cuda:0', grad_fn=<AddmmBackward>),
  tensor([[ 0.0603,  0.0088, -0.6210, -0.0704,  0.0960,  0.7776],
          [ 0.2196,  0.9104, -0.3917,  1.0485, -0.0217,  0.3171]],
         device='cuda:0', grad_fn=<AddmmBackward>)])

In [15]:
import numpy as np

In [16]:
# g = g.cpu().detach().numpy()
# logits = [logit.cpu().detach().numpy() for logit in logits]
g = g.cpu().detach()
logits = [logit.cpu().detach() for logit in logits]

In [17]:
# g_expanded = np.expand_dims(g, axis=1)
g_expanded = g.unsqueeze(1)
g_expanded.shape, g_expanded

(torch.Size([2, 1, 3]),
 tensor([[[-0.7242, -0.2769,  0.3864]],
 
         [[-0.7259, -0.1610,  0.1848]]]))

In [18]:
# g_tiled = np.tile(g_expanded, (1, num_classes, 1))
g_tiled = g_expanded.expand(-1, num_classes, -1)
g_tiled.shape, g_tiled

(torch.Size([2, 6, 3]),
 tensor([[[-0.7242, -0.2769,  0.3864],
          [-0.7242, -0.2769,  0.3864],
          [-0.7242, -0.2769,  0.3864],
          [-0.7242, -0.2769,  0.3864],
          [-0.7242, -0.2769,  0.3864],
          [-0.7242, -0.2769,  0.3864]],
 
         [[-0.7259, -0.1610,  0.1848],
          [-0.7259, -0.1610,  0.1848],
          [-0.7259, -0.1610,  0.1848],
          [-0.7259, -0.1610,  0.1848],
          [-0.7259, -0.1610,  0.1848],
          [-0.7259, -0.1610,  0.1848]]]))

In [19]:
# logits_transposed = np.transpose(logits, axes = (1, 2, 0))
logits_transposed = torch.stack(logits, dim = 2)
logits_transposed.shape, logits_transposed

(torch.Size([2, 6, 3]),
 tensor([[[ 0.4104,  0.2689,  0.0603],
          [-0.3844,  0.5607,  0.0088],
          [-0.1688,  0.3637, -0.6210],
          [ 0.2827, -0.4390, -0.0704],
          [ 0.5176,  0.4115,  0.0960],
          [ 0.3112,  0.2966,  0.7776]],
 
         [[ 0.3104,  0.6865,  0.2196],
          [-0.1623,  0.1736,  0.9104],
          [ 0.4761,  0.7279, -0.3917],
          [ 0.5980, -0.1427,  1.0485],
          [-0.0701, -0.0597, -0.0217],
          [ 0.2303, -0.5079,  0.3171]]]))

In [20]:
# yhat = np.sum(g_tiled * logits_transposed, axis = 2)
yhat = torch.sum(g_tiled * logits_transposed, dim = 2)
yhat.shape, yhat

(torch.Size([2, 6]),
 tensor([[-0.3484,  0.1266, -0.2184, -0.1104, -0.4517, -0.0070],
         [-0.2952,  0.2581, -0.5352, -0.2173,  0.0565, -0.0267]]))

In [21]:
g = g.numpy()
logits = [logit.numpy() for logit in logits]

yhat_for = np.zeros(shape = logits[0].shape)
for i in range(loader_config.__dict__['batch_size']):
    for j in range(n_experts):
        yhat_for[i, :] += g[i, j] * logits[j][i, :]
yhat_for.shape, yhat_for

((2, 6),
 array([[-0.34836203,  0.12657817, -0.21842196, -0.11037884, -0.45170977,
         -0.00698128],
        [-0.29522881,  0.25810405, -0.53520443, -0.21725466,  0.05647964,
         -0.02674845]]))

In [22]:
assert np.isclose(yhat, yhat_for).all()

# moe

In [23]:
class ResnetMoE(nn.Module):
    def __init__(self, gate_path, resnet_config, n_experts):
        super().__init__()

        self.gate = torch.load(gate_path)
        self.experts = nn.ModuleList()
        for _ in range(n_experts):
            self.experts.append(ResnetBaseline(**resnet_config.__dict__))
        self.num_classes = resnet_config.__dict__['n_classes']


    def forward(self, x):
        g = self.gate.forward(x)
        logits = [expert.forward(x) for expert in self.experts]

        g = g.unsqueeze(1)
        g = g.expand(-1, self.num_classes, -1)
        logits = torch.stack(logits, dim = 2)
        logits = torch.sum(g * logits, dim = 2)

        return logits

In [24]:
model = ResnetMoE(gate_path, resnet_config, n_experts)
model = model.to(device)

In [25]:
next(model.parameters()).is_cuda, next(model.gate.parameters()).is_cuda, next(model.experts[0].parameters()).is_cuda

(True, True, True)

In [26]:
model.eval()
with torch.no_grad():
    yhat = model.forward(ecg)

In [27]:
yhat

tensor([[-0.0043,  0.0012,  0.0036, -0.0449, -0.0240, -0.0502],
        [ 0.0017,  0.0003, -0.0101, -0.0131,  0.0006, -0.0136]],
       device='cuda:0')

In [28]:
g = model.gate.forward(ecg)
logits = [expert.forward(ecg) for expert in model.experts]

In [29]:
g = g.cpu().detach().numpy()
logits = [logit.cpu().detach().numpy() for logit in logits]

yhat_for = np.zeros(shape = logits[0].shape)
for i in range(loader_config.__dict__['batch_size']):
    for j in range(n_experts):
        yhat_for[i, :] += g[i, j] * logits[j][i, :]
yhat_for.shape, yhat_for

((2, 6),
 array([[-0.00433787,  0.0012323 ,  0.00359915, -0.04489214, -0.02397117,
         -0.05022467],
        [ 0.00168308,  0.000309  , -0.01007694, -0.01313952,  0.00055843,
         -0.01358158]]))

In [30]:
assert np.isclose(yhat.cpu(), yhat_for).all()