In [50]:
from collections import defaultdict
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from data.synthetic_data import SyntheticDataset
from model import MLP, SharedBottom, OMOE, MMOE, CGC
from utils import loss

In [2]:
num_data = 10000
input_dim = 100
task_corr_list = [0.25, 0.5, 0.75, 1]

test_ratio  = 0.1
test_size  = int(num_data * test_ratio)
train_size = num_data - test_size

n_tasks = 2
n_experts = 8
n_shared_experts = 2
n_task_experts = [3,3]
expert_size = 16
tower_size = 8
total_param_size = (input_dim * expert_size * n_experts) + (expert_size * tower_size * n_tasks)
shared_size = round(total_param_size / (input_dim + tower_size * n_tasks))

learning_rate = 0.001
n_epochs = 10
mb_size = 10
num_mb  = num_data // mb_size

In [35]:
mtl_name_list = ['shared_bottom', 'omoe', 'mmoe', 'cgc']

def init_model(model_type):
    if model_type == 'single':
        model = MLP(
            input_size=input_dim,
            hidden_size=round(total_param_size/input_dim)
        )
    elif model_type == 'shared_bottom':
        model = SharedBottom(
            input_size=input_dim,
            shared_size=shared_size,
            tower_size=tower_size,
            num_tasks=n_tasks
        )
    elif model_type == 'omoe':
        model = OMOE(
            input_size=input_dim,
            expert_size=expert_size,
            tower_size=tower_size,
            num_tasks=n_tasks,
            num_experts=n_experts
        )
    elif model_type == 'mmoe':
        model = MMOE(
            input_size=input_dim,
            expert_size=expert_size,
            tower_size=tower_size,
            num_tasks=n_tasks,
            num_experts=n_experts
        )
    elif model_type == 'cgc':
        model = CGC(
            input_size=input_dim, 
            expert_size=expert_size, 
            tower_size=tower_size, 
            num_tasks=n_tasks, 
            num_shared_experts=n_shared_experts, 
            num_task_experts=n_task_experts
        )
    
    optims = torch.optim.Adam(model.parameters(), lr=learning_rate)
    return model, optims


In [None]:
mse_loss = nn.MSELoss()
mtl_loss = loss.MultiTaskLoss()

model_dict   = dict()
history_dict = dict()
for task_corr in task_corr_list:
    print(f'Create dataset with task correlation {task_corr}..')
    dataset = SyntheticDataset(num_data, input_dim, task_corr=task_corr)
    train, test = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )

    print(f'Initialize model..')
    model_st1,  optims_st1  = init_model('single')
    model_st2,  optims_st2  = init_model('single')
    model_sb,   optims_sb   = init_model('shared_bottom')
    model_omoe, optims_omoe = init_model('omoe')
    model_mmoe, optims_mmoe = init_model('mmoe')
    model_cgc,  optims_cgc  = init_model('cgc')

    history = defaultdict(list)
    for it in range(n_epochs):
        print(f'Start {it}-th epoch..')

        cost_st1  = 0
        cost_st2  = 0
        cost_sb   = 0
        cost_omoe = 0
        cost_mmoe = 0
        cost_cgc  = 0

        dataloader = DataLoader(train.dataset, batch_size=mb_size, shuffle=True)
        for data in dataloader:
            X, y   = data[0], data[1]
            y_st1, y_st2 = y[0], y[1]

            yhat_st1 = model_st1(X)
            loss_st1 = mse_loss(yhat_st1, y_st1.view(-1,1))
            optims_st1.zero_grad()
            loss_st1.backward()
            optims_st1.step()

            yhat_st2 = model_st2(X)
            loss_st2 = mse_loss(yhat_st2, y_st2.view(-1,1))
            optims_st2.zero_grad()
            loss_st2.backward()
            optims_st2.step()

            yhat_sb = model_sb(X)
            loss_sb = mtl_loss(yhat_sb, y)
            optims_sb.zero_grad()
            loss_sb.backward()
            optims_sb.step()

            yhat_omoe = model_omoe(X)
            loss_omoe = mtl_loss(yhat_omoe, y)
            optims_omoe.zero_grad()
            loss_omoe.backward()
            optims_omoe.step()

            yhat_mmoe = model_mmoe(X)
            loss_mmoe = mtl_loss(yhat_mmoe, y)
            optims_mmoe.zero_grad()
            loss_mmoe.backward()
            optims_mmoe.step()

            yhat_cgc = model_cgc(X)
            loss_cgc = mtl_loss(yhat_cgc, y)
            optims_cgc.zero_grad()
            loss_cgc.backward()
            optims_cgc.step()

            cost_st1  += (loss_st1  / num_mb)
            cost_st2  += (loss_st2  / num_mb)
            cost_sb   += (loss_sb   / num_mb)
            cost_omoe += (loss_omoe / num_mb)
            cost_mmoe += (loss_mmoe / num_mb)
            cost_cgc  += (loss_cgc  / num_mb)
        
        history['single_task1'].append(cost_st1.item())
        history['single_task2'].append(cost_st2.item())
        history['shared_bottom'].append(cost_sb.item())
        history['omoe'].append(cost_omoe.item())
        history['mmoe'].append(cost_mmoe.item())
        history['cgc'].append(cost_cgc.item())
        
    history_dict[task_corr] = history
    model_dict[task_corr] = {
        'single_task1' : model_st1, 
        'single_task2' : model_st2, 
        'shared_bottom': model_sb, 
        'omoe': model_omoe, 
        'mmoe': model_mmoe, 
        'cgc' : model_cgc
    }

In [None]:
fig, axs = plt.subplots(1, len(task_corr_list), sharey=True, figsize=(15, 3))
plt.suptitle('MTL Model Loss', fontsize=15, y=1.15)
for idx, task_corr in enumerate(task_corr_list):
    history = history_dict[task_corr]
    model_names = list(history.keys())

    ax = axs[idx]
    for mtl in mtl_name_list:
        ax.plot(history[mtl], label=mtl)
    ax.set_title(f'Correlation {task_corr}')

for ax in axs.flat:
    ax.set(xlabel='epoch', ylabel='loss')
    ax.grid(True)

plt.legend(loc=(-2.3, 1.13), ncol=len(model_names))

In [None]:
X_test = test.dataset.X
y1_test = test.dataset.y1.view(-1,1)
y2_test = test.dataset.y2.view(-1,1)

fig, axs = plt.subplots(1, len(task_corr_list), sharey=True, figsize=(15, 3))
plt.suptitle('MTL gain on Synthetic Data', fontsize=15, y=1.15)
labels = ['task1', 'task2']
width = 0.35
x = np.arange(len(labels))

for idx, task_corr in enumerate(task_corr_list):
    models = model_dict[task_corr]
    mtl_gain1_list = []
    mtl_gain2_list = []

    with torch.no_grad():
        yhat_st1 = models['single_task1'](X_test)
        yhat_st2 = models['single_task2'](X_test)
        mse_st1 = mse_loss(yhat_st1, y1_test)
        mse_st2 = mse_loss(yhat_st2, y2_test)

        for mtl in mtl_name_list:
            yhat1, yhat2 = models[mtl](X_test)
            mse1 = mse_loss(yhat1, y1_test)
            mse2 = mse_loss(yhat2, y2_test)
            mtl_gain1_list.append(mse_st1 - mse1)
            mtl_gain2_list.append(mse_st2 - mse2)

    ax = axs[idx]
    for i, m in enumerate(mtl_name_list):
        rects = ax.bar(x - width/2, [mtl_gain1_list[i], mtl_gain2_list[i]], width, label=m)
    ax.set_title(f'Correlation {task_corr}')

for ax in axs.flat:
    ax.set(ylabel='MTL Gain')
    ax.grid(True)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)

plt.legend(loc=(-2.3, 1.13), ncol=len(model_names))
