## Initialization  

In [1]:
%load_ext autoreload
%autoreload 2
import os 
import sys
sys.path.insert(0, './src')
import time
import argparse
import yaml
import types
import copy, pprint
from time import sleep
from datetime import datetime
import numpy  as np
import torch  
import wandb
import pandas as pd
from utils.notebook_modules import initialize, init_dataloaders, init_environment, init_wandb, \
                                   training_prep, disp_dataloader_info,disp_info_1, \
                                   warmup_phase, weight_policy_training, disp_gpu_info

from utils.util import (print_separator, print_heading, timestring, print_loss, load_from_pickle) #, print_underline, 
#                       print_dbg, get_command_line_args ) 

pp = pprint.PrettyPrinter(indent=4)
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions(precision=6, linewidth=132)
pd.options.display.width = 132
# torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
# sys.path.insert(0, '/home/kbardool/kusanagi/AdaSparseChem/src')
# print(sys.path)
# disp_gpu_info() 
os.environ["WANDB_NOTEBOOK_NAME"] = "Adashare_Training.ipynb"


## Create Environment

### Parse Input Args  - Read YAML config file

In [2]:
# RESUME_MODEL_CKPT = 'model_train_ep_25_seed_0088'

## For RESTARTING
##
# input_args = " --config yamls/chembl_3task_train.yaml " \
#              " --resume " \
#              " --exp_id      330i85cg" \
#              " --exp_name    0308_1204" \
#              " --exp_desc    Train with dropout 0.5" \
#              " --seed_idx    0 "\
#              " --batch_size  128" \
#              " --lambda_sparsity  0.01"\
#              " --lambda_sharing   0.01" 
## get command line arguments

In [3]:
##  For Initiating 
##
input_args = " --config yamls/chembl_3task_train.yaml " \
             " --exp_desc    6 lyrs,dropout 0.5, weight 105 bch/ep policy 105 bch/ep " \
             " --warmup_epochs       350 " \
             " --hidden_size         1600 " \
             " --tail_hidden_size    1600 " \
             " --seed_idx             0" \
             " --batch_size         128" \
             " --task_lr          0.001" \
             " --backbone_lr      0.001" \
             " --policy_lr        0.001" \
             " --lambda_sparsity   0.02" \
             " --lambda_sharing    0.01" \
               " --folder_sfx    noplcy"                       
#              " --hidden_size   100 100 100 100 100 100" \
#              " --tail_hidden_size  100 " \


In [4]:
opt, ns = initialize(input_args, build_folders = True)


  command line parms : 
------------------------
 config...................  yamls/chembl_3task_train.yaml
 exp_id...................  None
 exp_name.................  None
 folder_sfx...............  noplcy
 exp_desc.................  6 lyrs,dropout 0.5, weight 105 bch/ep policy 105 bch/ep
 hidden_sizes.............  [1600]
 tail_hidden_size.........  1600
 warmup_epochs............  350
 training_epochs..........  None
 seed_idx.................  0
 batch_size...............  128
 backbone_lr..............  0.001
 task_lr..................  0.001
 policy_lr................  0.001
 decay_lr_rate............  None
 decay_lr_freq............  None
 lambda_sparsity..........  0.02
 lambda_sharing...........  0.01
 gpu_ids..................  [0]
 resume...................  False
 cpu......................  False



##################################################
################### READ YAML ####################
##################################################


 log_dir            

### Setup Dataloader and Model  

In [5]:
dldrs = init_dataloaders(opt)

disp_dataloader_info(dldrs)

environ = init_environment(ns, opt, is_train = True, policy_learning = False, display_cfg = False)

# ********************************************************************
# **************** define optimizer and schedulers *******************
# ********************************************************************                                
environ.define_optimizer(policy_learning=False)
environ.define_scheduler(policy_learning=False)


##################################################
############### CREATE DATALOADERS ###############
##################################################

 trainset.y_class                                   :  [(13331, 5), (13331, 5), (13331, 5)] 
 trainset1.y_class                                  :  [(13331, 5), (13331, 5), (13331, 5)] 
 trainset2.y_class                                  :  [(13331, 5), (13331, 5), (13331, 5)] 
 valset.y_class                                     :  [(4137, 5), (4137, 5), (4137, 5)]  
 testset.y_class                                    :  [(920, 5), (920, 5), (920, 5)]  
                                 
 size of training set 0 (warm up)                   :  13331 
 size of training set 1 (network parms)             :  13331 
 size of training set 2 (policy weights)            :  13331 
 size of validation set                             :  4137 
 size of test set                                   :  920 
                               Total           

In [6]:
# environ.optimizers['weights'].param_groups[0]

###  Weights and Biases Initialization 

In [7]:
init_wandb(ns, opt, environment = environ)

print(f" PROJECT NAME: {ns.wandb_run.project}\n"
      f" RUN ID      : {ns.wandb_run.id} \n"
      f" RUN NAME    : {ns.wandb_run.name}") 

[34m[1mwandb[0m: Currently logged in as: [33mkbardool[0m (use `wandb login --relogin` to force relogin)


14lubvbu 0329_0458_noplcy AdaSparseChem


 PROJECT NAME: AdaSparseChem
 RUN ID      : 14lubvbu 
 RUN NAME    : 0329_0458_noplcy
 PROJECT NAME: AdaSparseChem
 RUN ID      : 14lubvbu 
 RUN NAME    : 0329_0458_noplcy


In [8]:
# ns.wandb_run.finish()

### Initiate / Resume Training Prep

In [9]:
if opt['train']['resume']:
    RESUME_MODEL_CKPT = ""
    RESUME_METRICS_CKPT = ""    
    print(opt['train']['which_iter'])
    print(opt['paths']['checkpoint_dir'])
    print(RESUME_MODEL_CKPT)
    # opt['train']['resume'] = True
    # opt['train']['which_iter'] = 'warmup_ep_40_seed_0088'
    print_separator('Resume training')
    loaded_iter, loaded_epoch = environ.load_checkpoint(RESUME_MODEL_CKPT, path = opt['paths']['checkpoint_dir'], verbose = True)
    print(loaded_iter, loaded_epoch)    
#     current_iter = environ.load_checkpoint(opt['train']['which_iter'])
    environ.networks['mtl-net'].reset_logits()
    val_metrics = load_from_pickle(opt['paths']['checkpoint_dir'], RESUME_METRICS_CKPT)
    # training_prep(ns, opt, environ, dldrs, epoch = loaded_epoch, iter = loaded_iter )

else:
    print_separator('Initiate Training ')

##################################################
############### Initiate Training  ###############
##################################################


### Training Preparation

In [10]:
training_prep(ns, opt, environ, dldrs )

 cuda available [0]
 set print_freq to length of train loader: 105
 set eval_iters to length of val loader  : 33


In [11]:
disp_info_1(ns, opt, environ)


 Num_blocks                : 1                                

 batch size                : 128 
 batches/ Weight trn epoch : 105 
 batches/ Policy trn epoch : 105                                 

 Print Frequency           : -1 
 Config Val Frequency      : 500 
 Config Val Iterations     : -1 
 Val iterations            : 33 
 which_iter                : warmup 
 train_resume              : False                                 
 
 fix BN parms              : False 
 Task LR                   : 0.001 
 Backbone LR               : 0.001                                 

 Sharing  regularization   : 0.01 
 Sparsity regularization   : 0.02 
 Task     regularization   : 1                                 

 Current epoch             : 0  
 Warm-up epochs            : 350 
 Training epochs           : 250


In [12]:
print(environ.disp_for_excel())


    folder: 1600x1_0329_0458_plr0.001_sp0.02_sh0.01_lr0.001_noplcy
    layers: 1 [1600] 
    
    middle dropout         : 0.5
    last dropout           : 0.5
    diff_sparsity_weights  : False
    skip_layer             : 0
    is_curriculum          : False
    curriculum_speed       : 3
    
    task_lr                : 0.001
    backbone_lr            : 0.001
    decay_lr_rate          : 0.75
    decay_lr_freq          : 40
    
    policy_lr              : 0.001
    policy_decay_lr_rate   : 0.75
    policy_decay_lr_freq   : 50
    lambda_sparsity        : 0.02
    lambda_sharing         : 0.01
    lambda_tasks           : 1
    
    Gumbel init_temp       : 4
    Gumbel decay_temp      : 0.965
    Gumbel decay_temp_freq : 16
    Logit init_method      : random
    Logit init_neg_logits  : None
    Logit hard_sampling    : False
    Warm-up epochs         : 350
    training epochs        : 250
    Data split ratios      : [0.725, 0.225, 0.05]



## Warmup Training

In [13]:
# environ.display_trained_policy(ns.current_epoch,out=sys.stdout)
# ns.stop_epoch_warmup = 10
# ns.warmup_epochs = 10
# ns.check_for_improvment_wait = 0
print(ns.warmup_epochs, ns.current_epoch)
print_heading(f" Last Epoch: {ns.current_epoch}   # of warm-up epochs to do:  {ns.warmup_epochs} - Run epochs {ns.current_epoch+1} to {ns.current_epoch + ns.warmup_epochs}", verbose = True)

350 0
--------------------------------------------------------------------------
 Last Epoch: 0   # of warm-up epochs to do:  350 - Run epochs 1 to 350
-------------------------------------------------------------------------- 



In [14]:
# warmup_phase(ns,opt, environ, dldrs, epochs = 25)
warmup_phase(ns,opt, environ, dldrs)

--------------------------------------------------------------------------
 Last Epoch: 0   # of warm-up epochs to do:  350 - Run epochs 1 to 350
-------------------------------------------------------------------------- 

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |        
    1 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    9.8691   4.1555e-02   1.9403e-05    9.9107 |   0.67669   0.65897   0.65717   0.65871 |   10.1496   4.1555e-02   1.9403e-05    10.1912 |  46.8 |
Previous best_epoch:     0   best iter:     0,   best_value: 0.00000
New      best_epoch:     1   best iter:   105,   best_value: 0.65897
    2 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    9.5726   4.1555e-02   1.9403e-05    9.6142 |   0.65873   0.68739   0.68286   0.68721 |    9.8788   4.1555e-02   1.9403e-05     9.9204 |  48.5 |        
Previous best_epoch:

   23 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    5.1086   4.1555e-02   1.9403e-05    5.1502 |   0.55225   0.80759   0.80655   0.80751 |    8.2928   4.1555e-02   1.9403e-05     8.3344 |  45.2 |        
Previous best_epoch:    22   best iter:  2310,   best_value: 0.80599
New      best_epoch:    23   best iter:  2415,   best_value: 0.80759
   24 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    4.3138   4.1555e-02   1.9403e-05    4.3554 |   0.54752   0.80930   0.80806   0.80922 |    8.2133   4.1555e-02   1.9403e-05     8.2549 |  46.4 |        
Previous best_epoch:    23   best iter:  2415,   best_value: 0.80759
New      best_epoch:    24   best iter:  2520,   best_value: 0.80930
   25 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    3.8359   4.1555e-02   1.9403e-05    3.8775 |   0.54834   0.81094   0.80963   0.81085 |    8.2249   4.1555e-02   1.9403e-05     8.2665 |  45.1 |        
Previous best_epoch:    24   best iter:  2520,   best_value: 0.80930
New      best_epoch:    25  

   46 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    3.4770   4.1555e-02   1.9403e-05    3.5186 |   0.59043   0.82706   0.82664   0.82699 |    8.8640   4.1555e-02   1.9403e-05     8.9056 |  46.1 |        
Previous best_epoch:    45   best iter:  4725,   best_value: 0.82683
New      best_epoch:    46   best iter:  4830,   best_value: 0.82706
   47 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    3.3765   4.1555e-02   1.9403e-05    3.4180 |   0.59101   0.82766   0.82700   0.82759 |    8.8687   4.1555e-02   1.9403e-05     8.9103 |  46.6 |        
Previous best_epoch:    46   best iter:  4830,   best_value: 0.82706
New      best_epoch:    47   best iter:  4935,   best_value: 0.82766
   48 |   1.00e-03   1.00e-03   1.00e-03  4.000e+00 |    2.9140   4.1555e-02   1.9403e-05    2.9555 |   0.59881   0.82846   0.82781   0.82838 |    8.9854   4.1555e-02   1.9403e-05     9.0270 |  47.0 |        
Previous best_epoch:    47   best iter:  4935,   best_value: 0.82766
New      best_epoch:    48  

   77 |   7.50e-04   7.50e-04   1.00e-03  4.000e+00 |    1.3626   4.1555e-02   1.9403e-05    1.4042 |   0.72813   0.82834   0.82776   0.82826 |   10.8989   4.1555e-02   1.9403e-05    10.9405 |  46.4 |        
Epoch    77: reducing learning rate of group 0 to 7.5000e-04.
   78 |   7.50e-04   7.50e-04   7.50e-04  4.000e+00 |    1.3274   4.1555e-02   1.9403e-05    1.3690 |   0.72873   0.82876   0.82815   0.82868 |   10.8982   4.1555e-02   1.9403e-05    10.9397 |  47.3 |        
   79 |   7.50e-04   7.50e-04   7.50e-04  4.000e+00 |    1.6056   4.1555e-02   1.9403e-05    1.6472 |   0.73286   0.82880   0.82793   0.82872 |   10.9903   4.1555e-02   1.9403e-05    11.0319 |  46.9 |        
   80 |   7.50e-04   7.50e-04   7.50e-04  4.000e+00 |    1.5573   4.1555e-02   1.9403e-05    1.5988 |   0.73942   0.82859   0.82780   0.82852 |   11.0750   4.1555e-02   1.9403e-05    11.1166 |  48.0 |        
   81 |   7.50e-04   7.50e-04   7.50e-04  4.000e+00 |    1.4888   4.1555e-02   1.9403e-05    1.5304 | 

  114 |   5.63e-04   5.63e-04   7.50e-04  4.000e+00 |    0.8418   4.1555e-02   1.9403e-05    0.8834 |   0.87335   0.82664   0.82580   0.82656 |   13.0944   4.1555e-02   1.9403e-05    13.1360 |  48.3 |        
  115 |   5.63e-04   5.63e-04   7.50e-04  4.000e+00 |    0.7038   4.1555e-02   1.9403e-05    0.7454 |   0.87209   0.82672   0.82578   0.82664 |   13.0370   4.1555e-02   1.9403e-05    13.0786 |  46.1 |        
  116 |   5.63e-04   5.63e-04   7.50e-04  4.000e+00 |    0.8204   4.1555e-02   1.9403e-05    0.8620 |   0.87561   0.82681   0.82579   0.82674 |   13.1968   4.1555e-02   1.9403e-05    13.2384 |  47.2 |        
  117 |   5.63e-04   5.63e-04   7.50e-04  4.000e+00 |    0.8884   4.1555e-02   1.9403e-05    0.9299 |   0.87462   0.82676   0.82602   0.82668 |   13.1670   4.1555e-02   1.9403e-05    13.2086 |  48.6 |        
  118 |   5.63e-04   5.63e-04   7.50e-04  4.000e+00 |    1.1316   4.1555e-02   1.9403e-05    1.1732 |   0.88240   0.82663   0.82567   0.82656 |   13.2455   4.1555e-

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |        
  151 |   5.63e-04   5.63e-04   5.63e-04  4.000e+00 |    0.5912   4.1555e-02   1.9403e-05    0.6327 |   0.96657   0.82517   0.82430   0.82509 |   14.4999   4.1555e-02   1.9403e-05    14.5415 |  50.4 |
  152 |   5.63e-04   5.63e-04   5.63e-04  4.000e+00 |    0.7657   4.1555e-02   1.9403e-05    0.8073 |   0.96670   0.82533   0.82448   0.82526 |   14.4791   4.1555e-02   1.9403e-05    14.5206 |  48.0 |        
  153 |   5.63e-04   5.63e-04   5.63e-04  4.000e+00 |    0.4771   4.1555e-02   1.9403e-05    0.5187 |   0.96987   0.82525   0.82431   0.82517 |   14.5784   4.1555e-02   1.9403e-05    14.6199 |  46.7 |        
  154 |   5.63e-04   5.63e-04   5.63e-04  4.000e+00 |    0.5815   4.1555e-02   1.9403e-05    0.6231 |   0.96966   0.82496   0.82423   0.82488 |   14.5166   4.1555e-02   1.9

  187 |   4.22e-04   4.22e-04   5.63e-04  4.000e+00 |    0.5399   4.1555e-02   1.9403e-05    0.5815 |   1.03127   0.82421   0.82337   0.82413 |   15.4300   4.1555e-02   1.9403e-05    15.4716 |  46.5 |        
  188 |   4.22e-04   4.22e-04   5.63e-04  4.000e+00 |    0.3201   4.1555e-02   1.9403e-05    0.3616 |   1.03139   0.82398   0.82322   0.82390 |   15.4474   4.1555e-02   1.9403e-05    15.4890 |  47.4 |        
  189 |   4.22e-04   4.22e-04   5.63e-04  4.000e+00 |    0.5813   4.1555e-02   1.9403e-05    0.6228 |   1.03358   0.82403   0.82323   0.82395 |   15.5147   4.1555e-02   1.9403e-05    15.5562 |  46.0 |        
Epoch   189: reducing learning rate of group 0 to 4.2188e-04.
  190 |   4.22e-04   4.22e-04   4.22e-04  4.000e+00 |    0.4899   4.1555e-02   1.9403e-05    0.5315 |   1.03417   0.82395   0.82316   0.82388 |   15.5902   4.1555e-02   1.9403e-05    15.6318 |  47.3 |        
  191 |   4.22e-04   4.22e-04   4.22e-04  4.000e+00 |    0.5708   4.1555e-02   1.9403e-05    0.6124 | 

  224 |   3.16e-04   3.16e-04   4.22e-04  4.000e+00 |    0.3619   4.1555e-02   1.9403e-05    0.4034 |   1.08369   0.82344   0.82263   0.82336 |   16.2951   4.1555e-02   1.9403e-05    16.3367 |  47.6 |        
  225 |   3.16e-04   3.16e-04   4.22e-04  4.000e+00 |    0.3893   4.1555e-02   1.9403e-05    0.4309 |   1.08637   0.82333   0.82255   0.82325 |   16.2770   4.1555e-02   1.9403e-05    16.3186 |  46.9 |        
Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |        
  226 |   3.16e-04   3.16e-04   4.22e-04  4.000e+00 |    0.3523   4.1555e-02   1.9403e-05    0.3939 |   1.08169   0.82351   0.82275   0.82343 |   16.2517   4.1555e-02   1.9403e-05    16.2933 |  46.9 |
  227 |   3.16e-04   3.16e-04   4.22e-04  4.000e+00 |    0.4420   4.1555e-02   1.9403e-05    0.4836 |   1.08799   0.82351   0.82273   0.82343 |   16.3089   4.1555e-02   1.9

  260 |   2.37e-04   2.37e-04   3.16e-04  4.000e+00 |    0.2311   4.1555e-02   1.9403e-05    0.2727 |   1.11860   0.82283   0.82205   0.82275 |   16.8015   4.1555e-02   1.9403e-05    16.8431 |  48.4 |        
  261 |   2.37e-04   2.37e-04   3.16e-04  4.000e+00 |    0.3546   4.1555e-02   1.9403e-05    0.3962 |   1.12035   0.82276   0.82205   0.82268 |   16.8001   4.1555e-02   1.9403e-05    16.8417 |  45.8 |        
  262 |   2.37e-04   2.37e-04   3.16e-04  4.000e+00 |    0.4139   4.1555e-02   1.9403e-05    0.4555 |   1.12096   0.82283   0.82206   0.82275 |   16.7874   4.1555e-02   1.9403e-05    16.8290 |  47.8 |        
  263 |   2.37e-04   2.37e-04   3.16e-04  4.000e+00 |    0.4977   4.1555e-02   1.9403e-05    0.5392 |   1.12094   0.82286   0.82209   0.82278 |   16.8410   4.1555e-02   1.9403e-05    16.8826 |  46.5 |        
  264 |   2.37e-04   2.37e-04   3.16e-04  4.000e+00 |    0.6473   4.1555e-02   1.9403e-05    0.6888 |   1.12199   0.82285   0.82211   0.82277 |   16.7899   4.1555e-

  298 |   1.78e-04   1.78e-04   3.16e-04  4.000e+00 |    0.3717   4.1555e-02   1.9403e-05    0.4132 |   1.14520   0.82255   0.82182   0.82247 |   17.1644   4.1555e-02   1.9403e-05    17.2060 |  48.1 |        
  299 |   1.78e-04   1.78e-04   3.16e-04  4.000e+00 |    0.4038   4.1555e-02   1.9403e-05    0.4454 |   1.14920   0.82259   0.82182   0.82251 |   17.2732   4.1555e-02   1.9403e-05    17.3148 |  46.5 |        
  300 |   1.78e-04   1.78e-04   3.16e-04  4.000e+00 |    0.2379   4.1555e-02   1.9403e-05    0.2795 |   1.14675   0.82267   0.82191   0.82259 |   17.1464   4.1555e-02   1.9403e-05    17.1880 |  47.6 |        
Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar   trn ttl |   bceloss  avg prec    aucroc     aucpr |  val loss     val spar     val shar    val ttl |  time |        
  301 |   1.78e-04   1.78e-04   3.16e-04  4.000e+00 |    0.2998   4.1555e-02   1.9403e-05    0.3414 |   1.14931   0.82257   0.82176   0.82249 |   17.2307   4.1555e-

  334 |   1.78e-04   1.78e-04   2.37e-04  4.000e+00 |    0.2831   4.1555e-02   1.9403e-05    0.3247 |   1.16434   0.82261   0.82176   0.82253 |   17.4673   4.1555e-02   1.9403e-05    17.5089 |  47.3 |        
  335 |   1.78e-04   1.78e-04   2.37e-04  4.000e+00 |    0.3729   4.1555e-02   1.9403e-05    0.4145 |   1.16682   0.82247   0.82163   0.82239 |   17.4884   4.1555e-02   1.9403e-05    17.5299 |  45.8 |        
  336 |   1.78e-04   1.78e-04   2.37e-04  4.000e+00 |    0.2166   4.1555e-02   1.9403e-05    0.2582 |   1.16628   0.82244   0.82157   0.82236 |   17.5791   4.1555e-02   1.9403e-05    17.6207 |  48.3 |        
  337 |   1.78e-04   1.78e-04   2.37e-04  4.000e+00 |    0.3769   4.1555e-02   1.9403e-05    0.4185 |   1.16719   0.82247   0.82163   0.82239 |   17.4964   4.1555e-02   1.9403e-05    17.5380 |  48.0 |        
  338 |   1.78e-04   1.78e-04   2.37e-04  4.000e+00 |    0.3440   4.1555e-02   1.9403e-05    0.3855 |   1.16978   0.82248   0.82161   0.82240 |   17.5224   4.1555e-

In [15]:
# warmup_phase(ns,opt, environ, dldrs, epochs = 25)
ns.best_epoch, ns.best_iter, ns.best_value

(58, 6090, 0.8293974608888801)

In [16]:
# ns.wandb_run.finish()

ns.wandb_run.finish()

# environ.losses

# environ.val_metrics




VBox(children=(Label(value='3.867 MB of 3.867 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_time,▄▃▂▄▁▃▁▄▅▅▅▆▅▄▆▄▄▅▄▃▃▄▃▅▆▅▆▇▅▃▅▃▅▅▅▄▅▄▆█

0,1
epoch,350.0
train_time,49.3322


#### display parms

In [None]:
print( f" Backbone Initial LR         : {environ.opt['train']['backbone_lr']:4f}      current LR : {environ.optimizers['alphas'].param_groups[0]['lr']} \n"
       f" Tasks    Initial LR         : {environ.opt['train']['task_lr']:4f}      current LR : {environ.optimizers['weights'].param_groups[0]['lr']}    \n"
       f" Policy   Initial LR         : {environ.opt['train']['policy_lr']:4f}      current LR : {environ.optimizers['weights'].param_groups[1]['lr']}  \n")
print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n\n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" #
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}") #

In [None]:
# environ.opt['train']['policy_lr'] = 0.01
# opt['train']['policy_lr']         = 0.01
# environ.opt['train']['lambda_sparsity'] = 0.1
# environ.opt['train']['lambda_sharing']  = 0.01
# environ.opt['train']['lambda_tasks']    = 1.0
# environ.opt['train']['decay_temp_freq'] = 2
# print(environ.optimizers['alphas'].param_groups)
# print(environ.optimizers['weights'].param_groups)
# print('current lr: ', environ.optimizers['alphas'].param_groups[0]['lr'],)
# print('current lr: ', environ.optimizers['weights'].param_groups[0]['lr'])
# print('current lr: ', environ.optimizers['weights'].param_groups[1]['lr'])

## Weight & Policy Training

### Weight/Policy Training Preparation

In [None]:
# ns.flag_warmup = True
# num_train_layers = None 
# environ.opt['is_curriculum'] = True
# environ.opt['curriculum_speed'] = 4
# ns.num_train_layers = None

In [None]:
if ns.flag_warmup:
    print_heading( f"** {timestring()} \n"
                   f"** Training epoch: {ns.current_epoch} iter: {ns.current_iter}   flag: {ns.flag} \n"
                   f"** Set optimizer and scheduler to policy_learning = True (Switch weight optimizer from ADAM to SGD)\n"
                   f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
                   f"** Take checkpoint and block gradient flow through Policy net", verbose=True)
#     environ.define_optimizer(policy_learning=True)
#     environ.define_scheduler(policy_learning=True)
    ns.flag_warmup = False
    ns.flag = 'update_weights'
    environ.fix_alpha()
    environ.free_weights(opt['fix_BN'])

In [None]:
# ns.training_epochs = 250
# environ.display_trained_policy(ns.current_epoch)
# environ.display_trained_logits(ns.current_epoch)

In [None]:
print(f"ns.current_epoch           : {ns.current_epoch}")
print(f"ns.training_epochs         : {ns.training_epochs} \n") 
print(f"ns.current_iters           : {ns.current_iter}")  
print(f"Batches in weight epoch    : {ns.stop_iter_w}")
print(f"Batches in policy epoch    : {ns.stop_iter_a}")
print(f"num_train_layers           : {ns.num_train_layers}")
print()
print_loss(environ.val_metrics, title = f"[e] Last ep:{ns.current_epoch}  it:{ns.current_iter}")
print()

print_heading(f" Last Epoch Completed : {ns.current_epoch}       # of epochs to run:  {ns.training_epochs} -->  epochs {ns.current_epoch+1} to {ns.training_epochs + ns.current_epoch}"
              f"\n policy_learning rate : {environ.opt['train']['policy_lr']} "
              f"\n lambda_sparsity      : {environ.opt['train']['lambda_sparsity']}"
              f"\n lambda_sharing       : {environ.opt['train']['lambda_sharing']}"
              f"\n curriculum training  : {opt['is_curriculum']}     cirriculum speed: {opt['curriculum_speed']}     num_training_layers : {ns.num_train_layers}", 
              verbose = True)

### Weight/Policy Training

In [None]:
# weight_policy_training(ns, opt, environ, dldrs, epochs = 100)
weight_policy_training(ns, opt, environ, dldrs)

In [None]:
ns.best_epoch, ns.best_iter, ns.best_value

### Close WandB run

In [None]:
wandb.finish()

In [None]:
# ns.best_epoch = 0
# from utils.notebook_modules import wrapup_phase
# wrapup_phase(ns, opt, environ)

In [None]:
# environ.opt['train']['policy_lr']       = 0.002
# environ.opt['train']['lambda_sparsity'] = 0.05
# environ.opt['train']['lambda_sharing']  = 0.01
# environ.opt['train']['lambda_tasks']    = 1.0
# # environ.opt['train']['decay_temp_freq'] = 2

In [None]:
print( f" Backbone Learning Rate      : {environ.opt['train']['backbone_lr']}\n"
       f" Tasks    Learning Rate      : {environ.opt['train']['task_lr']}\n"
       f" Policy   Learning Rate      : {environ.opt['train']['policy_lr']}\n")

print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n\n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" 
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}\n") 

print( f" current_iters               : {ns.current_iter}   \n"
       f" current_epochs              : {ns.current_epoch}  \n" 
       f" train_total_epochs          : {ns.training_epochs}\n" 
       f" stop_epoch_training         : {ns.stop_epoch_training}")

## Post Training Stuff

In [None]:
# pp.pprint(environ.losses)
# pp.pprint(environ.val_metrics)
environ.num_layers, environ.networks['mtl-net'].num_layers

In [None]:
# pp.pprint(environ.val_metrics['total'])

In [None]:
# print_loss(environ.val_metrics, title = f"[Final] ep:{current_epoch}  it:{current_iter}",)
# environ.display_trained_policy(current_epoch)
# environ.display_trained_logits(current_epoch)
# environ.log_file.flush()

In [None]:
# model_label   = 'model_train_ep_%d_seed_%04d' % (current_epoch, opt['random_seed'])
# metrics_label = 'metrics_train_ep_%d_seed_%04d.pickle' % (current_epoch, opt['random_seed'])
# environ.save_checkpoint(model_label, current_iter, current_epoch) 
# save_to_pickle(environ.val_metrics, environ.opt['paths']['checkpoint_dir'], metrics_label)
# print_loss(environ.val_metrics, title = f"[Final] ep:{current_epoch}  it:{current_iter}",)
# environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
# environ.display_trained_logits(current_epoch)
# environ.log_file.flush()

In [None]:
# print_loss(current_iter, environ.losses, title = f"[e] Policy training epoch:{current_epoch}    iter:")
# print()
# print_loss(current_iter, trn_losses, title = f"[e] Policy training epoch:{current_epoch}    iter:")
# print()
# print_loss(current_iter, environ.val_metrics, title = f"[e] Policy training epoch:{current_epoch}    iter:")

In [None]:
# environ.losses
# environ.val_metrics

In [None]:
# environ.batch_data
# print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics, 0, out=[sys.stdout])
# environ.display_parameters()

# with np.printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan', precision=7, formatter={'float': lambda x: f"{x:12.5e}"}):
#     environ.print_logit_grads('gradients')

# environ_params = environ.get_task_specific_parameters()
# environ_params = environ.get_arch_parameters()
# environ_params = environ.get_backbone_parameters()
# print(environ_params)
# for param in environ_params:
#     print(param.grad.shape, '\n', param.grad)
#     print(param)

In [None]:
environ.display_trained_logits(ns.current_epoch)
environ.display_trained_policy(ns.current_epoch)

In [None]:
environ.display_test_sample_policy(ns.current_epoch, hard_sampling = True)
environ.display_train_sample_policy(ns.current_epoch, hard_sampling = True)

In [None]:
# environ.define_optimizer(policy_learning=True)

In [None]:
print(environ.optimizers['alphas'])
print(environ.optimizers['weights'])

In [None]:
print('Policy  initial_lr : ', environ.optimizers['alphas'].param_groups[0]['initial_lr'], 'lr : ',environ.optimizers['alphas'].param_groups[0]['lr'])
print('Weights initial_lr : ', environ.optimizers['weights'].param_groups[0]['initial_lr'], 'lr : ',environ.optimizers['weights'].param_groups[0]['lr'])
print('Weights initial_lr : ', environ.optimizers['weights'].param_groups[1]['initial_lr'], 'lr : ',environ.optimizers['weights'].param_groups[1]['lr'])

In [None]:
wandb.run is None

In [None]:
# opt['exp_instance'] = '0218_1358'     
# folder_name=  f"{opt['exp_instance']}_bs{opt['train']['batch_size']:03d}_{opt['train']['decay_lr_rate']:3.2f}_{opt['train']['decay_lr_freq']}"
# print()
# opt['exp_instance'] = datetime.now().strftime("%m%d_%H%M")
# opt['exp_description'] = f"No Alternating Weight/Policy - training all done with both weights and policy"
# folder_name=  f"{opt['exp_instance']}_bs{opt['train']['batch_size']:03d}_{opt['train']['decay_lr_rate']:3.2f}_{opt['train']['decay_lr_freq']}"

In [None]:
# wandb.finish()

In [None]:
# 
p = environ.get_current_state(0)

In [None]:
pp.pprint(p)

### Post Warm-up Training stuff

In [None]:
pp.pprint(environ.val_metrics)

In [None]:
environ.networks['mtl-net'].arch_parameters()

In [None]:
p = environ.get_sample_policy(hard_sampling = False)
print(p)
p = environ.get_policy_prob()
print(p)
p = environ.get_policy_logits()
print(p)

# p = environ.get_current_policy()
# print(p)

In [None]:
a = softmax([0.0, 1])
print(a)
sampled = np.random.choice((1, 0), p=a)
print(sampled)

In [None]:
print(environ.optimizers['weights'])
print(environ.schedulers['weights'].get_last_lr())

In [None]:
print('losses.keys      : ', environ.losses.keys())
print('losses[task]keys : ', environ.losses['task1'].keys())
pp.pprint(environ.losses)

In [None]:
print( environ.val_metrics.keys())
# pp.pprint(val_metrics)
print(type(environ.val_metrics['aggregated']))
print()
print(type(environ.val_metrics['task1']['classification_agg']))
print()
pp.pprint(environ.val_metrics)

In [None]:
# import pickle
# with open("val_metrics.pkl", mode= 'wb') as f:
#         pickle.dump(val_metrics, f)
    
# with open('val_metrics.pkl', 'rb') as f:    
#     tst_val_metrics = pickle.load(f)

In [None]:
# print(environ.input.shape) 
# a = getattr(environ, 'task1_pred')
# yc_data = environ.batch['task1_data']
# print(yc_data.shape)
# yc_ind = environ.batch['task1_ind']
# print(yc_ind.shape)
# yc_hat_all = getattr(environ, 'task1_pred')
# print(yc_hat_all.shape)
# yc_hat  = yc_hat_all[yc_ind[0], yc_ind[1]]
# print(yc_hat_all.shape, yc_hat.shape)

# 
# environ.losses
# loss = {}
# for key in environ.losses.keys():
#     loss[key] = {}
#     for subkey, v in environ.losses[key].items():
#         print(f" key:  {key}   subkey: {subkey} ")
#         if isinstance(v, torch.Tensor):
#             loss[key][subkey] = v.data
#             print(f" Tensor  -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
#         else:
#             loss[key][subkey] = v
#             print(f" integer -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
# pp.pprint(tst_val_metrics)             

In [None]:
# print('metrics.keys: ', environ.metrics.keys())
# print('metrics[task].keys: ', environ.metrics['task1'].keys())
# pp.pprint(environ.metrics['task1'])
# pp.pprint(environ.losses['task1']['total'])

In [None]:
# title='Iteration'
# for t_id, _ in enumerate(environ.tasks):
#     task_key = f"task{t_id+1}"
# #     print_heading(f"{title}  {current_iter}  {task_key} : {val_metrics[task_key]['classification_agg']}", verbose = True)

#     for key, _  in val_metrics[task_key]['classification_agg'].items():
#         print('%s/%-20s'%(task_key, key), val_metrics[task_key]['classification_agg'][key], current_iter)
#         print(f"{task_key:s}/{key:20s}", val_metrics[task_key]['classification_agg'][key], current_iter)
#         print()
#             # print_current_errors(os.path.join(self.log_dir, 'loss.txt'), current_iter,key, loss[key], time.time() - start_time)

In [None]:
# environ.print_loss(current_iter, start_time, metrics = val_metrics['loss'], verbose=True)
# print(opt['lambdas'])
# p = (opt['lambdas'][0] * environ.losses['tasks']['task1'])
# print(p)

# environ.print_val_metrics(current_iter, start_time, val_metrics , title='validation', verbose=True)    

In [None]:
# print(current_iter)
# print_metrics_cr(current_iter, t1 - t0, None, val_metrics , True)
# environ.print_val_metrics(current_iter, start_time, val_metrics, title='validation', verbose = True)

In [None]:
print(f" val_metric keys               : {val_metrics.keys()}")
print(f" loss keys                     : {val_metrics['loss'].keys()}")
print(f" task1 keys                    : {val_metrics['task1'].keys()}")
print(f" task1 classification keys     : {val_metrics['task1']['classification'].keys()}")
print(f" task1 classification_agg keys : {val_metrics['task1']['classification_agg'].keys()}")
print()
print(f" task1                       : {val_metrics['task1']['classification_agg']['loss']:5f}")
print(f" task2                       : {val_metrics['task2']['classification_agg']['loss']:5f}")
print(f" task3                       : {val_metrics['task3']['classification_agg']['loss']:5f}")
print(f" loss                        : {val_metrics['loss']['total']:5f}")
print(f" train_time                  : {val_metrics['train_time']:2f}")
print(f" epoch                       : {val_metrics['epoch']}")


### Post Weight + Policy Training Stuff 

In [None]:
environ.networks['mtl-net'].backbone.layer_config

In [None]:
num_blocks = 6
num_policy_layers = 6
gt =  torch.ones((num_blocks)).long()
gt0 =  torch.zeros((num_blocks)).long()
print(gt)
print(gt0)

loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
print(loss_weights)

In [None]:
if environ.opt['diff_sparsity_weights'] and not environ.opt['is_sharing']:
    print(' cond 1')
    ## Assign higher weights to higher layers 
    loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
    print(f"{task_key} sparsity error:  {2 * (loss_weights[-num_blocks:] * environ.cross_entropy2(logits[-num_blocks:], gt)).mean()})")
    print_dbg(f" loss_weights :  {loss_weights}", verbose = True)
    print_dbg(f" cross_entropy:  {environ.cross_entropy2(logits[-num_blocks:], gt)}  ", verbose = True)
    print_dbg(f" loss[sparsity][{task_key}]: {self.losses['sparsity'][task_key] } ", verbose = True)

else:
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between \n Logits   : \n{logits[-num_blocks:]} \n and gt: \n{gt} \n", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-num_blocks:], gt)}")
    
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt[-1:])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt0[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt0[-1:])} \n")
    
    print('\n cond 3')    
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt[0:1])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt0[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt0[0:1])} \n")
        
        

In [None]:
# flag = 'update_w'
# environ.fix_alpha
# environ.free_w(opt['fix_BN'])

flag = 'update_alpha'
environ.fix_weights()
environ.free_alpha()

In [None]:
environ.networks['mtl-net'].num_layers

In [None]:
print(f"current_iters         : {current_iter}")  
print(f"current_epochs           : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

train_total_epochs += 5

print(f"current_iters         : {current_iter}")  
print(f"current_epochs           : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

In [None]:
# print_metrics_cr(current_epoch, time.time() - t0, None, environ.val_metrics , num_prints)      

# num_prints += 1
# t0 = time.time()

# # Take check point
# environ.save_checkpoint('latest', current_iter)
# environ.train()
# #-------------------------------------------------------
# # END validation process
# #-------------------------------------------------------       
# flag = 'update_alpha'
# environ.fix_w()
# environ.free_alpha()

In [None]:
# dilation = 2
# kernel_size = np.asarray((3, 3))
# upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
# print(upsampled_kernel_size)

In [None]:
# environ.optimizers['weights'].param_groups[0]
# for param_group in optimizer.param_groups:
#     return param_group['lr']

In [None]:
environ.schedulers['weights'].get_last_lr()

In [None]:
current_state = {}
for k, v in environ.optimizers.items():
    print(f'state dict for {k} = {v}')
    current_state[k] = v.state_dict()
pp.pprint(current_state)

In [None]:
current_state = {}
for k, v in environ.schedulers.items():
    print(f'state dict for {k} = {v}')
    print(v.state_dict())

### Losses and Metrics

In [None]:
trn_losses = environ.losses

In [None]:
print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics , num_prints)      

In [None]:
# print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics , num_prints)      

In [None]:
# pp.pprint(environ.losses)
pp.pprint(trn_losses)

In [None]:
pp.pprint(environ.val_metrics)

In [None]:
# environ.opt['train']['Lambda_sharing'] = 0.5
# opt['train']['Lambda_sharing'] = 0.5

# environ.opt['train']['policy_lr'] = 0.001
# opt['train']['policy_lr'] = 0.001

In [None]:
environ.losses.keys()
pp.pprint(environ.losses)

In [None]:
tmp = environ.get_loss_dict()
print(tmp.keys())
pp.pprint(tmp)

In [None]:
print(opt['diff_sparsity_weights'])
print(opt['is_sharing'])
print(opt['diff_sparsity_weights'] and not opt['is_sharing'])
print(environ.opt['train']['Lambda_sharing'])
print(opt['train']['Lambda_sharing'])
print(environ.opt['train']['Lambda_sparsity'])
print(opt['train']['Lambda_sparsity'])
print(environ.opt['train']['policy_lr'])
print(opt['train']['policy_lr'])

### Policy / Logit stuff

In [None]:
from scipy.special          import softmax

In [None]:
np.set_printoptions(precision=8,edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions(precision=8,linewidth=132)

#### `get_task_logits(n)` Get logits for task group n

In [None]:
task_logits = environ.get_task_logits(1)
print(task_logits)

#### `get_arch_parameters()`: Get last used logits from network

In [None]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

In [None]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

#### `get_policy_logits()`:  Get Policy Logits - returns same as `get_arch_parameters()`

In [None]:
logs = environ.get_policy_logits()
for i in logs:
    print(i, '\n')
# probs = softmax(logs, axis= -1)
# for i in probs:
#     print(i, '\n')

#### `get_policy_prob()` : Gets the softmax of the logits

In [None]:
policy_softmaxs = environ.get_policy_prob()
for i in policy_softmaxs:
    print(i, '\n')

#### `get_sample_policy( hard_sampling = False)` : Calls test_sample_policy of network with random choices based on softmax of logits

In [None]:
policy_softmaxs = environ.get_policy_prob()
policies,logits = environ.get_sample_policy(hard_sampling = False)

for l, p, s in zip(logits, policies, policy_softmaxs) :
    for  l_row, p_row, s_row in zip(l, p, s):
        print( l_row,'\t', p_row, '\t', s_row)
    print('\n')

#### `get_sample_policy( hard_sampling = True)` : Calls test_sample_policy of network using ARGMAX of logits

In [None]:
policy_softmaxs = environ.get_policy_prob()
hard_policies, logits = environ.get_sample_policy(hard_sampling = True)

for p,l,s in zip(hard_policies, logits, policy_softmaxs) :
    for  p_row, l_row, s_row in zip(p, l, s):
        print( l_row,'\t', p_row, '\t', s_row)
    print('\n')

#### Print

In [None]:
print(f" Layer    task 1      task 2      task 3")
print(f" -----    ------      ------      ------")
for idx, (l1, l2, l3) in enumerate(zip(hard_policies[0], hard_policies[1], hard_policies[2]),1):
    print(f"   {idx}      {l1}       {l2}       {l3}")
    

    print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [None]:
def display_trained_policy(iter):

    policy_softmaxs = environ.get_policy_prob()
    policy_argmaxs = 1-np.argmax(policy_softmaxs, axis = -1)
    print(f"  Trained polcies at iteration: {iter} ")
    print(f"                   task 1                           task 2                         task 3        ")
    print(f" Layer       softmax        select          softmax        select          softmax        select   ")
    print(f" -----    ---------------   ------       ---------------   ------       ---------------   ------   ")
    for idx, (l1,l2,l3,  p1,p2,p3) in enumerate(zip(policy_softmaxs[0], policy_softmaxs[1], policy_softmaxs[2], policy_argmaxs[0], policy_argmaxs[1], policy_argmaxs[2]),1):
        print(f"   {idx}      {l1[0]:.4f}   {l1[1]:.4f}   {p1:4d}    {l2[0]:11.4f}   {l2[1]:.4f}   {p2:4d}    {l3[0]:11.4f}   {l3[1]:.4f}   {p3:4d}")

    print()
# print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [None]:
display_trained_policy(5)

In [None]:
print(f"                        POLICIES (SOFTMAX)                                       task 3          ")
print(f" Layer    task1              task2            task3 softmax         softmax         argmax         softmax         argmax   ")
print(f" -----    -------------     -------------     -------------   ------   ")
for idx, (l1,l2,l3, h1,h2,h3) in enumerate(zip(policy_softmaxs[0], policy_softmaxs[1], policy_softmaxs[2],hard_policies[0], hard_policies[1], hard_policies[2]),1):
    print(f"   {idx}      {l1[0]:.4f} {l1[1]:.4f}     {l2[0]:.4f} {l2[1]:.4f}     {l3[0]:.4f} {l3[1]:.4f}    {h3}")
    
print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [None]:
# print(policy_softmaxs[2], np.argmax(1-policy_softmaxs[2], axis = -1))
print(policy_softmaxs, np.argmax(policy_softmaxs, axis = -1))

#### `get_current_logits()` : Calls test_sample_policy of network using ARGMAX of logits

In [None]:
logits  = (environ.get_current_logits())
for i in logits:
    print(i ,'\n')

#### `get_current_policy()` : Calls test_sample_policy of network using ARGMAX of logits

In [None]:
pols  = (environ.get_current_policy())

for i in pols:
    print(i ,'\n')

#### `gumbel_softmax()`  

In [None]:
np.set_printoptions(precision=8,edgeitems=3, infstr='inf', linewidth=150, nanstr='nan', floatmode = 'maxprec_equal')
torch.set_printoptions(precision=8,linewidth=132)

In [None]:
print(environ.temp)
# tau = environ.temp
tau = 1
for i in range(3): 
    logits_tensor = torch.tensor(logits[0])
    # Sample soft categorical using reparametrization trick:
    gumbel_soft = F.gumbel_softmax(logits_tensor, tau=tau, hard=False).cpu().numpy() 

    # Sample hard categorical using "Straight-through" trick:
    gumbel_hard  = F.gumbel_softmax(logits_tensor, tau=tau, hard=True).cpu().numpy()
    
    for l, gs, gh in zip(lgts, gumbel_soft, gumbel_hard):
        print(f"   {l}   \t {gs}            \t {gh}")
#     print(lgts)
#     print(gumbel_soft)
#     print(gumbel_hard)
    print()

In [None]:
for lgts in logits:
    logits_tensor = torch.tensor(lgts)
    print(lgts)
    # Sample soft categorical using reparametrization trick:
    gumbel_soft = F.gumbel_softmax(logits_tensor, tau=1, hard=False)
    print(gumbel_soft)

    # Sample hard categorical using "Straight-through" trick:
    gumbel_hard  = F.gumbel_softmax(logits_tensor, tau=1, hard=True)
    print(gumbel_hard)
    print()

In [None]:
smax = scipy.special.softmax(logs, axis =1)
# smax = np.array( 
# [[0.46973792, 0.530262  ],
#  [0.45025694, 0.549743  ],
#  [0.4443086 , 0.5556915 ],
#  [0.4138397 , 0.58616036],
#  [0.4140113 , 0.5859887 ],
#  [0.42114905, 0.57885087]])

print(smax.shape)
print(smax)
print(smax[0])
print(smax[0].sum())
print(np.random.choice((1,0), p =smax[0]))

In [None]:
logs = np.array(
[[0.33064184, 0.42053092],
 [0.3532089 , 0.52056104],
 [0.3888512 , 0.5680909 ],
 [0.42039296, 0.694217  ],
 [0.4519742 , 0.73311865],
 [0.48401102, 0.7522658 ]],
)