# setup

In [1]:
import torch
import torch.nn as nn
import os

from tqdm import tqdm

In [2]:
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

# from configs.baseline import LoadDataConfig, Downstream_cnn_args
from configs.fake import LoadDataConfig, Downstream_cnn_args
from data.load_data import LoadData
from models.baseline import ResnetBaseline
from utils import get_inputs

# init

In [3]:
model_label = 'moe'

In [4]:
loader_config = LoadDataConfig()
resnet_config = Downstream_cnn_args()

In [5]:
dataloader = LoadData(**loader_config.__dict__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [6]:
gate_path = 'output/{}.pt'.format('gate')
n_experts = 3

# draft

In [7]:
gate = torch.load(gate_path)
experts = [ResnetBaseline(**resnet_config.__dict__) for _ in range(n_experts)]

In [8]:
num_classes = resnet_config.__dict__['n_classes']

In [9]:
train_dl = dataloader.get_train_dataloader()

In [10]:
with torch.no_grad():
    for batch in tqdm(train_dl):
        raw, exam_id, label = batch
        ecg = get_inputs(raw).to(device)
        label = label.to(device).float()
        break

  0%|          | 0/2 [00:00<?, ?it/s]


In [11]:
g = gate.forward(ecg)
g = torch.sigmoid(g)

In [12]:
g.shape, g

(torch.Size([2, 3]),
 tensor([[1.7541e-04, 5.5993e-05, 9.9973e-01],
         [1.8848e-03, 3.6716e-01, 6.2015e-01]], device='cuda:0',
        grad_fn=<SigmoidBackward>))

In [13]:
logits = [expert.to(device).forward(ecg) for expert in experts]

In [14]:
logits[0].shape, logits

(torch.Size([2, 6]),
 [tensor([[ 0.4856,  0.8151,  0.9983,  0.4227, -0.6620, -0.6500],
          [ 0.1341,  0.3821,  0.3847,  1.0778, -1.5524, -0.0611]],
         device='cuda:0', grad_fn=<AddmmBackward>),
  tensor([[-0.5103,  0.3201,  1.0194,  0.5512,  0.0159,  0.2435],
          [-0.2882,  0.4040, -0.0243, -0.0891,  0.4693, -0.3349]],
         device='cuda:0', grad_fn=<AddmmBackward>),
  tensor([[ 0.5113,  0.2070, -0.6829, -0.3501,  0.1526, -0.6137],
          [ 0.2547, -0.5813, -0.4381, -0.2162,  0.7637, -0.2244]],
         device='cuda:0', grad_fn=<AddmmBackward>)])

In [15]:
import numpy as np

In [16]:
# g = g.cpu().detach().numpy()
# logits = [logit.cpu().detach().numpy() for logit in logits]
g = g.cpu().detach()
logits = [logit.cpu().detach() for logit in logits]

In [17]:
# g_expanded = np.expand_dims(g, axis=1)
g_expanded = g.unsqueeze(1)
g_expanded.shape, g_expanded

(torch.Size([2, 1, 3]),
 tensor([[[1.7541e-04, 5.5993e-05, 9.9973e-01]],
 
         [[1.8848e-03, 3.6716e-01, 6.2015e-01]]]))

In [18]:
# g_tiled = np.tile(g_expanded, (1, num_classes, 1))
g_tiled = g_expanded.expand(-1, num_classes, -1)
g_tiled.shape, g_tiled

(torch.Size([2, 6, 3]),
 tensor([[[1.7541e-04, 5.5993e-05, 9.9973e-01],
          [1.7541e-04, 5.5993e-05, 9.9973e-01],
          [1.7541e-04, 5.5993e-05, 9.9973e-01],
          [1.7541e-04, 5.5993e-05, 9.9973e-01],
          [1.7541e-04, 5.5993e-05, 9.9973e-01],
          [1.7541e-04, 5.5993e-05, 9.9973e-01]],
 
         [[1.8848e-03, 3.6716e-01, 6.2015e-01],
          [1.8848e-03, 3.6716e-01, 6.2015e-01],
          [1.8848e-03, 3.6716e-01, 6.2015e-01],
          [1.8848e-03, 3.6716e-01, 6.2015e-01],
          [1.8848e-03, 3.6716e-01, 6.2015e-01],
          [1.8848e-03, 3.6716e-01, 6.2015e-01]]]))

In [19]:
# logits_transposed = np.transpose(logits, axes = (1, 2, 0))
logits_transposed = torch.stack(logits, dim = 2)
logits_transposed.shape, logits_transposed

(torch.Size([2, 6, 3]),
 tensor([[[ 0.4856, -0.5103,  0.5113],
          [ 0.8151,  0.3201,  0.2070],
          [ 0.9983,  1.0194, -0.6829],
          [ 0.4227,  0.5512, -0.3501],
          [-0.6620,  0.0159,  0.1526],
          [-0.6500,  0.2435, -0.6137]],
 
         [[ 0.1341, -0.2882,  0.2547],
          [ 0.3821,  0.4040, -0.5813],
          [ 0.3847, -0.0243, -0.4381],
          [ 1.0778, -0.0891, -0.2162],
          [-1.5524,  0.4693,  0.7637],
          [-0.0611, -0.3349, -0.2244]]]))

In [20]:
# yhat = np.sum(g_tiled * logits_transposed, axis = 2)
yhat = torch.sum(g_tiled * logits_transposed, dim = 2)
yhat.shape, yhat

(torch.Size([2, 6]),
 tensor([[ 0.5112,  0.2071, -0.6824, -0.3499,  0.1524, -0.6136],
         [ 0.0524, -0.2114, -0.2799, -0.1647,  0.6430, -0.2623]]))

In [21]:
g = g.numpy()
logits = [logit.numpy() for logit in logits]

yhat_for = np.zeros(shape = logits[0].shape)
for i in range(loader_config.__dict__['batch_size']):
    for j in range(n_experts):
        yhat_for[i, :] += g[i, j] * logits[j][i, :]
yhat_for.shape, yhat_for

((2, 6),
 array([[ 0.51124154,  0.2070664 , -0.68243882, -0.34988201,  0.15244258,
         -0.61364974],
        [ 0.05243565, -0.21141804, -0.27988988, -0.16473723,  0.64298703,
         -0.26226373]]))

In [22]:
assert np.isclose(yhat, yhat_for).all()

# moe

In [23]:
# class ResnetMoE(nn.Module):
#     def __init__(self, gate_path, resnet_config, n_experts):
#         super().__init__()

#         self.gate = torch.load(gate_path)
#         self.experts = nn.ModuleList()
#         for _ in range(n_experts):
#             self.experts.append(ResnetBaseline(**resnet_config.__dict__))
#         self.num_classes = resnet_config.__dict__['n_classes']


#     def forward(self, x):
#         g = self.gate.forward(x)
#         logits = [expert.forward(x) for expert in self.experts]

#         g = g.unsqueeze(1)
#         g = g.expand(-1, self.num_classes, -1)
#         logits = torch.stack(logits, dim = 2)
#         logits = torch.sum(g * logits, dim = 2)

#         return logits

In [24]:
from configs.moe import MoE_cnn_args
from models.moe import ResnetMoE

moe_config = MoE_cnn_args()

In [25]:
model = ResnetMoE(**moe_config.__dict__)
model = model.to(device)

In [26]:
next(model.parameters()).is_cuda, next(model.gate.parameters()).is_cuda, next(model.experts[0].parameters()).is_cuda

(True, True, True)

In [27]:
model.eval()
with torch.no_grad():
    yhat = model.forward(ecg)

In [28]:
yhat

tensor([[ 0.0447, -0.0198,  0.0166,  0.0671, -0.0603,  0.0254],
        [ 0.0086, -0.0147,  0.0017,  0.0396, -0.0532,  0.0260]],
       device='cuda:0')

In [29]:
g = model.gate.forward(ecg)
g = torch.sigmoid(g)
logits = [expert.forward(ecg) for expert in model.experts]

In [30]:
g = g.cpu().detach().numpy()
logits = [logit.cpu().detach().numpy() for logit in logits]

yhat_for = np.zeros(shape = logits[0].shape)
for i in range(loader_config.__dict__['batch_size']):
    for j in range(n_experts):
        yhat_for[i, :] += g[i, j] * logits[j][i, :]
yhat_for.shape, yhat_for

((2, 6),
 array([[ 0.0447257 , -0.01983287,  0.01664077,  0.06708844, -0.06032221,
          0.02544975],
        [ 0.00858808, -0.01473155,  0.00167097,  0.03960469, -0.05318132,
          0.02595899]]))

In [31]:
assert np.isclose(yhat.cpu(), yhat_for).all()