# setup

In [1]:
import torch
import torch.nn as nn
import os

In [2]:
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

# from configs.baseline import LoadDataConfig
from configs.fake import LoadDataConfig
from configs.moe import MoE_cnn_args
from data.load_data import LoadData
# from models.moe import ResnetMoE
from utils import train, eval, plot_log, export

# init

In [3]:
model_label = 'moe'

In [4]:
loader_config = LoadDataConfig()
moe_config = MoE_cnn_args()

In [5]:
dataloader = LoadData(**loader_config.__dict__)
# model = ResnetMoE(**moe_config.__dict__)
# model = torch.load('output/pretrained_moe.pt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EPOCHS = 5 + 1



In [6]:
# from configs.baseline import Downstream_cnn_args
# from models.baseline import ResnetBaseline

# resnet_config = Downstream_cnn_args()
# expert = ResnetBaseline(**resnet_config.__dict__)
# gate = torch.load('output/gate.pt')

In [7]:
# from collections import OrderedDict

# key_transformation = []
# for key in expert.state_dict().keys():
#     key_transformation.append(key)

In [8]:
# backbone = nn.Sequential(*list(gate.children())[:-1])

# state_dict = backbone.state_dict()
# new_state_dict = OrderedDict()

# for i, (key, value) in enumerate(state_dict.items()):
#     new_key = key_transformation[i]
#     new_state_dict[new_key] = value

In [9]:
# log = expert.load_state_dict(new_state_dict, strict = False)
# assert log.missing_keys == ['linear.weight', 'linear.bias']

In [10]:
from models.baseline import ResnetBaseline
from collections import OrderedDict

class ResnetMoE(nn.Module):
    def __init__(self, gate_path, resnet_config, n_experts):
        super().__init__()

        self.gate = torch.load(gate_path)
        backbone = self.generate_backbone(resnet_config)
        # self.gate = ResnetBaseline(**resnet_config.__dict__)
        # n_experts = 6
        self.experts = nn.ModuleList()
        for _ in range(n_experts):
            # self.experts.append(ResnetBaseline(**resnet_config.__dict__))
            expert = ResnetBaseline(**resnet_config.__dict__)
            log = expert.load_state_dict(backbone, strict = False)
            assert log.missing_keys == ['linear.weight', 'linear.bias']
            self.experts.append(expert)
        self.num_classes = resnet_config.__dict__['n_classes']


    def forward(self, x):
        g = self.gate.forward(x)
        g = torch.sigmoid(g)
        logits = [expert.forward(x) for expert in self.experts]

        g = g.unsqueeze(1)
        g = g.expand(-1, self.num_classes, -1)
        logits = torch.stack(logits, dim = 2)
        logits = torch.sum(g * logits, dim = 2)

        return logits
    
    def generate_backbone(self, resnet_config):
        key_transformation = []
        for key in ResnetBaseline(**resnet_config.__dict__).state_dict().keys():
            key_transformation.append(key)
        
        backbone = nn.Sequential(*list(self.gate.children())[:-1])

        state_dict = backbone.state_dict()
        new_state_dict = OrderedDict()

        for i, (key, value) in enumerate(state_dict.items()):
            new_key = key_transformation[i]
            new_state_dict[new_key] = value
        
        return new_state_dict

# draft

In [14]:
from utils import get_inputs

In [11]:
model = ResnetMoE(**moe_config.__dict__)

In [12]:
model = model.to(device)

In [13]:
train_dl = dataloader.get_train_dataloader()

In [15]:
for batch in (train_dl):
    raw, exam_id, label = batch
    ecg = get_inputs(raw).to(device)
    break

In [16]:
logits = model.forward(ecg)

In [18]:
logits.shape

torch.Size([2, 6])