In [1]:
from collections import defaultdict
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchsummary import summary

from data.synthetic_data import SyntheticDataset
from model import MLP, SharedBottom, OMOE, MMOE, CGC
from utils import loss

In [2]:
num_data = 10000
input_dim = 100
task_corr = 0.5

dataset = SyntheticDataset(num_data, input_dim, task_corr=task_corr)

In [3]:
test_ratio  = 0.1

test_size  = int(len(dataset) * test_ratio)
train_size = len(dataset) - test_size
train, test = torch.utils.data.random_split(
    dataset, [train_size, test_size]
)

In [4]:
n_tasks = 2
n_experts = 8
n_shared_experts = 2
n_task_experts = [3,3]
expert_size = 16
tower_size = 8
total_param_size = (input_dim * expert_size * n_experts) + (expert_size * tower_size * n_tasks)
shared_size = round(total_param_size / (input_dim + tower_size * n_tasks))

model_st1 = MLP(
    input_size=input_dim,
    hidden_size=round(total_param_size/input_dim)
)
model_st2 = MLP(
    input_size=input_dim,
    hidden_size=round(total_param_size/input_dim)
)
model_sb = SharedBottom(
    input_size=input_dim,
    shared_size=shared_size,
    tower_size=tower_size,
    num_tasks=n_tasks
)
model_omoe = OMOE(
    input_size=input_dim,
    expert_size=expert_size,
    tower_size=tower_size,
    num_tasks=n_tasks,
    num_experts=n_experts
)
model_mmoe = MMOE(
    input_size=input_dim,
    expert_size=expert_size,
    tower_size=tower_size,
    num_tasks=n_tasks,
    num_experts=n_experts
)
model_cgc = CGC(
    input_size=input_dim, 
    expert_size=expert_size, 
    tower_size=tower_size, 
    num_tasks=2, 
    num_shared_experts=n_shared_experts, 
    num_task_experts=n_task_experts
)

lr_list = [0.0001, 0.001, 0.01]
optims_st1  = [torch.optim.Adam(model_st1.parameters(),  lr=lr) for lr in lr_list]
optims_st2  = [torch.optim.Adam(model_st2.parameters(),  lr=lr) for lr in lr_list]
optims_sb   = [torch.optim.Adam(model_sb.parameters(),   lr=lr) for lr in lr_list]
optims_omoe = [torch.optim.Adam(model_omoe.parameters(), lr=lr) for lr in lr_list]
optims_mmoe = [torch.optim.Adam(model_mmoe.parameters(), lr=lr) for lr in lr_list]
optims_cgc  = [torch.optim.Adam(model_cgc.parameters(), lr=lr) for lr in lr_list]

mse_loss = nn.MSELoss()
mtl_loss = loss.MultiTaskLoss()

In [5]:
n_epochs = 10
mb_size = 10
num_mb  = num_data // mb_size

lr_idx = 1
for it in range(n_epochs):
    cost_st1, cost_st2, cost_sb, cost_omoe, cost_mmoe, cost_cgc = 0, 0, 0, 0, 0, 0

    dataloader = DataLoader(train.dataset, batch_size=mb_size, shuffle=True)
    for data in dataloader:
        X_mb, y_mb   = data[0], data[1]
        y1_mb, y2_mb = y_mb[0], y_mb[1]

        yhat_mb_st1 = model_st1(X_mb)
        loss_mb_st1 = mse_loss(yhat_mb_st1, y1_mb.view(-1,1))
        optims_st1[lr_idx].zero_grad()
        loss_mb_st1.backward()
        optims_st1[lr_idx].step()

        yhat_mb_st2 = model_st2(X_mb)
        loss_mb_st2 = mse_loss(yhat_mb_st2, y2_mb.view(-1,1))
        optims_st2[lr_idx].zero_grad()
        loss_mb_st2.backward()
        optims_st2[lr_idx].step()

        yhat_mb_sb = model_sb(X_mb)
        loss_mb_sb = mtl_loss(yhat_mb_sb, y_mb)
        optims_sb[lr_idx].zero_grad()
        loss_mb_sb.backward()
        optims_sb[lr_idx].step()

        yhat_mb_omoe = model_omoe(X_mb)
        loss_mb_omoe = mtl_loss(yhat_mb_omoe, y_mb)
        optims_omoe[lr_idx].zero_grad()
        loss_mb_omoe.backward()
        optims_omoe[lr_idx].step()

        yhat_mb_mmoe = model_mmoe(X_mb)
        loss_mb_mmoe = mtl_loss(yhat_mb_mmoe, y_mb)
        optims_mmoe[lr_idx].zero_grad()
        loss_mb_mmoe.backward()
        optims_mmoe[lr_idx].step()

        yhat_mb_cgc = model_cgc(X_mb)
        loss_mb_cgc = mtl_loss(yhat_mb_cgc, y_mb)
        optims_cgc[lr_idx].zero_grad()
        loss_mb_cgc.backward()
        optims_cgc[lr_idx].step()

        cost_st1 += (loss_mb_st1   / num_mb)
        cost_st2 += (loss_mb_st2   / num_mb)
        cost_sb  += (loss_mb_sb    / num_mb)
        cost_omoe += (loss_mb_omoe / num_mb)
        cost_mmoe += (loss_mb_mmoe / num_mb)
        cost_cgc += (loss_mb_cgc / num_mb)
    
    print(f'[{it}] st1 {cost_st1:.3f}; st2 {cost_st2:.3f}; sb: {cost_sb:.3f}; omoe: {cost_omoe:.3f}; mmoe: {cost_mmoe:.3f}; cgc: {cost_cgc:.3f}')

[0] st1 0.178; st2 0.171; sb: 0.389; omoe: 0.594; mmoe: 0.556; cgc: 0.498
[1] st1 0.145; st2 0.133; sb: 0.283; omoe: 0.486; mmoe: 0.442; cgc: 0.407
[2] st1 0.144; st2 0.132; sb: 0.278; omoe: 0.483; mmoe: 0.422; cgc: 0.385
[3] st1 0.143; st2 0.131; sb: 0.276; omoe: 0.459; mmoe: 0.395; cgc: 0.356
[4] st1 0.143; st2 0.131; sb: 0.275; omoe: 0.419; mmoe: 0.385; cgc: 0.345
[5] st1 0.143; st2 0.131; sb: 0.275; omoe: 0.401; mmoe: 0.381; cgc: 0.317
[6] st1 0.143; st2 0.131; sb: 0.275; omoe: 0.388; mmoe: 0.339; cgc: 0.291
[7] st1 0.143; st2 0.131; sb: 0.275; omoe: 0.374; mmoe: 0.312; cgc: 0.282
[8] st1 0.142; st2 0.131; sb: 0.274; omoe: 0.328; mmoe: 0.296; cgc: 0.278
[9] st1 0.142; st2 0.131; sb: 0.274; omoe: 0.297; mmoe: 0.288; cgc: 0.280


In [6]:
with torch.no_grad():
    yhat_st1 = model_st1(test.dataset.X)
    yhat_st2 = model_st2(test.dataset.X)
    yhat_sb  = model_sb(test.dataset.X)
    yhat_omoe = model_omoe(test.dataset.X)
    yhat_mmoe = model_mmoe(test.dataset.X)
    yhat_cgc = model_cgc(test.dataset.X)
    print(f'single task:   {mse_loss(yhat_st1, test.dataset.y1.view(-1,1)).item():.3f}, {mse_loss(yhat_st2, test.dataset.y2.view(-1,1)).item():.3f}')
    print(f'shared bottom: {mse_loss(yhat_sb[0], test.dataset.y1.view(-1,1)).item():.3f}, {mse_loss(yhat_sb[1], test.dataset.y2.view(-1,1)).item():.3f}')
    print(f'onegate moe:   {mse_loss(yhat_omoe[0], test.dataset.y1.view(-1,1)).item():.3f}, {mse_loss(yhat_omoe[1], test.dataset.y2.view(-1,1)).item():.3f}')
    print(f'multigate moe: {mse_loss(yhat_mmoe[0], test.dataset.y1.view(-1,1)).item():.3f}, {mse_loss(yhat_mmoe[1], test.dataset.y2.view(-1,1)).item():.3f}')
    print(f'cgc:           {mse_loss(yhat_cgc[0], test.dataset.y1.view(-1,1)).item():.3f}, {mse_loss(yhat_cgc[1], test.dataset.y2.view(-1,1)).item():.3f}')

single task:   0.143, 0.130
shared bottom: 0.143, 0.131
onegate moe:   0.153, 0.135
multigate moe: 0.146, 0.140
cgc:           0.147, 0.133
