## Initialization  

In [1]:
%load_ext autoreload
%autoreload 2
import os 
import sys
sys.path.insert(0, './src')
import time
import argparse
import yaml
import types
import copy, pprint
from time import sleep
from datetime import datetime
import numpy  as np
import torch  
import wandb
import pandas as pd
from utils.notebook_modules import initialize, init_dataloaders, init_environment, init_wandb, \
                                   training_prep, disp_dataloader_info,disp_info_1, \
                                   warmup_phase, weight_policy_training, disp_gpu_info

from utils.util import (print_separator, print_heading, timestring, print_loss) #, print_underline, load_from_pickle,
#                       print_dbg, get_command_line_args ) 

pp = pprint.PrettyPrinter(indent=4)
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions(precision=6, linewidth=132)
pd.options.display.width = 132
# torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
# sys.path.insert(0, '/home/kbardool/kusanagi/AdaSparseChem/src')
# print(sys.path)
# disp_gpu_info() 
os.environ["WANDB_NOTEBOOK_NAME"] = "Adashare_Training.ipynb"


## Create Environment

### Parse Input Args  - Read YAML config file

In [2]:
# RESUME_MODEL_CKPT = 'model_train_ep_25_seed_0088'

## For RESTARTING
##
# input_args = " --config yamls/chembl_3task_train.yaml " \
#              " --resume " \
#              " --exp_id      330i85cg" \
#              " --exp_name    0308_1204" \
#              " --exp_desc    Train with dropout 0.5" \
#              " --seed_idx    0 "\
#              " --batch_size  128" \
#              " --lambda_sparsity  0.01"\
#              " --lambda_sharing   0.01" 
## get command line arguments

In [4]:
##  For Initiating 
##
input_args = " --config yamls/chembl_3task_train.yaml " \
             " --exp_desc    6x100 lyrs,dropout 0.5, weight 105 bch/ep policy 105 bch/ep " \
             " --seed_idx            0" \
             " --batch_size        128" \
             " --task_lr          0.001" \
             " --backbone_lr      0.001" \
             " --policy_lr        0.01" \
             " --lambda_sparsity  0.02" \
             " --lambda_sharing   0.01" 

                         

In [5]:
opt, ns = initialize(input_args, build_folders = True)


  command line parms : 
------------------------
 config...................  yamls/chembl_3task_train.yaml
 exp_id...................  None
 exp_name.................  None
 folder_sfx...............  None
 exp_desc.................  6x100 lyrs,dropout 0.5, weight 105 bch/ep policy 105 bch/ep
 seed_idx.................  0
 batch_size...............  128
 backbone_lr..............  0.001
 task_lr..................  0.001
 policy_lr................  0.01
 decay_lr_rate............  None
 decay_lr_freq............  None
 lambda_sparsity..........  0.02
 lambda_sharing...........  0.01
 gpu_ids..................  [0]
 resume...................  False
 cpu......................  False



##################################################
################### READ YAML ####################
##################################################


 log_dir              create folder:  ../experiments/AdaSparseChem/100x6_0324_1359_plr0.01_sp0.02_sh0.01_lr0.001
 result_dir           folder exists:  .

### Setup Dataloader and Model  

In [6]:
dldrs = init_dataloaders(opt)

disp_dataloader_info(dldrs)

environ = init_environment(ns, opt, is_train = True, policy_learning = False, display_cfg = False)

# ********************************************************************
# **************** define optimizer and schedulers *******************
# ********************************************************************                                
environ.define_optimizer(policy_learning=False)
environ.define_scheduler(policy_learning=False)


##################################################
############### CREATE DATALOADERS ###############
##################################################

 trainset.y_class                                   :  [(13331, 5), (13331, 5), (13331, 5)] 
 trainset1.y_class                                  :  [(13331, 5), (13331, 5), (13331, 5)] 
 trainset2.y_class                                  :  [(13331, 5), (13331, 5), (13331, 5)] 
 valset.y_class                                     :  [(4137, 5), (4137, 5), (4137, 5)]  
 testset.y_class                                    :  [(920, 5), (920, 5), (920, 5)]  
                                 
 size of training set 0 (warm up)                   :  13331 
 size of training set 1 (network parms)             :  13331 
 size of training set 2 (policy weights)            :  13331 
 size of validation set                             :  4137 
 size of test set                                   :  920 
                               Total           

In [7]:
print(f" Initial alphas LR    : {environ.optimizers['alphas'].param_groups[0]['initial_lr']}   Current LR: {environ.optimizers['alphas'].param_groups[0]['lr'] }")
print(f" Current LR: {environ.optimizers['weights'].param_groups[0]['lr']}")
print(f" Current LR: {environ.optimizers['weights'].param_groups[1]['lr']}")

 Initial alphas LR    : 0.01   Current LR: 0.01
 Current LR: 0.001
 Current LR: 0.001


In [8]:
environ.optimizers['weights'].param_groups[0]

{'params': [Parameter containing:
  tensor([[ 0.054305,  0.073865, -0.079888,  0.079335,  0.040790,  0.061847, -0.047831,  0.095938, -0.037845,  0.097123,  0.002437,
            0.035050, -0.008040, -0.046561, -0.088728,  0.020019, -0.045145,  0.098016, -0.078546,  0.003124, -0.087069,  0.049226,
            0.083186,  0.073277,  0.049136,  0.004473, -0.035649,  0.049682,  0.098840, -0.059525, -0.024014, -0.045792, -0.096087,
           -0.064727,  0.088230, -0.066036,  0.045381,  0.040653, -0.011709,  0.072606,  0.004199,  0.031642, -0.074879,  0.074629,
           -0.034470,  0.025022,  0.047658, -0.017332,  0.039687,  0.096168,  0.078396,  0.054945, -0.022197,  0.028004, -0.006620,
            0.072863,  0.080161, -0.070892, -0.066525,  0.059054,  0.040922,  0.082037, -0.080134, -0.058053,  0.088985, -0.054591,
           -0.093486,  0.020917,  0.086113,  0.025513,  0.095772,  0.044221, -0.088343,  0.071039,  0.031189, -0.077064, -0.039500,
            0.056883, -0.025926, -0.021946

###  Weights and Biases Initialization 

In [9]:
init_wandb(ns, opt, environment = environ)

print(f" PROJECT NAME: {ns.wandb_run.project}\n"
      f" RUN ID      : {ns.wandb_run.id} \n"
      f" RUN NAME    : {ns.wandb_run.name}") 

[34m[1mwandb[0m: Currently logged in as: [33mkbardool[0m (use `wandb login --relogin` to force relogin)


65xmlssn 0324_1359 AdaSparseChem


 PROJECT NAME: AdaSparseChem
 RUN ID      : 65xmlssn 
 RUN NAME    : 0324_1359
 PROJECT NAME: AdaSparseChem
 RUN ID      : 65xmlssn 
 RUN NAME    : 0324_1359


In [17]:
# ns.wandb_run.finish()

### Initiate / Resume Training Prep

In [10]:
if opt['train']['resume']:
    print(opt['train']['which_iter'])
    print(opt['paths']['checkpoint_dir'])
    print(RESUME_MODEL_CKPT)
    # opt['train']['resume'] = True
    # opt['train']['which_iter'] = 'warmup_ep_40_seed_0088'
if opt['train']['resume']:
    print_separator('Resume training')
    loaded_iter, loaded_epoch = environ.load_checkpoint(RESUME_MODEL_CKPT, path = opt['paths']['checkpoint_dir'], verbose = True)
    print(loaded_iter, loaded_epoch)    
#     current_iter = environ.load_checkpoint(opt['train']['which_iter'])
    environ.networks['mtl-net'].reset_logits()
    val_metrics = load_from_pickle(opt['paths']['checkpoint_dir'], RESUME_METRICS_CKPT)
    # training_prep(ns, opt, environ, dldrs, epoch = loaded_epoch, iter = loaded_iter )

else:
    print_separator('Initiate Training ')

##################################################
############### Initiate Training  ###############
##################################################


### Training Preparation

In [11]:
training_prep(ns, opt, environ, dldrs )

 cuda available [0]
 set print_freq to length of train loader: 105
 set eval_iters to length of val loader  : 33


In [12]:
disp_info_1(ns, opt, environ)


 Num_blocks                : 6                                

 batch size                : 128 
 batches/ Weight trn epoch : 105 
 batches/ Policy trn epoch : 105                                 

 Print Frequency           : -1 
 Config Val Frequency      : 500 
 Config Val Iterations     : -1 
 Val iterations            : 33 
 which_iter                : warmup 
 train_resume              : False                                 
 
 fix BN parms              : False 
 Task LR                   : 0.001 
 Backbone LR               : 0.001                                 

 Sharing  regularization   : 0.01 
 Sparsity regularization   : 0.02 
 Task     regularization   : 1                                 

 Current epoch             : 0  
 Warm-up epochs            : 75 
 Training epochs           : 125


In [13]:
print(environ.disp_for_excel())


    folder: 100x6_0324_1359_plr0.01_sp0.02_sh0.01_lr0.001
    layers: 6 [100, 100, 100, 100, 100, 100] 
    
    middle dropout         : 0.5
    last dropout           : 0.5
    diff_sparsity_weights  : False
    skip_layer             : 0
    is_curriculum          : False
    curriculum_speed       : 3
    
    task_lr                : 0.001
    backbone_lr            : 0.001
    decay_lr_rate          : 0.75
    decay_lr_freq          : 40
    
    policy_lr              : 0.01
    policy_decay_lr_rate   : 0.75
    policy_decay_lr_freq   : 50
    lambda_sparsity        : 0.02
    lambda_sharing         : 0.01
    lambda_tasks           : 1
    
    Gumbel init_temp       : 4
    Gumbel decay_temp      : 0.965
    Gumbel decay_temp_freq : 16
    Logit init_method      : random
    Logit init_neg_logits  : None
    Logit hard_sampling    : False
    Warm-up epochs         : 75
    training epochs        : 125
    Data split ratios      : [0.725, 0.225, 0.05]



## Warmup Training

In [14]:
# environ.display_trained_policy(ns.current_epoch,out=sys.stdout)
# ns.stop_epoch_warmup = 10
# ns.warmup_epochs = 10
print(ns.warmup_epochs, ns.current_epoch)
print_heading(f" Last Epoch: {ns.current_epoch}   # of warm-up epochs to do:  {ns.warmup_epochs} - Run epochs {ns.current_epoch+1} to {ns.current_epoch + ns.warmup_epochs}", verbose = True)

75 0
------------------------------------------------------------------------
 Last Epoch: 0   # of warm-up epochs to do:  75 - Run epochs 1 to 75
------------------------------------------------------------------------ 



In [15]:
# warmup_phase(ns,opt, environ, dldrs, epochs = 25)
warmup_phase(ns,opt, environ, dldrs)

------------------------------------------------------------------------
 Last Epoch: 0   # of warm-up epochs to do:  75 - Run epochs 1 to 75
------------------------------------------------------------------------ 

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |
    1 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |   10.5212   4.1594e-02   2.1694e-04   10.5630 |   0.68903   0.57057   0.58071   0.57014 |   10.3354   4.1594e-02   2.1694e-04    10.3772 |   7.4 |
    2 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |   10.0123   4.1594e-02   2.1694e-04   10.0541 |   0.68283   0.60611   0.61873   0.60573 |   10.2434   4.1594e-02   2.1694e-04    10.2852 |   7.6 |
    3 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |    9.9372   4.1594e-02   2.1694e-04    9.9790 |   0.67496   0.62836   0.63553   0.62800 |   10.1262   4.1594e-02   2.1694e-04 

   39 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |    5.0632   4.1594e-02   2.1694e-04    5.1050 |   0.54292   0.80062   0.80044   0.80051 |    8.1424   4.1594e-02   2.1694e-04     8.1842 |   7.1 |
   40 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |    5.1329   4.1594e-02   2.1694e-04    5.1747 |   0.54445   0.80245   0.80214   0.80235 |    8.1705   4.1594e-02   2.1694e-04     8.2123 |   7.9 |
   41 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |    5.8953   4.1594e-02   2.1694e-04    5.9371 |   0.54543   0.80368   0.80294   0.80358 |    8.1779   4.1594e-02   2.1694e-04     8.2197 |   7.1 |
   42 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |    5.7173   4.1594e-02   2.1694e-04    5.7591 |   0.54211   0.80368   0.80321   0.80358 |    8.1289   4.1594e-02   2.1694e-04     8.1707 |   7.2 |
   43 |   1.00e-03   1.00e-03   1.00e-02  4.000e+00 |    4.3632   4.1594e-02   2.1694e-04    4.4050 |   0.54141   0.80681   0.80588   0.80670 |    8.1226   4.1594e-02   2.1694e-04     8.1644 |   7

In [None]:
# warmup_phase(ns,opt, environ, dldrs, epochs = 25)

In [15]:
# ns.wandb_run.finish()

In [16]:
# ns.wandb_run.finish()

In [17]:
environ.losses

{'parms': {'gumbel_temp': 4,
  'train_layers': 0,
  'lr_0': 0.001,
  'lr_1': 0.001,
  'policy_lr': 0.01,
  'lambda_sparsity': 0.02,
  'lambda_sharing': 0.01,
  'lambda_tasks': 1},
 'task': {'total': tensor(8.673099, device='cuda:0', dtype=torch.float64),
  'task1': tensor(2.684705, device='cuda:0', dtype=torch.float64),
  'task2': tensor(2.954407, device='cuda:0', dtype=torch.float64),
  'task3': tensor(3.033987, device='cuda:0', dtype=torch.float64)},
 'task_mean': {'total': tensor(1.734620, device='cuda:0', dtype=torch.float64),
  'task1': tensor(0.536941, device='cuda:0', dtype=torch.float64),
  'task2': tensor(0.590881, device='cuda:0', dtype=torch.float64),
  'task3': tensor(0.606797, device='cuda:0', dtype=torch.float64)},
 'sparsity': {'total': tensor(0.041594, device='cuda:0'),
  'task1': tensor(0.013862, device='cuda:0'),
  'task2': tensor(0.013873, device='cuda:0'),
  'task3': tensor(0.013859, device='cuda:0')},
 'sharing': {'total': tensor(0.000217, device='cuda:0')},
 'tota

In [18]:
environ.val_metrics

{'parms': {'gumbel_temp': 4,
  'train_layers': 0,
  'lr_0': 0.001,
  'lr_1': 0.001,
  'policy_lr': 0.01,
  'lambda_sparsity': 0.03,
  'lambda_sharing': 0.01,
  'lambda_tasks': 1},
 'task': {'total': 8.197446869225011,
  'task1': 2.6417922839840293,
  'task2': 2.756808388999536,
  'task3': 2.7988461962414473},
 'task_mean': {'total': 1.6394893738450027,
  'task1': 0.5283584567968057,
  'task2': 0.5513616777999073,
  'task3': 0.5597692392482894},
 'sparsity': {'total': 0.06234907731413841,
  'task1': 0.02078135870397091,
  'task2': 0.020778220146894455,
  'task3': 0.02078949846327305},
 'sharing': {'total': 0.00022678218374494463},
 'total': {'total': 8.260022728722895,
  'total_mean': 1.702065233342886,
  'task': 8.197446869225011,
  'policy': 0.06257585949788336},
 'task1': {'classification': {'_type': 'table-file',
   'path': 'media/table/classification_149_3227c0901b92f61b7164.table.json',
   'sha256': '3227c0901b92f61b71642e159ca3228ea4dddf7e124440b8ba79493c19f8e1a2',
   'size': 105

#### display parms

In [18]:
print( f" Backbone Learning Rate      : {environ.opt['train']['backbone_lr']}\n"
       f" Tasks    Learning Rate      : {environ.opt['train']['task_lr']}\n"
       f" Policy   Learning Rate      : {environ.opt['train']['policy_lr']}\n")
print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n\n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" #
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}") #

 Backbone Learning Rate      : 0.001
 Tasks    Learning Rate      : 0.001
 Policy   Learning Rate      : 0.01

 Sparsity regularization     : 0.02
 Sharing  regularization     : 0.01 

 Tasks    regularization     : 1   
 Gumbel Temp                 : 4.0000         
 Gumbel Temp decay           : 16


In [19]:
# environ.opt['train']['policy_lr'] = 0.01
# opt['train']['policy_lr']         = 0.01
# environ.opt['train']['lambda_sparsity'] = 0.1
# environ.opt['train']['lambda_sharing']  = 0.01
# environ.opt['train']['lambda_tasks']    = 1.0
# environ.opt['train']['decay_temp_freq'] = 2
print(environ.optimizers['alphas'].param_groups)
# print(environ.optimizers['weights'].param_groups)
print('initial lr: ', environ.optimizers['alphas'].param_groups[0]['initial_lr'] , 'current lr: ', environ.optimizers['alphas'].param_groups[0]['lr'],)
print('current lr: ', environ.optimizers['weights'].param_groups[0]['lr'])
print('current lr: ', environ.optimizers['weights'].param_groups[1]['lr'])

[{'params': [Parameter containing:
tensor([[ 9.265881e-04, -1.235903e-03],
        [ 5.588943e-05,  6.319766e-04],
        [-6.584334e-04,  3.764827e-04],
        [-7.700642e-04, -1.022177e-04],
        [-1.066945e-03,  5.982257e-04],
        [ 1.315012e-03,  2.184580e-04]], device='cuda:0'), Parameter containing:
tensor([[ 7.603760e-04,  3.758761e-04],
        [ 2.285245e-03, -5.118694e-05],
        [ 3.820481e-04, -1.171452e-03],
        [ 2.224697e-03,  3.282016e-04],
        [-4.551781e-04, -1.086531e-03],
        [-8.322308e-04, -2.218975e-04]], device='cuda:0'), Parameter containing:
tensor([[ 0.001423,  0.000156],
        [ 0.000209, -0.000176],
        [ 0.000879,  0.000876],
        [-0.001247,  0.002500],
        [-0.000269, -0.001266],
        [ 0.000376,  0.001677]], device='cuda:0')], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0005, 'amsgrad': False, 'initial_lr': 0.01}]
initial lr:  0.01 current lr:  0.01
current lr:  0.001
current lr:  0.001


## Weight & Policy Training

### Weight/Policy Training Preparation

In [None]:
# ns.flag_warmup = True

In [20]:
if ns.flag_warmup:
    print_heading( f"** {timestring()} \n"
                   f"** Training epoch: {ns.current_epoch} iter: {ns.current_iter}   flag: {ns.flag} \n"
                   f"** Set optimizer and scheduler to policy_learning = True (Switch weight optimizer from ADAM to SGD)\n"
                   f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
                   f"** Take checkpoint and block gradient flow through Policy net", verbose=True)
#     environ.define_optimizer(policy_learning=True)
#     environ.define_scheduler(policy_learning=True)
    ns.flag_warmup = False
    ns.flag = 'update_w'
    environ.fix_alpha()
    environ.free_weights(opt['fix_BN'])

------------------------------------------------------------------------------------------------------------------------
** 2022-03-24 14:21:59:113811 
** Training epoch: 75 iter: 7875   flag: update_w 
** Set optimizer and scheduler to policy_learning = True (Switch weight optimizer from ADAM to SGD)
** Switch from Warm Up training to Alternate training Weights & Policy 
** Take checkpoint and block gradient flow through Policy net
------------------------------------------------------------------------------------------------------------------------ 



In [21]:
# num_train_layers = None 
# environ.opt['is_curriculum'] = True
# environ.opt['curriculum_speed'] = 4
# ns.num_train_layers = None

In [22]:
print(ns.current_epoch,ns.training_epochs)
print(f"ns.current_epoch           : {ns.current_epoch}") 
print(f"ns.current_iters           : {ns.current_iter} \n")  
print(f"ns.training_epochs         : {ns.training_epochs}") 
# print(f"ns.stop_epoch_training     : {ns.stop_epoch_training}")
print(f"Batches in weight epoch    : {ns.stop_iter_w}")
print(f"Batches in policy epoch    : {ns.stop_iter_a}")
print(f"num_train_layers           : {ns.num_train_layers}")
print()

75 125
ns.current_epoch           : 75
ns.current_iters           : 7875 

ns.training_epochs         : 125
Batches in weight epoch    : 105
Batches in policy epoch    : 105
num_train_layers           : None



In [23]:
# ns.training_epochs = 75
print_loss(environ.val_metrics, title = f"[e] Last ep:{ns.current_epoch}  it:{ns.current_iter}")
# environ.display_trained_policy(ns.current_epoch)
# environ.display_trained_logits(ns.current_epoch)

[e] Last ep:75  it:7875 -  Total Loss: 8.3767     
Task: 8.3349   Sparsity: 4.15940e-02    Sharing: 2.16943e-04 


In [34]:
ns.current_epoch = 200
ns.flag = 'update_weights'

In [31]:
print_heading(f" Last Epoch Completed : {ns.current_epoch}       # of epochs to run:  {ns.training_epochs} -->  epochs {ns.current_epoch+1} to {ns.training_epochs + ns.current_epoch}"
              f"\n policy_learning rate : {environ.opt['train']['policy_lr']} "
              f"\n lambda_sparsity      : {environ.opt['train']['lambda_sparsity']}"
              f"\n lambda_sharing       : {environ.opt['train']['lambda_sharing']}"
              f"\n curriculum training  : {opt['is_curriculum']}     cirriculum speed: {opt['curriculum_speed']}     num_training_layers : {ns.num_train_layers}", 
              verbose = True)

------------------------------------------------------------------------------------------------------------------------
 Last Epoch Completed : 200       # of epochs to run:  125 -->  epochs 201 to 325
 policy_learning rate : 0.01 
 lambda_sparsity      : 0.02
 lambda_sharing       : 0.01
 curriculum training  : False     cirriculum speed: 3     num_training_layers : 6
------------------------------------------------------------------------------------------------------------------------ 



### Weight/Policy Training

In [None]:
# weight_policy_training(ns, opt, environ, dldrs, epochs = 100)
weight_policy_training(ns, opt, environ, dldrs)

------------------------------------------------------------------------------------------------------------------------
 Last Epoch Completed : 200   # of epochs to run:  125 -->  epochs 201 to 325    
 policy_learning rate : 0.01      
 lambda_sparsity      : 0.02
 lambda_sharing       : 0.01 
 curriculum training  : False     cirriculum speed: 3     num_training_layers : 6
------------------------------------------------------------------------------------------------------------------------ 

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |
  201 |   4.22e-04   4.22e-04   5.62e-03  3.117e+00 |    1.3567   3.8673e-02   1.6323e-04    1.3955 |   0.73789   0.81544   0.81764   0.81535 |   11.0494   3.8673e-02   1.6323e-04    11.0882 |  12.0 |
  201 |   4.22e-04   4.22e-04   5.62e-03  3.117e+00 |    1.2647   3.8714e-02   1.6241e-04    1.3

  215 |   4.22e-04   4.22e-04   5.62e-03  3.008e+00 |    1.3631   4.2233e-02   1.6102e-04    1.4055 |   0.75782   0.81630   0.81881   0.81621 |   11.3832   4.2233e-02   1.6102e-04    11.4256 |  12.3 |
  215 |   4.22e-04   4.22e-04   5.62e-03  3.008e+00 |    1.3013   4.1847e-02   1.8554e-04    1.3433 |   0.78512   0.81399   0.81765   0.81386 |   11.7575   4.1841e-02   2.1063e-04    11.7996 |  11.3 |

[e] Policy training epoch:215  it:37275 -  Total Loss: 11.7996     
Task: 11.7575   Sparsity: 4.18410e-02    Sharing: 2.10629e-04 

 epch: 215   softmax      s        softmax       s        softmax       s
 -----  ----------------- -    ----------------- -    ----------------- - 
   1    0.1528    0.8472  0    0.1625    0.8375  0    0.1570    0.8430  0
   2    0.7492    0.2508  1    0.6817    0.3183  1    0.6153    0.3847  1
   3    0.8219    0.1781  1    0.7077    0.2923  1    0.6777    0.3223  1
   4    0.4187    0.5813  0    0.4846    0.5154  0    0.5642    0.4358  1
   5    0.2599    0.

  228 |   4.22e-04   4.22e-04   4.22e-03  2.903e+00 |    0.8637   4.2895e-02   8.8387e-05    0.9067 |   0.79459   0.81484   0.81789   0.81474 |   11.8592   4.2883e-02   7.6284e-05    11.9022 |  10.2 |
  229 |   4.22e-04   4.22e-04   4.22e-03  2.903e+00 |    1.1089   4.2883e-02   7.6284e-05    1.1519 |   0.76352   0.81580   0.81828   0.81570 |   11.4181   4.2883e-02   7.6284e-05    11.4611 |  12.0 |
  229 |   4.22e-04   4.22e-04   4.22e-03  2.903e+00 |    1.0978   4.2491e-02   1.1286e-04    1.1404 |   0.77165   0.81569   0.81855   0.81559 |   11.5788   4.2504e-02   1.4889e-04    11.6215 |  10.7 |
  230 |   4.22e-04   4.22e-04   4.22e-03  2.903e+00 |    1.2092   4.2504e-02   1.4889e-04    1.2519 |   0.77727   0.81420   0.81710   0.81408 |   11.6253   4.2504e-02   1.4889e-04    11.6680 |  11.5 |
  230 |   4.22e-04   4.22e-04   4.22e-03  2.903e+00 |    1.3824   4.2975e-02   1.8721e-04    1.4255 |   0.79653   0.81651   0.81844   0.81639 |   12.0055   4.2971e-02   1.9572e-04    12.0487 |  10

  241 |   3.16e-04   3.16e-04   4.22e-03  2.801e+00 |    0.8267   4.2967e-02   1.3918e-04    0.8698 |   0.80057   0.81559   0.81822   0.81549 |   12.0208   4.2961e-02   1.7167e-04    12.0640 |  10.8 |
  242 |   3.16e-04   3.16e-04   4.22e-03  2.801e+00 |    1.0138   4.2961e-02   1.7167e-04    1.0569 |   0.78315   0.81456   0.81776   0.81443 |   11.7588   4.2961e-02   1.7167e-04    11.8019 |  12.2 |
  242 |   3.16e-04   3.16e-04   4.22e-03  2.801e+00 |    0.9942   4.3042e-02   1.4200e-04    1.0374 |   0.77396   0.81571   0.81884   0.81562 |   11.6145   4.3038e-02   1.5484e-04    11.6577 |  10.5 |
  243 |   3.16e-04   3.16e-04   4.22e-03  2.801e+00 |    1.4627   4.3038e-02   1.5484e-04    1.5059 |   0.80558   0.81530   0.81808   0.81519 |   12.0410   4.3038e-02   1.5484e-04    12.0842 |  11.8 |
  243 |   3.16e-04   3.16e-04   4.22e-03  2.801e+00 |    1.1254   4.3821e-02   1.0224e-04    1.1693 |   0.80373   0.81576   0.81791   0.81567 |   12.0476   4.3828e-02   9.7723e-05    12.0915 |  10

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |
  256 |   3.16e-04   3.16e-04   4.22e-03  2.703e+00 |    0.8449   4.2513e-02   1.5793e-04    0.8876 |   0.77928   0.81691   0.81858   0.81680 |   11.7135   4.2513e-02   1.5793e-04    11.7562 |  11.9 |
  256 |   3.16e-04   3.16e-04   4.22e-03  2.703e+00 |    1.3677   4.2864e-02   1.2642e-04    1.4107 |   0.78421   0.81470   0.81795   0.81460 |   11.7818   4.2877e-02   1.3460e-04    11.8248 |  10.7 |
  257 |   3.16e-04   3.16e-04   4.22e-03  2.703e+00 |    1.3725   4.2877e-02   1.3460e-04    1.4155 |   0.81067   0.81283   0.81581   0.81272 |   12.1666   4.2877e-02   1.3460e-04    12.2096 |  12.3 |
  257 |   3.16e-04   3.16e-04   4.22e-03  2.703e+00 |    1.1888   4.3333e-02   1.6920e-04    1.2323 |   0.78161   0.81530   0.81743   0.81519 |   11.7129   4.3344e-02   1.5742e-04    11.7564 |  11

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |
  271 |   3.16e-04   3.16e-04   4.22e-03  2.608e+00 |    0.9363   4.3554e-02   1.2463e-04    0.9800 |   0.81745   0.81338   0.81718   0.81326 |   12.2906   4.3554e-02   1.2463e-04    12.3343 |  11.9 |
  271 |   3.16e-04   3.16e-04   4.22e-03  2.608e+00 |    1.2569   4.4423e-02   1.7043e-04    1.3015 |   0.81238   0.81573   0.81835   0.81563 |   12.1939   4.4441e-02   1.4291e-04    12.2385 |  11.4 |
  272 |   3.16e-04   3.16e-04   4.22e-03  2.608e+00 |    1.2777   4.4441e-02   1.4291e-04    1.3223 |   0.81725   0.81431   0.81687   0.81421 |   12.2935   4.4441e-02   1.4291e-04    12.3381 |  12.0 |
  272 |   3.16e-04   3.16e-04   4.22e-03  2.608e+00 |    1.2140   4.4415e-02   1.5816e-04    1.2586 |   0.82726   0.81361   0.81678   0.81351 |   12.4332   4.4417e-02   1.5478e-04    12.4778 |  10


[e] Policy training epoch:285  it:51975 -  Total Loss: 12.3895     
Task: 12.3444   Sparsity: 4.49968e-02    Sharing: 1.70983e-04 

 epch: 285   softmax      s        softmax       s        softmax       s
 -----  ----------------- -    ----------------- -    ----------------- - 
   1    0.1560    0.8440  0    0.1653    0.8347  0    0.1738    0.8262  0
   2    0.7916    0.2084  1    0.6981    0.3019  1    0.6493    0.3507  1
   3    0.8457    0.1543  1    0.7140    0.2860  1    0.6494    0.3506  1
   4    0.4764    0.5236  0    0.5638    0.4362  1    0.5815    0.4185  1
   5    0.2660    0.7340  0    0.4203    0.5797  0    0.3524    0.6476  0
   6    0.3575    0.6425  0    0.3366    0.6634  0    0.3852    0.6148  0


Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |
  286 |   2.37e-04   2.37e-04   3.16e-03  2.517e+00 |    0.9471   4.499

  299 |   2.37e-04   2.37e-04   3.16e-03  2.517e+00 |    1.2525   4.4915e-02   1.1928e-04    1.2975 |   0.82225   0.81537   0.81775   0.81525 |   12.3681   4.4926e-02   1.2180e-04    12.4131 |  10.7 |
 decay gumbel softmax to 2.4290831317510966
  300 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    0.8786   4.4926e-02   1.2180e-04    0.9236 |   0.83591   0.81444   0.81757   0.81433 |   12.5156   4.4926e-02   1.2180e-04    12.5606 |  12.3 |
  300 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    0.9432   4.5336e-02   1.0694e-04    0.9887 |   0.82994   0.81616   0.81846   0.81604 |   12.4635   4.5354e-02   8.2510e-05    12.5089 |  11.4 |

[e] Policy training epoch:300  it:55125 -  Total Loss: 12.5089     
Task: 12.4635   Sparsity: 4.53539e-02    Sharing: 8.25096e-05 

 epch: 300   softmax      s        softmax       s        softmax       s
 -----  ----------------- -    ----------------- -    ----------------- - 
   1    0.1408    0.8592  0    0.1769    0.8231  0    0.1653    0.8347 

  313 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    0.9519   4.7111e-02   1.2893e-04    0.9991 |   0.84172   0.81482   0.81766   0.81471 |   12.6633   4.7111e-02   1.2893e-04    12.7106 |  11.8 |
  313 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    0.9391   4.7281e-02   1.3497e-04    0.9865 |   0.85318   0.81427   0.81705   0.81417 |   12.8010   4.7290e-02   1.5580e-04    12.8485 |  11.0 |
  314 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    0.7924   4.7290e-02   1.5580e-04    0.8398 |   0.82778   0.81669   0.81867   0.81656 |   12.4372   4.7290e-02   1.5580e-04    12.4847 |  12.4 |
  314 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    0.9189   4.7680e-02   9.6948e-05    0.9666 |   0.86262   0.81502   0.81756   0.81490 |   12.9082   4.7689e-02   1.0395e-04    12.9559 |  10.4 |
  315 |   2.37e-04   2.37e-04   3.16e-03  2.429e+00 |    1.0508   4.7689e-02   1.0395e-04    1.0986 |   0.84580   0.81567   0.81731   0.81558 |   12.7106   4.7689e-02   1.0395e-04    12.7584 |  13

In [21]:
ns.best_epoch, ns.best_iter, ns.best_value

(152, 23940, 0.8217469868335131)

### Close WandB run

In [31]:
wandb.finish()




VBox(children=(Label(value='7.157 MB of 7.157 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇███
train_time,▄▁▂▂▃▁▅▅▆▅▄▅▅█▅▅▄▄▄▅▅▄▅▅▅▄▄▅▆▆▆▆▅▄▄▅▅▅▄▅

0,1
epoch,325.0
train_time,9.89502


In [None]:
# ns.best_epoch = 0
# from utils.notebook_modules import wrapup_phase
# wrapup_phase(ns, opt, environ)

In [None]:
print( f" Backbone Learning Rate      : {environ.opt['train']['backbone_lr']}\n"
       f" Tasks    Learning Rate      : {environ.opt['train']['task_lr']}\n"
       f" Policy   Learning Rate      : {environ.opt['train']['policy_lr']}\n")


print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n\n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" #
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}") #
print(opt['train']['decay_temp_freq'])  


In [None]:
# environ.opt['train']['policy_lr']       = 0.002
# environ.opt['train']['lambda_sparsity'] = 0.05
# environ.opt['train']['lambda_sharing']  = 0.01
# environ.opt['train']['lambda_tasks']    = 1.0
# # environ.opt['train']['decay_temp_freq'] = 2
 

In [None]:
print( f" Backbone Learning Rate      : {environ.opt['train']['backbone_lr']}\n"
       f" Tasks    Learning Rate      : {environ.opt['train']['task_lr']}\n"
       f" Policy   Learning Rate      : {environ.opt['train']['policy_lr']}\n")


print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n\n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" #
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}") #

print()
print( f" current_iters               : {ns.current_iter}")  
print( f" current_epochs              : {ns.current_epoch}") 
print( f" train_total_epochs          : {ns.training_epochs}") 
print( f" stop_epoch_training         : {ns.stop_epoch_training}")

## Post Training Stuff

In [61]:
# pp.pprint(environ.losses)
# pp.pprint(environ.val_metrics)
environ.num_layers, environ.networks['mtl-net'].num_layers

(6, 6)

In [None]:
# pp.pprint(environ.val_metrics['total'])

In [None]:
# print_loss(environ.val_metrics, title = f"[Final] ep:{current_epoch}  it:{current_iter}",)
# environ.display_trained_policy(current_epoch)
# environ.display_trained_logits(current_epoch)
# environ.log_file.flush()

In [None]:
# model_label   = 'model_train_ep_%d_seed_%04d' % (current_epoch, opt['random_seed'])
# metrics_label = 'metrics_train_ep_%d_seed_%04d.pickle' % (current_epoch, opt['random_seed'])
# environ.save_checkpoint(model_label, current_iter, current_epoch) 
# save_to_pickle(environ.val_metrics, environ.opt['paths']['checkpoint_dir'], metrics_label)
# print_loss(environ.val_metrics, title = f"[Final] ep:{current_epoch}  it:{current_iter}",)
# environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
# environ.display_trained_logits(current_epoch)
# environ.log_file.flush()

In [None]:
# print_loss(current_iter, environ.losses, title = f"[e] Policy training epoch:{current_epoch}    iter:")
# print()
# print_loss(current_iter, trn_losses, title = f"[e] Policy training epoch:{current_epoch}    iter:")
# print()
# print_loss(current_iter, environ.val_metrics, title = f"[e] Policy training epoch:{current_epoch}    iter:")

In [None]:
# environ.losses
# environ.val_metrics

In [None]:
# environ.batch_data
# print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics, 0, out=[sys.stdout])
# environ.display_parameters()

# with np.printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan', precision=7, formatter={'float': lambda x: f"{x:12.5e}"}):
#     environ.print_logit_grads('gradients')

# environ_params = environ.get_task_specific_parameters()
# environ_params = environ.get_arch_parameters()
# environ_params = environ.get_backbone_parameters()
# print(environ_params)
# for param in environ_params:
#     print(param.grad.shape, '\n', param.grad)
#     print(param)

In [None]:
environ.display_trained_logits(ns.current_epoch)
environ.display_trained_policy(ns.current_epoch)

In [None]:
environ.display_test_sample_policy(ns.current_epoch, hard_sampling = True)
environ.display_train_sample_policy(ns.current_epoch, hard_sampling = True)

In [None]:
# environ.define_optimizer(policy_learning=True)

In [None]:
print(environ.optimizers['alphas'])
print(environ.optimizers['weights'])

In [None]:
print('Policy  initial_lr : ', environ.optimizers['alphas'].param_groups[0]['initial_lr'], 'lr : ',environ.optimizers['alphas'].param_groups[0]['lr'])
print('Weights initial_lr : ', environ.optimizers['weights'].param_groups[0]['initial_lr'], 'lr : ',environ.optimizers['weights'].param_groups[0]['lr'])
print('Weights initial_lr : ', environ.optimizers['weights'].param_groups[1]['initial_lr'], 'lr : ',environ.optimizers['weights'].param_groups[1]['lr'])

In [None]:
wandb.run is None

In [None]:
# opt['exp_instance'] = '0218_1358'     
# folder_name=  f"{opt['exp_instance']}_bs{opt['train']['batch_size']:03d}_{opt['train']['decay_lr_rate']:3.2f}_{opt['train']['decay_lr_freq']}"
# print()
# opt['exp_instance'] = datetime.now().strftime("%m%d_%H%M")
# opt['exp_description'] = f"No Alternating Weight/Policy - training all done with both weights and policy"
# folder_name=  f"{opt['exp_instance']}_bs{opt['train']['batch_size']:03d}_{opt['train']['decay_lr_rate']:3.2f}_{opt['train']['decay_lr_freq']}"

In [None]:
# wandb.finish()

In [None]:
# 
p = environ.get_current_state(0)

In [None]:
pp.pprint(p)

### Post Warm-up Training stuff

In [None]:
pp.pprint(environ.val_metrics)

In [None]:
environ.networks['mtl-net'].arch_parameters()

In [None]:
p = environ.get_sample_policy(hard_sampling = False)
print(p)
p = environ.get_policy_prob()
print(p)
p = environ.get_policy_logits()
print(p)

# p = environ.get_current_policy()
# print(p)

In [None]:
a = softmax([0.0, 1])
print(a)
sampled = np.random.choice((1, 0), p=a)
print(sampled)

In [None]:
print(environ.optimizers['weights'])
print(environ.schedulers['weights'].get_last_lr())

In [None]:
print('losses.keys      : ', environ.losses.keys())
print('losses[task]keys : ', environ.losses['task1'].keys())
pp.pprint(environ.losses)

In [None]:
print( environ.val_metrics.keys())
# pp.pprint(val_metrics)
print(type(environ.val_metrics['aggregated']))
print()
print(type(environ.val_metrics['task1']['classification_agg']))
print()
pp.pprint(environ.val_metrics)

In [None]:
# import pickle
# with open("val_metrics.pkl", mode= 'wb') as f:
#         pickle.dump(val_metrics, f)
    
# with open('val_metrics.pkl', 'rb') as f:    
#     tst_val_metrics = pickle.load(f)

In [None]:
# print(environ.input.shape) 
# a = getattr(environ, 'task1_pred')
# yc_data = environ.batch['task1_data']
# print(yc_data.shape)
# yc_ind = environ.batch['task1_ind']
# print(yc_ind.shape)
# yc_hat_all = getattr(environ, 'task1_pred')
# print(yc_hat_all.shape)
# yc_hat  = yc_hat_all[yc_ind[0], yc_ind[1]]
# print(yc_hat_all.shape, yc_hat.shape)

# 
# environ.losses
# loss = {}
# for key in environ.losses.keys():
#     loss[key] = {}
#     for subkey, v in environ.losses[key].items():
#         print(f" key:  {key}   subkey: {subkey} ")
#         if isinstance(v, torch.Tensor):
#             loss[key][subkey] = v.data
#             print(f" Tensor  -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
#         else:
#             loss[key][subkey] = v
#             print(f" integer -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
# pp.pprint(tst_val_metrics)             

In [None]:
# print('metrics.keys: ', environ.metrics.keys())
# print('metrics[task].keys: ', environ.metrics['task1'].keys())
# pp.pprint(environ.metrics['task1'])
# pp.pprint(environ.losses['task1']['total'])

In [None]:
# title='Iteration'
# for t_id, _ in enumerate(environ.tasks):
#     task_key = f"task{t_id+1}"
# #     print_heading(f"{title}  {current_iter}  {task_key} : {val_metrics[task_key]['classification_agg']}", verbose = True)

#     for key, _  in val_metrics[task_key]['classification_agg'].items():
#         print('%s/%-20s'%(task_key, key), val_metrics[task_key]['classification_agg'][key], current_iter)
#         print(f"{task_key:s}/{key:20s}", val_metrics[task_key]['classification_agg'][key], current_iter)
#         print()
#             # print_current_errors(os.path.join(self.log_dir, 'loss.txt'), current_iter,key, loss[key], time.time() - start_time)

In [None]:
# environ.print_loss(current_iter, start_time, metrics = val_metrics['loss'], verbose=True)
# print(opt['lambdas'])
# p = (opt['lambdas'][0] * environ.losses['tasks']['task1'])
# print(p)

# environ.print_val_metrics(current_iter, start_time, val_metrics , title='validation', verbose=True)    

In [None]:
# print(current_iter)
# print_metrics_cr(current_iter, t1 - t0, None, val_metrics , True)
# environ.print_val_metrics(current_iter, start_time, val_metrics, title='validation', verbose = True)

In [None]:
print(f" val_metric keys               : {val_metrics.keys()}")
print(f" loss keys                     : {val_metrics['loss'].keys()}")
print(f" task1 keys                    : {val_metrics['task1'].keys()}")
print(f" task1 classification keys     : {val_metrics['task1']['classification'].keys()}")
print(f" task1 classification_agg keys : {val_metrics['task1']['classification_agg'].keys()}")
print()
print(f" task1                       : {val_metrics['task1']['classification_agg']['loss']:5f}")
print(f" task2                       : {val_metrics['task2']['classification_agg']['loss']:5f}")
print(f" task3                       : {val_metrics['task3']['classification_agg']['loss']:5f}")
print(f" loss                        : {val_metrics['loss']['total']:5f}")
print(f" train_time                  : {val_metrics['train_time']:2f}")
print(f" epoch                       : {val_metrics['epoch']}")


### Post Weight + Policy Training Stuff 

In [None]:
environ.networks['mtl-net'].backbone.layer_config

In [None]:
num_blocks = 6
num_policy_layers = 6
gt =  torch.ones((num_blocks)).long()
gt0 =  torch.zeros((num_blocks)).long()
print(gt)
print(gt0)

loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
print(loss_weights)

In [None]:
if environ.opt['diff_sparsity_weights'] and not environ.opt['is_sharing']:
    print(' cond 1')
    ## Assign higher weights to higher layers 
    loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
    print(f"{task_key} sparsity error:  {2 * (loss_weights[-num_blocks:] * environ.cross_entropy2(logits[-num_blocks:], gt)).mean()})")
    print_dbg(f" loss_weights :  {loss_weights}", verbose = True)
    print_dbg(f" cross_entropy:  {environ.cross_entropy2(logits[-num_blocks:], gt)}  ", verbose = True)
    print_dbg(f" loss[sparsity][{task_key}]: {self.losses['sparsity'][task_key] } ", verbose = True)

else:
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between \n Logits   : \n{logits[-num_blocks:]} \n and gt: \n{gt} \n", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-num_blocks:], gt)}")
    
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt[-1:])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt0[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt0[-1:])} \n")
    
    print('\n cond 3')    
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt[0:1])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt0[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt0[0:1])} \n")
        
        

In [None]:
# flag = 'update_w'
# environ.fix_alpha
# environ.free_w(opt['fix_BN'])

flag = 'update_alpha'
environ.fix_weights()
environ.free_alpha()

In [None]:
environ.networks['mtl-net'].num_layers

In [None]:
print(f"current_iters         : {current_iter}")  
print(f"current_epochs           : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

train_total_epochs += 5

print(f"current_iters         : {current_iter}")  
print(f"current_epochs           : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

In [None]:
# print_metrics_cr(current_epoch, time.time() - t0, None, environ.val_metrics , num_prints)      

# num_prints += 1
# t0 = time.time()

# # Take check point
# environ.save_checkpoint('latest', current_iter)
# environ.train()
# #-------------------------------------------------------
# # END validation process
# #-------------------------------------------------------       
# flag = 'update_alpha'
# environ.fix_w()
# environ.free_alpha()

In [None]:
# dilation = 2
# kernel_size = np.asarray((3, 3))
# upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
# print(upsampled_kernel_size)

In [None]:
# environ.optimizers['weights'].param_groups[0]
# for param_group in optimizer.param_groups:
#     return param_group['lr']

In [None]:
environ.schedulers['weights'].get_last_lr()

In [None]:
current_state = {}
for k, v in environ.optimizers.items():
    print(f'state dict for {k} = {v}')
    current_state[k] = v.state_dict()
pp.pprint(current_state)

In [None]:
current_state = {}
for k, v in environ.schedulers.items():
    print(f'state dict for {k} = {v}')
    print(v.state_dict())

### Losses and Metrics

In [None]:
trn_losses = environ.losses

In [None]:
print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics , num_prints)      

In [None]:
# print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics , num_prints)      

In [None]:
# pp.pprint(environ.losses)
pp.pprint(trn_losses)

In [None]:
pp.pprint(environ.val_metrics)

In [None]:
# environ.opt['train']['Lambda_sharing'] = 0.5
# opt['train']['Lambda_sharing'] = 0.5

# environ.opt['train']['policy_lr'] = 0.001
# opt['train']['policy_lr'] = 0.001

In [None]:
environ.losses.keys()
pp.pprint(environ.losses)

In [None]:
tmp = environ.get_loss_dict()
print(tmp.keys())
pp.pprint(tmp)

In [None]:
print(opt['diff_sparsity_weights'])
print(opt['is_sharing'])
print(opt['diff_sparsity_weights'] and not opt['is_sharing'])
print(environ.opt['train']['Lambda_sharing'])
print(opt['train']['Lambda_sharing'])
print(environ.opt['train']['Lambda_sparsity'])
print(opt['train']['Lambda_sparsity'])
print(environ.opt['train']['policy_lr'])
print(opt['train']['policy_lr'])

### Policy / Logit stuff

In [None]:
from scipy.special          import softmax

In [None]:
np.set_printoptions(precision=8,edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions(precision=8,linewidth=132)

#### `get_task_logits(n)` Get logits for task group n

In [None]:
task_logits = environ.get_task_logits(1)
print(task_logits)

#### `get_arch_parameters()`: Get last used logits from network

In [None]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

In [None]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

#### `get_policy_logits()`:  Get Policy Logits - returns same as `get_arch_parameters()`

In [None]:
logs = environ.get_policy_logits()
for i in logs:
    print(i, '\n')
# probs = softmax(logs, axis= -1)
# for i in probs:
#     print(i, '\n')

#### `get_policy_prob()` : Gets the softmax of the logits

In [None]:
policy_softmaxs = environ.get_policy_prob()
for i in policy_softmaxs:
    print(i, '\n')

#### `get_sample_policy( hard_sampling = False)` : Calls test_sample_policy of network with random choices based on softmax of logits

In [None]:
policy_softmaxs = environ.get_policy_prob()
policies,logits = environ.get_sample_policy(hard_sampling = False)

for l, p, s in zip(logits, policies, policy_softmaxs) :
    for  l_row, p_row, s_row in zip(l, p, s):
        print( l_row,'\t', p_row, '\t', s_row)
    print('\n')

#### `get_sample_policy( hard_sampling = True)` : Calls test_sample_policy of network using ARGMAX of logits

In [None]:
policy_softmaxs = environ.get_policy_prob()
hard_policies, logits = environ.get_sample_policy(hard_sampling = True)

for p,l,s in zip(hard_policies, logits, policy_softmaxs) :
    for  p_row, l_row, s_row in zip(p, l, s):
        print( l_row,'\t', p_row, '\t', s_row)
    print('\n')

#### Print

In [None]:
print(f" Layer    task 1      task 2      task 3")
print(f" -----    ------      ------      ------")
for idx, (l1, l2, l3) in enumerate(zip(hard_policies[0], hard_policies[1], hard_policies[2]),1):
    print(f"   {idx}      {l1}       {l2}       {l3}")
    

    print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [None]:
def display_trained_policy(iter):

    policy_softmaxs = environ.get_policy_prob()
    policy_argmaxs = 1-np.argmax(policy_softmaxs, axis = -1)
    print(f"  Trained polcies at iteration: {iter} ")
    print(f"                   task 1                           task 2                         task 3        ")
    print(f" Layer       softmax        select          softmax        select          softmax        select   ")
    print(f" -----    ---------------   ------       ---------------   ------       ---------------   ------   ")
    for idx, (l1,l2,l3,  p1,p2,p3) in enumerate(zip(policy_softmaxs[0], policy_softmaxs[1], policy_softmaxs[2], policy_argmaxs[0], policy_argmaxs[1], policy_argmaxs[2]),1):
        print(f"   {idx}      {l1[0]:.4f}   {l1[1]:.4f}   {p1:4d}    {l2[0]:11.4f}   {l2[1]:.4f}   {p2:4d}    {l3[0]:11.4f}   {l3[1]:.4f}   {p3:4d}")

    print()
# print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [None]:
display_trained_policy(5)

In [None]:
print(f"                        POLICIES (SOFTMAX)                                       task 3          ")
print(f" Layer    task1              task2            task3 softmax         softmax         argmax         softmax         argmax   ")
print(f" -----    -------------     -------------     -------------   ------   ")
for idx, (l1,l2,l3, h1,h2,h3) in enumerate(zip(policy_softmaxs[0], policy_softmaxs[1], policy_softmaxs[2],hard_policies[0], hard_policies[1], hard_policies[2]),1):
    print(f"   {idx}      {l1[0]:.4f} {l1[1]:.4f}     {l2[0]:.4f} {l2[1]:.4f}     {l3[0]:.4f} {l3[1]:.4f}    {h3}")
    
print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [None]:
# print(policy_softmaxs[2], np.argmax(1-policy_softmaxs[2], axis = -1))
print(policy_softmaxs, np.argmax(policy_softmaxs, axis = -1))

#### `get_current_logits()` : Calls test_sample_policy of network using ARGMAX of logits

In [None]:
logits  = (environ.get_current_logits())
for i in logits:
    print(i ,'\n')

#### `get_current_policy()` : Calls test_sample_policy of network using ARGMAX of logits

In [None]:
pols  = (environ.get_current_policy())

for i in pols:
    print(i ,'\n')

#### `gumbel_softmax()`  

In [None]:
np.set_printoptions(precision=8,edgeitems=3, infstr='inf', linewidth=150, nanstr='nan', floatmode = 'maxprec_equal')
torch.set_printoptions(precision=8,linewidth=132)

In [None]:
print(environ.temp)
# tau = environ.temp
tau = 1
for i in range(3): 
    logits_tensor = torch.tensor(logits[0])
    # Sample soft categorical using reparametrization trick:
    gumbel_soft = F.gumbel_softmax(logits_tensor, tau=tau, hard=False).cpu().numpy() 

    # Sample hard categorical using "Straight-through" trick:
    gumbel_hard  = F.gumbel_softmax(logits_tensor, tau=tau, hard=True).cpu().numpy()
    
    for l, gs, gh in zip(lgts, gumbel_soft, gumbel_hard):
        print(f"   {l}   \t {gs}            \t {gh}")
#     print(lgts)
#     print(gumbel_soft)
#     print(gumbel_hard)
    print()

In [None]:
for lgts in logits:
    logits_tensor = torch.tensor(lgts)
    print(lgts)
    # Sample soft categorical using reparametrization trick:
    gumbel_soft = F.gumbel_softmax(logits_tensor, tau=1, hard=False)
    print(gumbel_soft)

    # Sample hard categorical using "Straight-through" trick:
    gumbel_hard  = F.gumbel_softmax(logits_tensor, tau=1, hard=True)
    print(gumbel_hard)
    print()

In [None]:
smax = scipy.special.softmax(logs, axis =1)
# smax = np.array( 
# [[0.46973792, 0.530262  ],
#  [0.45025694, 0.549743  ],
#  [0.4443086 , 0.5556915 ],
#  [0.4138397 , 0.58616036],
#  [0.4140113 , 0.5859887 ],
#  [0.42114905, 0.57885087]])

print(smax.shape)
print(smax)
print(smax[0])
print(smax[0].sum())
print(np.random.choice((1,0), p =smax[0]))

In [None]:
logs = np.array(
[[0.33064184, 0.42053092],
 [0.3532089 , 0.52056104],
 [0.3888512 , 0.5680909 ],
 [0.42039296, 0.694217  ],
 [0.4519742 , 0.73311865],
 [0.48401102, 0.7522658 ]],
)