In [1]:
import tensorflow as tf
import numpy as np
import mlp.tf_utils as utils
from mlp.data_providers import AugmentedCIFAR10DataProvider, AugmentedCIFAR100DataProvider, CIFAR100DataProvider, CIFAR10DataProvider
from mlp.image_transforms import random_flip, random_crop, center_crop, random_flip_small
from mlp.Conv_models import ConvModel, TwoTaskConvModel, TwoTaskConvModelSoftSharing
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

seed = 12345 
rng = np.random.RandomState(seed)

In [2]:
def combined_transformer(inputs, rng):
    inputs = random_crop(inputs, rng)
    return random_flip_small(inputs, rng)

In [3]:
layer_dims = [3, 64, 128, 256, 512, _]
batch_size = 128

bifurcation_point = 3
joint_loss = False
max_pools = [True, True, False, True]
lrns = [True, True, False, False]
lrn_alphas = [0.01, 0.01, 0.01, 0.01]
batch_norms = [True, True, True, True]
l1 = False
l2 = 5e-4
learning_rates = [2e-4, 2e-5, 2e-6]
lrn_schedule = [-1, 40, 80]
optimizer_params = [0.5, 0.9999]  
activation = tf.nn.relu
dropout = True
epochs = 1000
error = 'soft_max_cross_entropy'
image_size = 24  
logs_dir = "tf-log"
soft_loss = 1e-1

In [4]:
for i, ratio in enumerate([1.]):
    optimiser = "SGD"  
    
    ########### MTL CLASSIFIER SGD #################
    tf.reset_default_graph()
    model_name = "MTL_SGD"
    print("MODEL:- " + model_name)
    
    main_train_data = CIFAR100DataProvider(which_set='train', batch_size=batch_size, shuffle_order=False)
    main_valid_data = CIFAR100DataProvider(which_set='valid', batch_size=batch_size, shuffle_order=False)

    aux_train_data = CIFAR10DataProvider(which_set='train', batch_size=batch_size, shuffle_order=False)
    aux_valid_data = CIFAR10DataProvider(which_set='valid', batch_size=batch_size, shuffle_order=False)
    
    m_t_idx = utils.sample_data(main_train_data, ratio)
    m_v_idx = utils.sample_data(main_valid_data, ratio)
    _1 = utils.sample_data(aux_train_data, ratio, m_t_idx)
    _2 = utils.sample_data(aux_valid_data, ratio, m_v_idx)
    
    model = TwoTaskConvModelSoftSharing(conv_layer_dims=layer_dims, main_train_data=main_train_data, main_valid_data=main_valid_data,
                         aux_train_data=aux_train_data, aux_valid_data=aux_valid_data, L1=l1, L2=l2,
                         batch_size=batch_size, learning_rates=learning_rates, learning_rate_schedule=lrn_schedule,
                         optimiser=optimiser, max_pools=max_pools, lin_response_norms=lrns,
                         lin_response_alphas=lrn_alphas, batch_norms=batch_norms, bifurcation_point=bifurcation_point,
                         optimiser_params=optimizer_params, activation=activation, dropout=dropout,
                         epochs=epochs, error=error, input_image_size=image_size, name=model_name, joint_loss=joint_loss,
                                       soft_loss=soft_loss)
    
    #model.create_network()
    #model.initialize_network(logs_dir)
    #model.train_model()
    
    ########### MTL CLASSIFIER ADAM #################
    
    optimiser = "Adam"  
    tf.reset_default_graph()
    model_name = "MTL_ADAM_AUG"
    print("MODEL:- " + model_name)
    
    main_train_data = AugmentedCIFAR100DataProvider(which_set='train', transformer=combined_transformer,
                                                    batch_size=batch_size, shuffle_order=False)
    main_valid_data = CIFAR100DataProvider(which_set='valid', batch_size=batch_size, shuffle_order=False)
    main_valid_data.inputs = center_crop(main_valid_data.inputs, rng)
    
    aux_train_data = AugmentedCIFAR100DataProvider(which_set='train', use_coarse_targets=True, 
                                                   transformer=combined_transformer,
                                                   batch_size=batch_size, shuffle_order=False)
    aux_valid_data = CIFAR100DataProvider(which_set='valid', use_coarse_targets=True,
                                                   batch_size=batch_size, shuffle_order=False)
    aux_valid_data.inputs = center_crop(aux_valid_data.inputs, rng)
    
    m_t_idx = utils.sample_data(main_train_data, ratio)
    m_v_idx = utils.sample_data(main_valid_data, ratio)
    _1 = utils.sample_data(aux_train_data, ratio, m_t_idx)
    _2 = utils.sample_data(aux_valid_data, ratio, m_v_idx)
    
    model = TwoTaskConvModelSoftSharing(conv_layer_dims=layer_dims, main_train_data=main_train_data, main_valid_data=main_valid_data,
                         aux_train_data=aux_train_data, aux_valid_data=aux_valid_data, L1=l1, L2=l2,
                         batch_size=batch_size, learning_rates=learning_rates, learning_rate_schedule=lrn_schedule,
                         optimiser=optimiser, max_pools=max_pools, lin_response_norms=lrns,
                         lin_response_alphas=lrn_alphas, batch_norms=batch_norms, bifurcation_point=bifurcation_point,
                         optimiser_params=optimizer_params, activation=activation, dropout=dropout,
                         epochs=epochs, error=error, input_image_size=image_size, name=model_name, joint_loss=joint_loss,
                                       soft_loss=soft_loss)
    
    model.create_network()
    model.initialize_network(logs_dir)
    model.train_model()
    
    ########### CIFAR100 BASELINE ADAM #####################
    tf.reset_default_graph()
    model_name = "baseline_with_ADAM_AUG"
    print("MODEL:- " + model_name)
    
    train_data = AugmentedCIFAR100DataProvider(which_set='train', transformer=combined_transformer,
                                                    batch_size=batch_size, shuffle_order=False)
    valid_data = CIFAR100DataProvider(which_set='valid', batch_size=batch_size, shuffle_order=False)
    valid_data.inputs = center_crop(valid_data.inputs, rng)
    
    # update baseline data providers to have same samples 
    _3 = utils.sample_data(train_data, ratio, m_t_idx)
    _4 = utils.sample_data(valid_data, ratio, m_v_idx)
    
    model_b = ConvModel(conv_layer_dims=layer_dims, train_data=train_data, valid_data=valid_data,
                 batch_size=batch_size, learning_rates=learning_rates, learning_rate_schedule=lrn_schedule,
                 optimiser=optimiser, L1=l1, L2=l2,
                 max_pools=max_pools, lin_response_norms=lrns, lin_response_alphas=lrn_alphas, batch_norms=batch_norms,
                 optimiser_params=optimizer_params, activation=activation, dropout=dropout,
                 epochs=epochs, error=error, input_image_size=image_size, name=model_name)
    model_b.create_network()
    model_b.initialize_network(logs_dir)
    model_b.train_model()

MODEL:- MTL_SGD
MODEL:- MTL_ADAM_AUG
Setting up model...
Initializing network...
Training model...
New learning rate:  0.0002
Epoch finished:  1
Epoch finished:  2
Epoch finished:  3
Epoch finished:  4
Epoch finished:  5
Epoch finished:  6
Epoch finished:  7
Epoch finished:  8
Epoch finished:  9
Epoch finished:  10
Epoch finished:  11
Epoch finished:  12
Epoch finished:  13
Epoch finished:  14
Epoch finished:  15
Epoch finished:  16
Epoch finished:  17
Epoch finished:  18
Epoch finished:  19
Epoch finished:  20
Epoch finished:  21
Epoch finished:  22
Epoch finished:  23
Epoch finished:  24
Epoch finished:  25
Epoch finished:  26
Epoch finished:  27
Epoch finished:  28
Epoch finished:  29
Epoch finished:  30
Epoch finished:  31
Epoch finished:  32
Epoch finished:  33
Epoch finished:  34
Epoch finished:  35
Epoch finished:  36
Epoch finished:  37
Epoch finished:  38
Epoch finished:  39
Epoch finished:  40
New learning rate:  2e-05
Epoch finished:  41
Epoch finished:  42
Epoch finished:  

Epoch finished:  388
Epoch finished:  389
Epoch finished:  390
Epoch finished:  391
Epoch finished:  392
Epoch finished:  393
Epoch finished:  394
Epoch finished:  395
Epoch finished:  396
Epoch finished:  397
Epoch finished:  398
Epoch finished:  399
Epoch finished:  400
Epoch finished:  401
Epoch finished:  402
Epoch finished:  403
Epoch finished:  404
Epoch finished:  405
Epoch finished:  406
Epoch finished:  407
Epoch finished:  408
Epoch finished:  409
Epoch finished:  410
Epoch finished:  411
Epoch finished:  412
Epoch finished:  413
Epoch finished:  414
Epoch finished:  415
Epoch finished:  416
Epoch finished:  417
Epoch finished:  418
Epoch finished:  419
Epoch finished:  420
Epoch finished:  421
Epoch finished:  422
Epoch finished:  423
Epoch finished:  424
Epoch finished:  425
Epoch finished:  426
Epoch finished:  427
Epoch finished:  428
Epoch finished:  429
Epoch finished:  430
Epoch finished:  431
Epoch finished:  432
Epoch finished:  433
Epoch finished:  434
Epoch finishe

Epoch finished:  779
Epoch finished:  780
Epoch finished:  781
Epoch finished:  782
Epoch finished:  783
Epoch finished:  784
Epoch finished:  785
Epoch finished:  786
Epoch finished:  787
Epoch finished:  788
Epoch finished:  789
Epoch finished:  790
Epoch finished:  791
Epoch finished:  792
Epoch finished:  793
Epoch finished:  794
Epoch finished:  795
Epoch finished:  796
Epoch finished:  797
Epoch finished:  798
Epoch finished:  799
Epoch finished:  800
Epoch finished:  801
Epoch finished:  802
Epoch finished:  803
Epoch finished:  804
Epoch finished:  805
Epoch finished:  806
Epoch finished:  807
Epoch finished:  808
Epoch finished:  809
Epoch finished:  810
Epoch finished:  811
Epoch finished:  812
Epoch finished:  813
Epoch finished:  814
Epoch finished:  815
Epoch finished:  816
Epoch finished:  817
Epoch finished:  818
Epoch finished:  819
Epoch finished:  820
Epoch finished:  821
Epoch finished:  822
Epoch finished:  823
Epoch finished:  824
Epoch finished:  825
Epoch finishe

Epoch finished:  165
Epoch finished:  166
Epoch finished:  167
Epoch finished:  168
Epoch finished:  169
Epoch finished:  170
Epoch finished:  171
Epoch finished:  172
Epoch finished:  173
Epoch finished:  174
Epoch finished:  175
Epoch finished:  176
Epoch finished:  177
Epoch finished:  178
Epoch finished:  179
Epoch finished:  180
Epoch finished:  181
Epoch finished:  182
Epoch finished:  183
Epoch finished:  184
Epoch finished:  185
Epoch finished:  186
Epoch finished:  187
Epoch finished:  188
Epoch finished:  189
Epoch finished:  190
Epoch finished:  191
Epoch finished:  192
Epoch finished:  193
Epoch finished:  194
Epoch finished:  195
Epoch finished:  196
Epoch finished:  197
Epoch finished:  198
Epoch finished:  199
Epoch finished:  200
Epoch finished:  201
Epoch finished:  202
Epoch finished:  203
Epoch finished:  204
Epoch finished:  205
Epoch finished:  206
Epoch finished:  207
Epoch finished:  208
Epoch finished:  209
Epoch finished:  210
Epoch finished:  211
Epoch finishe

Epoch finished:  556
Epoch finished:  557
Epoch finished:  558
Epoch finished:  559
Epoch finished:  560
Epoch finished:  561
Epoch finished:  562
Epoch finished:  563
Epoch finished:  564
Epoch finished:  565
Epoch finished:  566
Epoch finished:  567
Epoch finished:  568
Epoch finished:  569
Epoch finished:  570
Epoch finished:  571
Epoch finished:  572
Epoch finished:  573
Epoch finished:  574
Epoch finished:  575
Epoch finished:  576
Epoch finished:  577
Epoch finished:  578
Epoch finished:  579
Epoch finished:  580
Epoch finished:  581
Epoch finished:  582
Epoch finished:  583
Epoch finished:  584
Epoch finished:  585
Epoch finished:  586
Epoch finished:  587
Epoch finished:  588
Epoch finished:  589
Epoch finished:  590
Epoch finished:  591
Epoch finished:  592
Epoch finished:  593
Epoch finished:  594
Epoch finished:  595
Epoch finished:  596
Epoch finished:  597
Epoch finished:  598
Epoch finished:  599
Epoch finished:  600
Epoch finished:  601
Epoch finished:  602
Epoch finishe

Epoch finished:  947
Epoch finished:  948
Epoch finished:  949
Epoch finished:  950
Epoch finished:  951
Epoch finished:  952
Epoch finished:  953
Epoch finished:  954
Epoch finished:  955
Epoch finished:  956
Epoch finished:  957
Epoch finished:  958
Epoch finished:  959
Epoch finished:  960
Epoch finished:  961
Epoch finished:  962
Epoch finished:  963
Epoch finished:  964
Epoch finished:  965
Epoch finished:  966
Epoch finished:  967
Epoch finished:  968
Epoch finished:  969
Epoch finished:  970
Epoch finished:  971
Epoch finished:  972
Epoch finished:  973
Epoch finished:  974
Epoch finished:  975
Epoch finished:  976
Epoch finished:  977
Epoch finished:  978
Epoch finished:  979
Epoch finished:  980
Epoch finished:  981
Epoch finished:  982
Epoch finished:  983
Epoch finished:  984
Epoch finished:  985
Epoch finished:  986
Epoch finished:  987
Epoch finished:  988
Epoch finished:  989
Epoch finished:  990
Epoch finished:  991
Epoch finished:  992
Epoch finished:  993
Epoch finishe

In [5]:
model.optimiser_params

[0.5, 0.9999]