In [1]:
import sys
import torch
import click
import json
import datetime
from timeit import default_timer as timer

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils import data
import torchvision
import types

from tqdm import tqdm
from tensorboardX import SummaryWriter

# from models.gradient_scaler import MinNormElement
import losses
import datasets
import metrics
import model_selector
from min_norm_solvers import MinNormSolver, gradient_normalizers

NUM_EPOCHS = 100

In [2]:
with open('../configs.json') as config_params:
    configs = json.load(config_params)

with open('../mnist.json') as json_params:
    params = json.load(json_params)


exp_identifier = []
for (key, val) in params.items():
    if 'tasks' in key:
        continue
    if 'scales' in key:
        continue
    exp_identifier+= ['{}={}'.format(key,val)]
    
exp_identifier = '|'.join(exp_identifier)
params['exp_id'] = exp_identifier

In [3]:
params['exp_id']="test_mlt"

In [4]:
writer = SummaryWriter(log_dir='runs/{}_{}'.format(params['exp_id'], datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")))

In [5]:
# @click.command()
# @click.option('--param_file', default='params.json', help='JSON parameters file')

train_loader, train_dst, val_loader, val_dst = datasets.get_dataset(params, configs)

loss_fn = losses.get_loss(params)
metric = metrics.get_metrics(params)

model = model_selector.get_model(params)
model_params = []
for m in model:
    model_params += model[m].parameters()

if 'RMSprop' in params['optimizer']:
    optimizer = torch.optim.RMSprop(model_params, lr=params['lr'])
elif 'Adam' in params['optimizer']:
    optimizer = torch.optim.Adam(model_params, lr=params['lr'])
elif 'SGD' in params['optimizer']:
    optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9)

tasks = params['tasks']
all_tasks = configs[params['dataset']]['all_tasks']
print('Starting training with parameters \n \t{} \n'.format(str(params)))

if 'mgda' in params['algorithm']:
    approximate_norm_solution = params['use_approximation']
    if approximate_norm_solution:
        print('Using approximate min-norm solver')
    else:
        print('Using full solver')
        
n_iter = 0
loss_init = {}
for epoch in tqdm(range(NUM_EPOCHS)):
    start = timer()
    print('Epoch {} Started'.format(epoch))
    if (epoch+1) % 10 == 0:
        # Every 50 epoch, half the LR
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.85
        print('Half the learning rate{}'.format(n_iter))

    for m in model:
        model[m].train()

    for batch in train_loader:
        n_iter += 1
        # First member is always images
        images = batch[0]
        images = Variable(images.cuda())

        labels = {}
        # Read all targets of all tasks
        for i, t in enumerate(all_tasks):
            if t not in tasks:
                continue
            labels[t] = batch[i+1]
            labels[t] = Variable(labels[t].cuda())

        # Scaling the loss functions based on the algorithm choice
        loss_data = {}
        grads = {}
        scale = {}
        mask = None
        masks = {}
        if 'mgda' in params['algorithm']:
            # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA

            if approximate_norm_solution:
                optimizer.zero_grad()
                # First compute representations (z)
                images_volatile = Variable(images.data, volatile=True)
                rep, mask = model['rep'](images_volatile, mask)
                # As an approximate solution we only need gradients for input
                if isinstance(rep, list):
                    # This is a hack to handle psp-net
                    rep = rep[0]
                    rep_variable = [Variable(rep.data.clone(), requires_grad=True)]
                    list_rep = True
                else:
                    rep_variable = Variable(rep.data.clone(), requires_grad=True)
                    list_rep = False

                # Compute gradients of each loss function wrt z
                for t in tasks:
                    optimizer.zero_grad()
                    out_t, masks[t] = model[t](rep_variable, None)
                    loss = loss_fn[t](out_t, labels[t])
                    loss_data[t] = loss.item()
                    loss.backward()
                    grads[t] = []
                    if list_rep:
                        grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False))
                        rep_variable[0].grad.data.zero_()
                    else:
                        grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False))
                        rep_variable.grad.data.zero_()
            else:
                # This is MGDA
                for t in tasks:
                    # Comptue gradients of each loss function wrt parameters
                    optimizer.zero_grad()
                    rep, mask = model['rep'](images, mask)
                    out_t, masks[t] = model[t](rep, None)
                    loss = loss_fn[t](out_t, labels[t])
                    loss_data[t] = loss.item()
                    loss.backward()
                    grads[t] = []
                    for param in model['rep'].parameters():
                        if param.grad is not None:
                            grads[t].append(Variable(param.grad.data.clone(), requires_grad=False))

            # Normalize all gradients, this is optional and not included in the paper. See the notebook for details
            gn = gradient_normalizers(grads, loss_data, params['normalization_type'])
            for t in tasks:
                for gr_i in range(len(grads[t])):
                    grads[t][gr_i] = grads[t][gr_i] / gn[t]

            # Frank-Wolfe iteration to compute scales.
            sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
            for i, t in enumerate(tasks):
                scale[t] = float(sol[i])
        else:
            for t in tasks:
                masks[t] = None
                scale[t] = float(params['scales'][t])

        # Scaled back-propagation
        optimizer.zero_grad()
        rep, _ = model['rep'](images, mask)
        for i, t in enumerate(tasks):
            out_t, _ = model[t](rep, masks[t])
            loss_t = loss_fn[t](out_t, labels[t])
            loss_data[t] = loss_t.item()
            if i > 0:
                loss = loss + scale[t]*loss_t
            else:
                loss = scale[t]*loss_t
        loss.backward()
        optimizer.step()

        writer.add_scalar('training_loss', loss.item(), n_iter)
        for t in tasks:
            writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter)

    for m in model:
        model[m].eval()

    tot_loss = {}
    tot_loss['all'] = 0.0
    met = {}
    for t in tasks:
        tot_loss[t] = 0.0
        met[t] = 0.0

    num_val_batches = 0
    for batch_val in val_loader:
        val_images = Variable(batch_val[0].cuda(), volatile=True)
        labels_val = {}

        for i, t in enumerate(all_tasks):
            if t not in tasks:
                continue
            labels_val[t] = batch_val[i+1]
            labels_val[t] = Variable(labels_val[t].cuda(), volatile=True)

        val_rep, _ = model['rep'](val_images, None)
        for t in tasks:
            out_t_val, _ = model[t](val_rep, None)
            loss_t = loss_fn[t](out_t_val, labels_val[t])
            tot_loss['all'] += loss_t.item()
            tot_loss[t] += loss_t.item()
            metric[t].update(out_t_val, labels_val[t])
        num_val_batches+=1

    for t in tasks:
        writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter)
        metric_results = metric[t].get_result()
        for metric_key in metric_results:
            writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter)
        metric[t].reset()
    writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter)

    if epoch % 3 == 0:
        # Save after every 3 epoch
        state = {'epoch': epoch+1,
                'model_rep': model['rep'].state_dict(),
                'optimizer_state' : optimizer.state_dict()}
        for t in tasks:
            key_name = 'model_{}'.format(t)
            state[key_name] = model[t].state_dict()

        torch.save(state, "saved_models/{}_{}_model.pkl".format(params['exp_id'], epoch+1))

    end = timer()
    print('Epoch ended in {}s'.format(end - start))

  0%|          | 0/100 [00:00<?, ?it/s]

Starting training with parameters 
 	{'optimizer': 'Adam', 'batch_size': 256, 'lr': 0.0005, 'dataset': 'mnist', 'tasks': ['L', 'R'], 'normalization_type': 'loss+', 'algorithm': 'mgda', 'use_approximation': False, 'scales': {'L': 0.5, 'R': 0.5}, 'exp_id': 'test_mlt'} 

Using full solver
Epoch 0 Started


  1%|          | 1/100 [00:16<26:24, 16.00s/it]

Epoch ended in 16.00476831011474s
Epoch 1 Started


  2%|▏         | 2/100 [00:31<26:05, 15.97s/it]

Epoch ended in 15.898206159938127s
Epoch 2 Started


  3%|▎         | 3/100 [00:47<25:37, 15.85s/it]

Epoch ended in 15.562254683114588s
Epoch 3 Started


  4%|▍         | 4/100 [01:03<25:20, 15.84s/it]

Epoch ended in 15.798544346820563s
Epoch 4 Started


  5%|▌         | 5/100 [01:19<25:02, 15.82s/it]

Epoch ended in 15.76646945020184s
Epoch 5 Started


  6%|▌         | 6/100 [01:34<24:44, 15.79s/it]

Epoch ended in 15.723986465018243s
Epoch 6 Started


  7%|▋         | 7/100 [01:50<24:29, 15.80s/it]

Epoch ended in 15.814092014916241s
Epoch 7 Started


  8%|▊         | 8/100 [02:06<24:22, 15.90s/it]

Epoch ended in 16.120038119144738s
Epoch 8 Started


  9%|▉         | 9/100 [02:22<24:16, 16.00s/it]

Epoch ended in 16.24445695709437s
Epoch 9 Started
Half the learning rate2115


 10%|█         | 10/100 [02:40<24:53, 16.59s/it]

Epoch ended in 17.967109113931656s
Epoch 10 Started


 11%|█         | 11/100 [02:57<24:41, 16.64s/it]

Epoch ended in 16.756754010915756s
Epoch 11 Started


 12%|█▏        | 12/100 [03:13<23:56, 16.33s/it]

Epoch ended in 15.59129299223423s
Epoch 12 Started


 13%|█▎        | 13/100 [03:28<23:23, 16.13s/it]

Epoch ended in 15.662181449122727s
Epoch 13 Started


 14%|█▍        | 14/100 [03:45<23:05, 16.12s/it]

Epoch ended in 16.078830441925675s
Epoch 14 Started


 15%|█▌        | 15/100 [03:59<22:16, 15.72s/it]

Epoch ended in 14.806297885254025s
Epoch 15 Started


 16%|█▌        | 16/100 [04:15<22:01, 15.73s/it]

Epoch ended in 15.729598675854504s
Epoch 16 Started


 17%|█▋        | 17/100 [04:31<21:56, 15.86s/it]

Epoch ended in 16.154383201617748s
Epoch 17 Started


 18%|█▊        | 18/100 [04:47<21:40, 15.86s/it]

Epoch ended in 15.878533512819558s
Epoch 18 Started


 19%|█▉        | 19/100 [05:03<21:24, 15.86s/it]

Epoch ended in 15.84704486373812s
Epoch 19 Started
Half the learning rate4465


 20%|██        | 20/100 [05:20<21:25, 16.07s/it]

Epoch ended in 16.543569652363658s
Epoch 20 Started


 21%|██        | 21/100 [05:36<21:07, 16.04s/it]

Epoch ended in 15.985974702984095s
Epoch 21 Started


 22%|██▏       | 22/100 [05:51<20:44, 15.96s/it]

Epoch ended in 15.768231051973999s
Epoch 22 Started


 23%|██▎       | 23/100 [06:07<20:26, 15.93s/it]

Epoch ended in 15.855333028826863s
Epoch 23 Started


 24%|██▍       | 24/100 [06:23<20:02, 15.83s/it]

Epoch ended in 15.5777299660258s
Epoch 24 Started


 25%|██▌       | 25/100 [06:38<19:42, 15.76s/it]

Epoch ended in 15.608155766967684s
Epoch 25 Started


 26%|██▌       | 26/100 [06:54<19:17, 15.65s/it]

Epoch ended in 15.370994047727436s
Epoch 26 Started


 27%|██▋       | 27/100 [07:09<19:01, 15.64s/it]

Epoch ended in 15.633977610617876s
Epoch 27 Started


 28%|██▊       | 28/100 [07:25<18:52, 15.73s/it]

Epoch ended in 15.941795939113945s
Epoch 28 Started


 29%|██▉       | 29/100 [07:41<18:36, 15.73s/it]

Epoch ended in 15.727257839869708s
Epoch 29 Started
Half the learning rate6815


 30%|███       | 30/100 [07:56<18:12, 15.61s/it]

Epoch ended in 15.31800963031128s
Epoch 30 Started


 31%|███       | 31/100 [08:12<17:51, 15.53s/it]

Epoch ended in 15.344443827867508s
Epoch 31 Started


 32%|███▏      | 32/100 [08:27<17:36, 15.53s/it]

Epoch ended in 15.530299884267151s
Epoch 32 Started


 33%|███▎      | 33/100 [08:43<17:22, 15.56s/it]

Epoch ended in 15.620857571717352s
Epoch 33 Started


 34%|███▍      | 34/100 [08:59<17:11, 15.63s/it]

Epoch ended in 15.78773481072858s
Epoch 34 Started


 35%|███▌      | 35/100 [09:14<16:53, 15.59s/it]

Epoch ended in 15.488583808299154s
Epoch 35 Started


 36%|███▌      | 36/100 [09:30<16:38, 15.61s/it]

Epoch ended in 15.643610467668623s
Epoch 36 Started


 37%|███▋      | 37/100 [09:46<16:25, 15.65s/it]

Epoch ended in 15.738160952925682s
Epoch 37 Started


 38%|███▊      | 38/100 [10:01<16:12, 15.69s/it]

Epoch ended in 15.798099190928042s
Epoch 38 Started


 39%|███▉      | 39/100 [10:17<15:51, 15.59s/it]

Epoch ended in 15.363095710985363s
Epoch 39 Started
Half the learning rate9165


 40%|████      | 40/100 [10:33<15:44, 15.75s/it]

Epoch ended in 16.101273177191615s
Epoch 40 Started


 41%|████      | 41/100 [10:48<15:20, 15.59s/it]

Epoch ended in 15.233121388591826s
Epoch 41 Started


 42%|████▏     | 42/100 [11:04<15:06, 15.64s/it]

Epoch ended in 15.732744091190398s
Epoch 42 Started


 43%|████▎     | 43/100 [11:19<14:52, 15.66s/it]

Epoch ended in 15.71235195081681s
Epoch 43 Started


 44%|████▍     | 44/100 [11:36<14:47, 15.85s/it]

Epoch ended in 16.303636104334146s
Epoch 44 Started


 45%|████▌     | 45/100 [11:51<14:26, 15.75s/it]

Epoch ended in 15.490891072899103s
Epoch 45 Started


 46%|████▌     | 46/100 [12:07<14:10, 15.76s/it]

Epoch ended in 15.774515850003809s
Epoch 46 Started


 47%|████▋     | 47/100 [12:23<13:56, 15.78s/it]

Epoch ended in 15.819985467009246s
Epoch 47 Started


 48%|████▊     | 48/100 [12:39<13:40, 15.78s/it]

Epoch ended in 15.794760087970644s
Epoch 48 Started


 49%|████▉     | 49/100 [12:57<13:57, 16.43s/it]

Epoch ended in 17.922455388121307s
Epoch 49 Started
Half the learning rate11515


 50%|█████     | 50/100 [13:12<13:30, 16.20s/it]

Epoch ended in 15.67321149026975s
Epoch 50 Started


 51%|█████     | 51/100 [13:28<13:03, 15.99s/it]

Epoch ended in 15.50180999096483s
Epoch 51 Started


 52%|█████▏    | 52/100 [13:44<12:47, 15.98s/it]

Epoch ended in 15.966197737958282s
Epoch 52 Started


 53%|█████▎    | 53/100 [13:59<12:18, 15.72s/it]

Epoch ended in 15.109480869956315s
Epoch 53 Started


 54%|█████▍    | 54/100 [14:14<11:52, 15.49s/it]

Epoch ended in 14.928702285978943s
Epoch 54 Started


 55%|█████▌    | 55/100 [14:30<11:40, 15.57s/it]

Epoch ended in 15.770480618346483s
Epoch 55 Started


 56%|█████▌    | 56/100 [14:46<11:39, 15.89s/it]

Epoch ended in 16.623234272934496s
Epoch 56 Started


 57%|█████▋    | 57/100 [15:02<11:21, 15.86s/it]

Epoch ended in 15.776146022137254s
Epoch 57 Started


 58%|█████▊    | 58/100 [15:18<11:04, 15.82s/it]

Epoch ended in 15.741465223021805s
Epoch 58 Started


 59%|█████▉    | 59/100 [15:34<10:59, 16.09s/it]

Epoch ended in 16.69982324866578s
Epoch 59 Started
Half the learning rate13865


 60%|██████    | 60/100 [15:50<10:33, 15.84s/it]

Epoch ended in 15.266133644618094s
Epoch 60 Started


 61%|██████    | 61/100 [16:05<10:16, 15.82s/it]

Epoch ended in 15.766564626712352s
Epoch 61 Started


 62%|██████▏   | 62/100 [16:21<10:00, 15.79s/it]

Epoch ended in 15.718567826785147s
Epoch 62 Started


 63%|██████▎   | 63/100 [16:37<09:43, 15.76s/it]

Epoch ended in 15.675336759071797s
Epoch 63 Started


 64%|██████▍   | 64/100 [16:52<09:25, 15.70s/it]

Epoch ended in 15.550255964044482s
Epoch 64 Started


 65%|██████▌   | 65/100 [17:08<09:06, 15.63s/it]

Epoch ended in 15.463217579759657s
Epoch 65 Started


 66%|██████▌   | 66/100 [17:23<08:50, 15.61s/it]

Epoch ended in 15.568036040756851s
Epoch 66 Started


 67%|██████▋   | 67/100 [17:39<08:35, 15.62s/it]

Epoch ended in 15.650450848974288s
Epoch 67 Started


 68%|██████▊   | 68/100 [17:55<08:19, 15.62s/it]

Epoch ended in 15.609448031987995s
Epoch 68 Started


 69%|██████▉   | 69/100 [18:11<08:08, 15.76s/it]

Epoch ended in 16.079026114195585s
Epoch 69 Started
Half the learning rate16215


 70%|███████   | 70/100 [18:27<07:52, 15.76s/it]

Epoch ended in 15.762025830801576s
Epoch 70 Started


 71%|███████   | 71/100 [18:42<07:34, 15.67s/it]

Epoch ended in 15.455291201826185s
Epoch 71 Started


 72%|███████▏  | 72/100 [18:58<07:20, 15.73s/it]

Epoch ended in 15.85220294399187s
Epoch 72 Started


 73%|███████▎  | 73/100 [19:14<07:06, 15.80s/it]

Epoch ended in 15.97300068102777s
Epoch 73 Started


 74%|███████▍  | 74/100 [19:30<06:50, 15.81s/it]

Epoch ended in 15.820833961945027s
Epoch 74 Started


 75%|███████▌  | 75/100 [19:45<06:31, 15.67s/it]

Epoch ended in 15.330484848935157s
Epoch 75 Started


 76%|███████▌  | 76/100 [20:01<06:17, 15.72s/it]

Epoch ended in 15.846316321752965s
Epoch 76 Started


 77%|███████▋  | 77/100 [20:17<06:04, 15.87s/it]

Epoch ended in 16.20971363224089s
Epoch 77 Started


 78%|███████▊  | 78/100 [20:33<05:51, 15.98s/it]

Epoch ended in 16.234544970095158s
Epoch 78 Started


 79%|███████▉  | 79/100 [20:49<05:35, 15.96s/it]

Epoch ended in 15.919654270168394s
Epoch 79 Started
Half the learning rate18565


 80%|████████  | 80/100 [21:06<05:25, 16.28s/it]

Epoch ended in 17.004878935869783s
Epoch 80 Started


 81%|████████  | 81/100 [21:22<05:05, 16.07s/it]

Epoch ended in 15.588367511052638s
Epoch 81 Started


 82%|████████▏ | 82/100 [21:38<04:48, 16.01s/it]

Epoch ended in 15.853535799775273s
Epoch 82 Started


 83%|████████▎ | 83/100 [21:53<04:28, 15.79s/it]

Epoch ended in 15.293893642257899s
Epoch 83 Started


 84%|████████▍ | 84/100 [22:09<04:12, 15.78s/it]

Epoch ended in 15.736170151270926s
Epoch 84 Started


 85%|████████▌ | 85/100 [22:25<03:57, 15.84s/it]

Epoch ended in 15.969654724933207s
Epoch 85 Started


 86%|████████▌ | 86/100 [22:41<03:42, 15.90s/it]

Epoch ended in 16.038098993711174s
Epoch 86 Started


 87%|████████▋ | 87/100 [22:57<03:29, 16.14s/it]

Epoch ended in 16.704358438029885s
Epoch 87 Started


 88%|████████▊ | 88/100 [23:13<03:12, 16.04s/it]

Epoch ended in 15.805752066895366s
Epoch 88 Started


 89%|████████▉ | 89/100 [23:30<02:58, 16.22s/it]

Epoch ended in 16.624304501339793s
Epoch 89 Started
Half the learning rate20915


 90%|█████████ | 90/100 [23:45<02:39, 15.91s/it]

Epoch ended in 15.181192980147898s
Epoch 90 Started


 91%|█████████ | 91/100 [24:01<02:22, 15.83s/it]

Epoch ended in 15.658204676117748s
Epoch 91 Started


 92%|█████████▏| 92/100 [24:16<02:05, 15.69s/it]

Epoch ended in 15.35512958234176s
Epoch 92 Started


 93%|█████████▎| 93/100 [24:32<01:49, 15.62s/it]

Epoch ended in 15.456271653994918s
Epoch 93 Started


 94%|█████████▍| 94/100 [24:47<01:33, 15.58s/it]

Epoch ended in 15.478736952878535s
Epoch 94 Started


 95%|█████████▌| 95/100 [25:03<01:18, 15.60s/it]

Epoch ended in 15.66035548131913s
Epoch 95 Started


 96%|█████████▌| 96/100 [25:18<01:02, 15.57s/it]

Epoch ended in 15.496790060773492s
Epoch 96 Started


 97%|█████████▋| 97/100 [25:34<00:46, 15.51s/it]

Epoch ended in 15.348755285143852s
Epoch 97 Started


 98%|█████████▊| 98/100 [25:49<00:31, 15.51s/it]

Epoch ended in 15.504949799273163s
Epoch 98 Started


 99%|█████████▉| 99/100 [26:05<00:15, 15.54s/it]

Epoch ended in 15.597728090826422s
Epoch 99 Started
Half the learning rate23265


100%|██████████| 100/100 [26:21<00:00, 15.70s/it]

Epoch ended in 16.085932726971805s



