## Initialization  

### Initialization

In [1]:
%load_ext autoreload
%autoreload 2

import os 
import sys
sys.path.insert(0, './src')
print(sys.path)
import time
import argparse
import yaml
import copy, pprint
from time import sleep
from datetime import datetime

import numpy  as np
import torch  
import wandb
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader 
import scipy.sparse
from scipy.special import softmax
 
from GPUtil import showUtilization as gpu_usage
# from tqdm.notebook import trange, tqdm
from tqdm     import tqdm,trange 

from envs.sparsechem_env_dev import SparseChemEnv_Dev
from utils.sparsechem_utils import load_sparse, load_task_weights, class_fold_counts, fold_and_transform_inputs, print_metrics_cr
from dataloaders.chembl_dataloader_dev import ClassRegrSparseDataset_v3, ClassRegrSparseDataset, InfiniteDataLoader
from utils.util import ( makedir, print_separator, create_path, print_yaml, print_loss, should, 
                         fix_random_seed, read_yaml, timestring, print_heading, print_dbg, 
                         print_underline, write_config_report, display_config, get_command_line_args, is_notebook)

print(' Cuda is available  : ', torch.cuda.is_available())
print(' CUDA device count  : ', torch.cuda.device_count())
print(' CUDA current device: ', torch.cuda.current_device())
print(' GPU Processes      : \n', torch.cuda.list_gpu_processes())
print()

for i in range(torch.cuda.device_count()):
    print(f" Device : cuda:{i}")
    print('   name:       ', torch.cuda.get_device_name())
    print('   capability: ', torch.cuda.get_device_capability())
    print('   properties: ', torch.cuda.get_device_properties(i))
    ## current GPU memory usage by tensors in bytes for a given device
    print('   Allocated : ', torch.cuda.memory_allocated(i) ) 
    ## current GPU memory managed by caching allocator in bytes for a given device, in previous PyTorch versions the command was torch.cuda.memory_cached
    print('   Reserved  : ', torch.cuda.memory_reserved(i) )   
    print()

gpu_usage()                             

pp = pprint.PrettyPrinter(indent=4)
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions(precision=6, linewidth=132)
# torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
pd.options.display.width = 132

os.environ["WANDB_NOTEBOOK_NAME"] = "Adashare_Training_Resume.ipynb"

['./src', '/home/kbardool/kusanagi/AdaSparseChem', '/home/kbardool/miniconda3/envs/pyt-gpu/lib/python39.zip', '/home/kbardool/miniconda3/envs/pyt-gpu/lib/python3.9', '/home/kbardool/miniconda3/envs/pyt-gpu/lib/python3.9/lib-dynload', '', '/home/kbardool/miniconda3/envs/pyt-gpu/lib/python3.9/site-packages', '/home/kbardool/miniconda3/envs/pyt-gpu/lib/python3.9/site-packages/IPython/extensions', '/home/kbardool/.ipython']
 Cuda is available  :  True
 CUDA device count  :  1
 CUDA current device:  0
 GPU Processes      : 
 GPU:0
no processes are running

 Device : cuda:0
   name:        NVIDIA GeForce GTX 970M
   capability:  (5, 2)
   properties:  _CudaDeviceProperties(name='NVIDIA GeForce GTX 970M', major=5, minor=2, total_memory=3071MB, multi_processor_count=10)
   Allocated :  0
   Reserved  :  0

| ID | GPU  | MEM |
-------------------
|  0 | nan% |  1% |


## Create Environment

### Parse Input Args

In [None]:
input_args = " --config yamls/chembl_3task_train.yaml " \
             " --exp_id      14oarkpu" \
             " --exp_name    0304_1549" \
             " --exp_desc    Train with dropout 0.5" \
             " --seed_idx    0 "\
             " --batch_size  128".split()
# get command line arguments
args = get_command_line_args(input_args)

    
print_underline(' command line parms : ', True)
for key, val in vars(args).items():
    print(f" {key:.<25s}  {val}")

### Read Configuration File

In [None]:
# ********************************************************************
# ****************** create folders and print options ****************
# ********************************************************************
print_separator('READ YAML')

opt, gpu_ids = read_yaml(args)


fix_random_seed(opt["random_seed"])
    
# opt['exp_instance'] = datetime.now().strftime("%m%d_%H%M")
# opt['exp_instance'] = '0218_1358'     
# opt['exp_description'] = f"Retrain phase for 0218_1358"
# folder_name=  f"{opt['exp_instance']}_bs{opt['train']['batch_size']:03d}_{opt['train']['decay_lr_rate']:3.2f}_{opt['train']['decay_lr_freq']}"

create_path(opt)    
print()
print_heading(f" experiment name       : {opt['exp_name']} \n"
              f" experiment id         : {opt['exp_id']} \n"
              f" folder_name           : {opt['exp_folder']} \n"
              f" experiment description: {opt['exp_description']}\n"
              f" Random seeds          : {opt['seed_list']}\n"
              f" Random  seed used     : {opt['random_seed']} \n"
              f" log folder            : {opt['paths']['log_dir']}\n"
              f" checkpoint folder     : {opt['paths']['checkpoint_dir']}", verbose = True)
print(f" Gpu ids: {gpu_ids}     seed index: {args.seed_idx}      policy_iter:{opt['train']['policy_iter']}")
print(opt['dataload']['x_split_ratios'])

config_filename = 'run_config_seed_%04d.txt' % (opt['random_seed'])
write_config_report(opt, filename = config_filename)    
# display_config(opt)
best_results = {}

### Setup Dataloader and Model  

In [6]:
# ********************************************************************
# ******************** Prepare the dataloaders ***********************
# ********************************************************************
# load the dataloader
print_separator('CREATE DATALOADERS')

trainset0 = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 0, verbose = False)
trainset1 = trainset0
trainset2 = trainset0
# trainset1 = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 1)
# trainset2 = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 2)
valset    = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 1)
testset   = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 2)

warmup_trn_loader = InfiniteDataLoader(trainset0 , batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset0.collate, shuffle=True)
weight_trn_loader = InfiniteDataLoader(trainset1 , batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset1.collate, shuffle=True)
policy_trn_loader = InfiniteDataLoader(trainset2 , batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset2.collate, shuffle=True)
val_loader        = InfiniteDataLoader(valset    , batch_size=opt['train']['batch_size'], num_workers = 1, pin_memory=True, collate_fn=valset.collate  , shuffle=True)
test_loader       = InfiniteDataLoader(testset   , batch_size=32                        , num_workers = 1, pin_memory=True, collate_fn=testset.collate  , shuffle=True)

opt['train']['weight_iter_alternate'] = opt['train'].get('weight_iter_alternate' , len(weight_trn_loader))
opt['train']['alpha_iter_alternate']  = opt['train'].get('alpha_iter_alternate'  , len(policy_trn_loader))

# ********************************************************************
# ********************Create the environment *************************
# ********************************************************************
# create the model and the pretrain model
print_separator('CREATE THE ENVIRONMENT')
environ = SparseChemEnv_Dev(log_dir          = opt['paths']['log_dir'], 
                            checkpoint_dir   = opt['paths']['checkpoint_dir'], 
                            exp_name         = opt['exp_name'],
                            tasks_num_class  = opt['tasks_num_class'], 
                            init_neg_logits  = opt['train']['init_neg_logits'], 
                            device           = gpu_ids[0],
                            init_temperature = opt['train']['init_temp'], 
                            temperature_decay= opt['train']['decay_temp'], 
                            is_train         = True,
                            opt              = opt, 
                            verbose          = False)
environ.define_optimizer(policy_learning=False)
environ.define_scheduler(policy_learning=False)

cfg = environ.print_configuration()
# print(cfg)
write_config_report(opt, cfg, filename = config_filename, mode = 'a')

##################################################
############### CREATE DATALOADERS ###############
##################################################
##################################################
############# CREATE THE ENVIRONMENT #############
##################################################
-------------------------------------------------------
* SparseChemEnv_Dev  Initializtion - verbose: False
------------------------------------------------------- 

 device is  cuda:0
--------------------------------------------------------
* SparseChemEnv_Dev environment successfully created
-------------------------------------------------------- 



###  Weights and Biases Initialization 

In [7]:
print(opt['exp_id'], opt['exp_name'], opt['project_name']) # , opt['exp_instance'])
# opt['exp_id'] = wandb.util.generate_id()

14oarkpu 0304_1549 AdaSparseChem


In [8]:
run = wandb.init(project=opt['project_name'], entity="kbardool", resume="allow", id = opt['exp_id'], name = opt['exp_name'])
 
print(f"PROJECT NAME: {wandb.run.project} RUN ID:  {wandb.run.id}    RUN NAME: {wandb.run.name}") 

# # wandb.init(id='1q3dt2en' )
# assert wandb.run is None, "Run is still running"

wandb.config = opt.copy()
wandb.watch(environ.networks['mtl-net'], log='all', log_freq=10)     ###  Weights and Biases Initialization 

[34m[1mwandb[0m: Currently logged in as: [33mkbardool[0m (use `wandb login --relogin` to force relogin)


PROJECT NAME: AdaSparseChem RUN ID:  14oarkpu    RUN NAME: 0304_1549


[]

In [9]:
# wandb_run_name = opt['exp_instance']
# wandb_run_id = wandb.util.generate_id()
# print(wandb.run.id, wandb.run.name) 
# print(opt['exp_uid'], opt['exp_instance'])
# run = wandb.init(project="AdaSparseChem", entity="kbardool", resume="allow", id = opt['exp_id'], name = opt['exp_instance'])

# wandb.config = opt.copy()
# wandb.watch(environ.networks['mtl-net'], log='all', log_freq=10)
environ.device

device(type='cuda', index=0)

###  Weights and Biases Initialization 

In [37]:
wandb_run_id = wandb.util.generate_id()
print(wandb_run_id)
 
run = wandb.init(project="AdaSparseChem", entity="kbardool", resume="allow", id = wandb_run_id)

print(wandb.run.id, wandb.run.name) 

[34m[1mwandb[0m: Currently logged in as: [33mkbardool[0m (use `wandb login --relogin` to force relogin)


3ely9bo8


3ely9bo8 splendid-durian-18


In [38]:
wandb.config = opt.copy()
wandb.watch(environ.networks['mtl-net'], log='all', log_freq=10)

[]

In [12]:
# wandb.finish()

## Initiate / Resume Training Prep

In [10]:
print(opt['train']['which_iter'])

a = environ.load_checkpoint('warmup_ep_40_seed_0088', path = '../experiments/AdaSparseChem/50x6_0304_1549_plr0.01_sp0.0001_sh0.01/')

warmup
=> loading snapshot from ../experiments/AdaSparseChem/50x6_0304_1549_plr0.01_sp0.0001_sh0.01/warmup_ep_40_seed_0088_model.pth.tar
   Loading to GPU cuda:0


In [11]:
opt['train']['resume'] = True
opt['train']['which_iter'] = 'warmup_ep_40_seed_0088'
if opt['train']['resume']:
    print_separator('Resume training')
    current_iter = environ.load_checkpoint(opt['train']['which_iter'])
    environ.networks['mtl-net'].reset_logits()
else:
    print_separator('Initiate Training ')

##################################################
################ Resume training #################
##################################################
=> loading snapshot from ../experiments/AdaSparseChem/50x6_0304_1549_plr0.01_sp0.0001_sh0.01/warmup_ep_40_seed_0088_model.pth.tar
   Loading to GPU cuda:0


###  Load Saved Checkpoint

In [13]:
current_iter = environ.load_checkpoint('latest_weights_policy')

print('Evaluating the snapshot saved at %d iter' % current_iter)

=> loading snapshot from ../experiments/AdaSparseChem/50x6_0218_1358_plr0.01_sp0.01_sh1.0/latest_weights_policy_model.pth.tar
Loading to CPU
  networks -  network:  mtl-net
  load snapshot - network:  mtl-net
    network mtl-net - item task1_logits
    network mtl-net - item task2_logits
    network mtl-net - item task3_logits
    network mtl-net - item backbone.Input_linear.weight
    network mtl-net - item backbone.Input_linear.bias
    network mtl-net - item backbone.blocks.0.0.linear.weight
    network mtl-net - item backbone.blocks.0.0.linear.bias
    network mtl-net - item backbone.blocks.1.0.linear.weight
    network mtl-net - item backbone.blocks.1.0.linear.bias
    network mtl-net - item backbone.blocks.2.0.linear.weight
    network mtl-net - item backbone.blocks.2.0.linear.bias
    network mtl-net - item backbone.blocks.3.0.linear.weight
    network mtl-net - item backbone.blocks.3.0.linear.bias
    network mtl-net - item backbone.blocks.4.0.linear.weight
    network mtl-net 

In [15]:
current_epoch  = 262

## Training Preparation

### Training Preparation

In [16]:
# if opt['train']['resume']:
#     print_separator('Resume training')
#     current_iter = environ.load(opt['train']['which_iter'])
#     environ.networks['mtl-net'].reset_logits()
# else:
#     print_separator('Initiate Training ')

if torch.cuda.is_available():
    print(' cuda available', gpu_ids)   
    environ.cuda(gpu_ids)
else:
    print(' cuda not available')
    environ.cpu()

if opt['train']['print_freq'] == -1:
    print(f" set print_freq to length of train loader: {len(warmup_trn_loader)}")
    opt['train']['print_freq']    = len(warmup_trn_loader)

if opt['train']['val_iters'] == -1:
    print(f" set eval_iters to length of val loader  : {len(val_loader)}")
    eval_iters    = len(val_loader)    
else:
    eval_iters    = opt['train']['val_iters']


# opt['train']['weight_iter_alternate'] = len(warmup_trn_loader)
# opt['train']['alpha_iter_alternate']  = len(warmup_trn_loader)
stop_iter_w = opt['train']['weight_iter_alternate']
stop_iter_a = opt['train']['alpha_iter_alternate'] 
    
# Fix Alpha -     
flag           = 'update_w'
environ.fix_alpha()
environ.free_weights(opt['fix_BN'])

 
# current_iter_w = 0 
# current_iter_a = 0
# best_value     = 0 
# best_iter      = 0
p_epoch        = 0
w_epoch        = 0

# best_metrics   = None
# flag_warmup    = True
num_prints     = 0
num_blocks     = sum(environ.networks['mtl-net'].layers)

warm_up_epochs     = opt['train']['warm_up_epochs']
train_total_epochs = opt['train']['training_epochs']
curriculum_speed   = opt['curriculum_speed'] 

stop_epoch_warmup  = current_epoch + warm_up_epochs

 cuda available [0]
 set eval_iters to length of val loader  : 36


In [17]:
print(f" opt['train']['weight_iter_alternate']   {opt['train']['weight_iter_alternate']}")
print(f" opt['train']['alpha_iter_alternate']    {opt['train']['alpha_iter_alternate']}")

 opt['train']['weight_iter_alternate']   108
 opt['train']['alpha_iter_alternate']    108


In [18]:
print(f"\n trainset.y_class                       :  {[ i.shape  for i in trainset.y_class_list]}",
      f"\n trainset1.y_class                      :  {[ i.shape  for i in trainset1.y_class_list]}",
      f"\n trainset2.y_class                      :  {[ i.shape  for i in trainset2.y_class_list]}",
      f"\n valset.y_class                         :  {[ i.shape  for i in valset.y_class_list  ]} ",
      f"\n                                ",
      f'\n size of training set 0 (warm up)       :  {len(trainset)}',
      f'\n size of training set 1 (network parms) :  {len(trainset1)}',
      f'\n size of training set 2 (policy weights):  {len(trainset2)}',
      f'\n size of validation set                 :  {len(valset)}',
      f'\n                               Total    :  {len(trainset)+len(trainset1)+len(trainset2)+len(valset)}',
      f"\n                                ",
      f"\n batch size                             :  {opt['train']['batch_size']}",
      f"\n                                ",
      f"\n # batches training 0 (warm up)         :  {len(warmup_trn_loader)}",
      f"\n # batches training 1 (network parms)   :  {len(weight_trn_loader)}",
      f"\n # batches training 2 (policy weights)  :  {len(policy_trn_loader)}",
      f"\n # batches validation dataset           :  {len(val_loader)}",
      f"\n                                ",
      f"\n Weight iter alternate                  :  {opt['train']['weight_iter_alternate'] }",
      f"\n Alpha  iter alternate                  :  {opt['train']['alpha_iter_alternate'] }")


 trainset.y_class                       :  [(13791, 5), (13791, 5), (13791, 5)] 
 trainset1.y_class                      :  [(18, 5), (18, 5), (18, 5)] 
 trainset2.y_class                      :  [(18, 5), (18, 5), (18, 5)] 
 valset.y_class                         :  [(4561, 5), (4561, 5), (4561, 5)]  
                                 
 size of training set 0 (warm up)       :  13791 
 size of training set 1 (network parms) :  18 
 size of training set 2 (policy weights):  18 
 size of validation set                 :  4561 
                               Total    :  18388 
                                 
 batch size                             :  128 
                                 
 # batches training 0 (warm up)         :  108 
 # batches training 1 (network parms)   :  108 
 # batches training 2 (policy weights)  :  108 
 # batches validation dataset           :  36 
                                 
 Weight iter alternate                  : 108 
 Alpha  iter alternate        

In [19]:
print(f"\n experiment name           : {opt['exp_name']}",
      f"\n experiment description    : {opt['exp_description']}",
      f"                                \n"
      f"\n Network[mtl_net].layers   : {environ.networks['mtl-net'].layers}",
      f"\n Num_blocks                : {sum(environ.networks['mtl-net'].layers)}"    
      f"                                \n"
      f"\n batch size                : {opt['train']['batch_size']}",    
      f"\n Total iterations          : {opt['train']['total_iters']}",
      f"\n Warm-up iterations        : {opt['train']['warm_up_iters']}",
      f"\n Warm-up epochs            : {opt['train']['warm_up_epochs']}",
      f"\n Warm-up stop              : {stop_epoch_warmup}",
      f"\n train_total_epochs        : {train_total_epochs}",
      f"                                \n"
      f"\n Print Frequency           : {opt['train']['print_freq']}",
      f"\n Validation Frequency      : {opt['train']['val_freq']}",
      f"\n Validation Iterations     : {opt['train']['val_iters']}",
      f"\n eval_iters                : {eval_iters}",
      f"\n which_iter                : {opt['train']['which_iter']}",
      f"\n train_resume              : {opt['train']['resume']}",
      f"                                \n",                     
      f"\n Length warmup_trn_loader  : {len(warmup_trn_loader)}",
      f"\n Length val_loader         : {len(val_loader)}",
      f"\n stop_iter_w               : {stop_iter_w}",
      f"                                \n",
      f"\n fix BN parms              : {opt['fix_BN']}",    
      f"\n Backbone LR               : {opt['train']['backbone_lr']}",
      f"\n Backbone LR               : {opt['train']['task_lr']   }",     
      f"                                \n"
      f"\n Sharing  regularization   : {opt['train']['lambda_sharing']}",    
      f"\n Sparsity regularization   : {opt['train']['lambda_sparsity']}",  
      f"\n Task     regularization   : {opt['train']['lambda_tasks']}",
      f"\n Last Epoch                : {current_epoch} ",
      f"\n # of warm-up epochs to do : {warm_up_epochs}")


 experiment name           : SparseChem 
 experiment description    : No Alternating Weight/Policy - training all done with both weights and policy                                 

 Network[mtl_net].layers   : [1, 1, 1, 1, 1, 1] 
 Num_blocks                : 6                                

 batch size                : 128 
 Total iterations          : 25000 
 Warm-up iterations        : None 
 Warm-up epochs            : 1 
 Warm-up stop              : 263 
 train_total_epochs        : 50                                 

 Print Frequency           : 108 
 Validation Frequency      : 500 
 Validation Iterations     : -1 
 eval_iters                : 36 
 which_iter                : warmup 
 train_resume              : False                                 
 
 Length warmup_trn_loader  : 108 
 Length val_loader         : 36 
 stop_iter_w               : 108                                 
 
 fix BN parms              : False 
 Backbone LR               : 0.001 
 Backbone LR       

In [20]:
print(f"\n folder: {opt['exp_folder']}",
      f"\n layers: {opt['hidden_sizes']}",    
      f"                               \n",
      f"\n diff_sparsity_weights  : {opt['diff_sparsity_weights']}",
      f"\n skip_layer             : {opt['skip_layer']}",
      f"\n is_curriculum          : {opt['is_curriculum']}",
      f"\n curriculum_speed       : {opt['curriculum_speed']}",
      f"                              \n",    
      f"\n decay_lr_rate          : {opt['train']['decay_lr_rate']}",      
      f"\n decay_lr_freq          : {opt['train']['decay_lr_freq']}",     
      f"\n policy_decay_lr_rate   : {opt['train']['policy_decay_lr_rate']}",      
      f"\n policy_decay_lr_freq   : {opt['train']['policy_decay_lr_freq']}", 
      f"                              \n",    
      f"\n policy_lr              : {opt['train']['policy_lr']}", 
      f"\n lambda_sparsity        : {opt['train']['lambda_sparsity']}",      
      f"\n lambda_sharing         : {opt['train']['lambda_sharing']}", 
      f"                              \n",    
      f"\n lambda_tasks           : {opt['train']['lambda_tasks']}",  
      f"\n init_temp              : {opt['train']['init_temp']}",
      f"\n decay_temp             : {opt['train']['decay_temp']}",    
      f"\n decay_temp_freq        : {opt['train']['decay_temp_freq']}",   
      f"\n init_method            : {opt['train']['init_method']}", 
      f"\n init_neg_logits        : {opt['train']['init_neg_logits']}",    
      f"\n hard_sampling          : {opt['train']['hard_sampling']}",
      f"\n Warm-up epochs         : {opt['train']['warm_up_epochs']}",
      f"\n training epochs        : {opt['train']['training_epochs']}")


 folder: 50x6_0218_1358_plr0.01_sp0.01_sh1.0 
 layers: [50, 50, 50, 50, 50, 50]                                
 
 diff_sparsity_weights  : False 
 skip_layer             : 0 
 is_curriculum          : False 
 curriculum_speed       : 1                               
 
 decay_lr_rate          : 0.85 
 decay_lr_freq          : 2000 
 policy_decay_lr_rate   : 0.85 
 policy_decay_lr_freq   : 2200                               
 
 policy_lr              : 0.01 
 lambda_sparsity        : 0.01 
 lambda_sharing         : 1.0                               
 
 lambda_tasks           : 1 
 init_temp              : 4 
 decay_temp             : 0.965 
 decay_temp_freq        : 2 
 init_method            : random 
 init_neg_logits        : None 
 hard_sampling          : True 
 Warm-up epochs         : 1 
 training epochs        : 50


## Resume Weight/Policy Training

### Training Preparation

In [21]:
print_heading( f"** {timestring()} - Training iteration {current_iter}   flag: {flag} \n"
               f"** Set optimizer and scheduler to policy_learning = True (Switch weight optimizer from ADAM to SGD)\n"
               f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
               f"** Take checkpoint and block gradient flow through Policy net", verbose=True)

environ.define_optimizer(policy_learning=True)
environ.define_scheduler(policy_learning=True)

flag_warmup = False
flag = 'update_w'
environ.fix_alpha()
environ.free_weights(opt['fix_BN'])
    
leave      = False
verbose    = False    

------------------------------------------------------------------------------------------------------------------------
** 2022-02-21 18:10:29:638057 - Training iteration 53357   flag: update_w 
** Set optimizer and scheduler to policy_learning = True (Switch weight optimizer from ADAM to SGD)
** Switch from Warm Up training to Alternate training Weights & Policy 
** Take checkpoint and block gradient flow through Policy net
------------------------------------------------------------------------------------------------------------------------ 



### Resume Training

In [15]:
current_epoch  = 262

In [43]:
train_total_epochs = 10
stop_epoch_training = current_epoch +train_total_epochs


print(f"current_epoch          : {current_epoch}") 
print(f"current_iters          : {current_iter}")  
print(f"train_total_epochs     : {train_total_epochs}") 
print(f"stop_epoch_training    : {stop_epoch_training}")
print(f"Batches in weight epoch (stop_iter_w): {stop_iter_w}")
print(f"Batches in policy epoch (stop_iter_a): {stop_iter_a}")
print()

current_epoch          : 302
current_iters          : 61997
train_total_epochs     : 10
stop_epoch_training    : 312


In [None]:
# print_loss(environ.val_metrics, title = f"[e] Last epoch:{current_epoch}  it:{current_iter}")

environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
# environ.display_trained_logits(current_epoch)

In [None]:
print_heading(f" Last Epoch Completed: {current_epoch}   # of epochs to do:  {train_total_epochs} -  epochs {current_epoch+1} to {stop_epoch_training}"
              f"\n policy_lr           : {opt['train']['policy_lr']}"
              f"\n lambda_sparsity     : {opt['train']['lambda_sparsity']}"
              f"\n lambda_sharing      : {opt['train']['lambda_sharing']}", verbose = True)

line_count = 0

In [None]:
while current_epoch < stop_epoch_training:
    current_epoch+=1
    #-----------------------------------------------------------------------------------------------------------
    # Set number of layers to train based on cirriculum_speed and p_epoch (number of epochs of policy training)
    # e.g., When curriculum_speed == 3, num_train_layers is incremented  after every 3 policy training epochs
    #-----------------------------------------------------------------------------------------------------------
    num_train_layers = (p_epoch // opt['curriculum_speed']) + 1  if opt['is_curriculum'] else None

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------
    if flag == 'update_w':
        start_time = time.time()

        with trange(+1, stop_iter_w+1 , initial = 0, total = stop_iter_w, position=0,
                     leave= leave, desc=f"Epoch {current_epoch} weight training") as t_weights :
            
            for current_iter_w in t_weights:    
                current_iter += 1
                environ.train()
                batch = next(weight_trn_loader)
                environ.set_inputs(batch, weight_trn_loader.dataset.input_size)
 
                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)

                t_weights.set_postfix({'it' : current_iter, 
                                       'Lss': f"{environ.losses['losses']['total'].item():.4f}" , 
                                       'Spr': f"{environ.losses['sparsity']['total'].item():.4e}",  
                                       'Shr': f"{environ.losses['sharing']['total'].item():.4e}"})  

        trn_losses = environ.losses
        environ.print_trn_metrics(current_epoch, current_iter, start_time, title = f"[Weight Trn]", to_display = False)
        wandb.log(environ.losses)
        
        
        #--------------------------------------------------------------------
        # validation process (here current_iter_w and stop_iter_w are equal)
        #--------------------------------------------------------------------
        val_metrics = environ.evaluate(val_loader,  
                                       is_policy=opt['policy'],
                                       policy_sampling = 'eval',
                                       num_train_layers=num_train_layers, 
                                       hard_sampling=opt['train']['hard_sampling'],
                                       eval_iters = eval_iters, progress = True, leave = leave, verbose = False)  

        environ.print_val_metrics(current_epoch, current_iter, start_time, title = f"[Weight Val]", verbose = False)
        print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics, line_count, out=[sys.stdout, environ.log_file]) 
        line_count +=1

        #------------------------------------------------------------------------ 
        #  Save Best Checkpoint Code (saved below and in sparsechem_env_dev.py)
        #----------------------------------------------------------------------- 
        # Take check point:     environ.save_checkpoint('latest_weights', current_iter)
        #-----------------------------------------------------------------------
        # END validation process 
        #-----------------------------------------------------------------------
        flag = 'update_alpha'
        environ.fix_weights()
        environ.free_alpha()
        
#         environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
#         environ.display_trained_logits(current_epoch)
        
#-----------------------------------------
# Policy Training  
#-----------------------------------------
    if flag == 'update_alpha':
        start_time = time.time()        

        with trange( +1, stop_iter_a+1 , initial = 0, total = stop_iter_a,  position=0,
                     leave= leave, desc=f"Epoch {current_epoch} policy training") as t_policy :
            for current_iter_a in t_policy:    
                current_iter += 1
                batch = next(policy_trn_loader)

                environ.set_inputs(batch, policy_trn_loader.dataset.input_size)

                environ.optimize(opt['lambdas'], is_policy=opt['policy'], 
                                 flag=flag, num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'], verbose = False)
                
                t_policy.set_postfix({'it' : current_iter,
                                      'Lss': f"{environ.losses['losses']['total'].item():.4f}",
                                      'Spr': f"{environ.losses['sparsity']['total'].item():.4e}",
                                      'Shr': f"{environ.losses['sharing']['total'].item():.4e}"})
#                                       ,'lyrs': f"{num_train_layers}"})    
#                                       ,'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
#                 if current_iter % 100 == 0 :
#                     wandb.log(environ.losses)

        # print loss results (here current_iter_w and stop_iter_w are equal)
        trn_losses = environ.losses
        environ.print_trn_metrics(current_epoch, current_iter, start_time, title = f"[Policy Trn]")
        wandb.log(environ.losses)
        
        #--------------------------------------------------------------------
        # validation process (here current_iter_a and stop_iter_a are equal)
        #--------------------------------------------------------------------        
        val_metrics = environ.evaluate(val_loader, 
                                       is_policy=opt['policy'],
                                       policy_sampling = 'eval',
                                       num_train_layers=num_train_layers, 
                                       hard_sampling=opt['train']['hard_sampling'],
                                       eval_iters = eval_iters, progress = True, leave = False, verbose = False)  

        environ.print_val_metrics(current_epoch, current_iter, start_time, title = f"[Policy Val]", verbose = False)
        print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics, line_count, out=[sys.stdout, environ.log_file])      
        line_count +=1
        #-----------------------------------------------------------------------
        # END validation process 
        #-----------------------------------------------------------------------        
        
        p_epoch += 1        
        if should(p_epoch, opt['train']['decay_temp_freq']):
            environ.decay_temperature()
            print(f" decay gumbel softmax to {environ.gumbel_temperature}")
        
        flag = 'update_w'
        environ.fix_alpha()
        environ.free_weights(opt['fix_BN'])
        
        environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
#         environ.display_trained_logits(current_epoch)        
#         print_loss(current_epoc, current_iter, environ.val_metrics, title = f"[Policy trn]  ep:{current_epoch}   it:{current_iter}")
    
    #-----------------------------------------
    # End Policy Training  
    #----------------------------------------- 
    if should(current_epoch, 5):
        environ.save_checkpoint('model_latest_weights_policy', current_iter)        
        print_loss(environ.val_metrics, title = f"\n[e] Policy training epoch:{current_epoch}  it:{current_iter}")
        environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
        environ.log_file.flush()
        line_count = 0

In [44]:
model_label   = 'model_warmup_ep_%d_seed_%04d' % (current_epoch, opt['random_seed'])
metrics_label = 'metrics_warmup_ep_%d_seed_%04d.pickle' % (current_epoch, opt['random_seed'])
environ.save_checkpoint(model_label, current_iter, current_epoch) 
save_to_pickle(environ.val_metrics, environ.opt['paths']['checkpoint_dir'], metrics_label)
print_loss(environ.val_metrics, title = f"[Final] ep:{current_epoch}  it:{current_iter}",)
environ.display_trained_policy(current_epoch,out=[sys.stdout, environ.log_file])
environ.display_trained_logits(current_epoch)
environ.log_file.flush()


 epch: 302   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.7059     0.2941   1     0.6700     0.3300   1     0.7576     0.2424   1
   2    0.9023     0.0977   1     0.9229     0.0771   1     0.9109     0.0891   1
   3    0.9175     0.0825   1     0.9436     0.0564   1     0.9395     0.0605   1
   4    0.7662     0.2338   1     0.8105     0.1895   1     0.7340     0.2660   1
   5    0.6771     0.3229   1     0.9040     0.0960   1     0.8325     0.1675   1
   6    0.4055     0.5945   0     0.3965     0.6035   0     0.6284     0.3716   1


------------------------------------------------------------------------------------------------------------------------
 Last Epoch Completed: 302   # of epochs to do:  10 -  epochs 303 to 312
 policy_lr           : 0.01
 lambda_sparsity     : 0.1
 lambda_sharing      : 0.001
-----------------------------------------------------------

                                                                                                                        

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar trn total |   logloss   bceloss    aucroc     aucpr |  val loss     val spar     val shar  val total |  time |
  303 |   5.22e-04   5.22e-04   6.14e-03  1.056e-01 |    1.0073   5.1581e+00   3.3268e-02    6.1987 |   0.00017   0.77321   0.80570   0.80939 |   11.5919   5.1581e+00   3.3268e-02    16.7833 |  30.8 |


                                                                                                                        

  303 |   5.22e-04   5.22e-04   5.22e-03  1.056e-01 |    1.0719   4.3419e+00   5.1179e-02    5.4650 |   0.00017   0.77025   0.80573   0.80949 |   11.5467   4.3372e+00   5.1426e-02    15.9353 |  37.7 |

 epch: 303   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.6056     0.3944   1     0.5626     0.4374   1     0.6671     0.3329   1
   2    0.8708     0.1292   1     0.8891     0.1109   1     0.8866     0.1134   1
   3    0.8956     0.1044   1     0.9269     0.0731   1     0.9239     0.0761   1
   4    0.6771     0.3229   1     0.7268     0.2732   1     0.6526     0.3474   1
   5    0.5993     0.4007   1     0.8720     0.1280   1     0.7642     0.2358   1
   6    0.2841     0.7159   0     0.3163     0.6837   0     0.4919     0.5081   0




                                                                                                                        

  304 |   5.22e-04   5.22e-04   5.22e-03  1.056e-01 |    0.9896   4.3372e+00   5.1426e-02    5.3782 |   0.00017   0.77122   0.80601   0.81001 |   11.5611   4.3372e+00   5.1426e-02    15.9496 |  31.6 |


                                                                                                                        

  304 |   5.22e-04   5.22e-04   5.22e-03  1.056e-01 |    0.9935   3.7162e+00   3.2332e-02    4.7421 |   0.00017   0.77089   0.80596   0.80998 |   11.5561   3.7114e+00   3.2704e-02    15.3003 |  39.8 |
 decay gumbel softmax to 0.10194672193578354

 epch: 304   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.5212     0.4788   1     0.4724     0.5276   0     0.5757     0.4243   1
   2    0.8304     0.1696   1     0.8572     0.1428   1     0.8406     0.1594   1
   3    0.8767     0.1233   1     0.9092     0.0908   1     0.9028     0.0972   1
   4    0.5822     0.4178   1     0.6519     0.3481   1     0.5745     0.4255   1
   5    0.5675     0.4325   1     0.8218     0.1782   1     0.7162     0.2838   1
   6    0.2269     0.7731   0     0.2488     0.7512   0     0.3976     0.6024   0




                                                                                                                        

  305 |   5.22e-04   5.22e-04   5.22e-03  1.019e-01 |    1.0162   3.7114e+00   3.2704e-02    4.7603 |   0.00017   0.77212   0.80595   0.80969 |   11.5751   3.7114e+00   3.2704e-02    15.3192 |  37.4 |


                                                                                                                        

  305 |   5.22e-04   5.22e-04   5.22e-03  1.019e-01 |    0.9923   3.3011e+00   8.0778e-02    4.3743 |   0.00017   0.77083   0.80602   0.80965 |   11.5562   3.2988e+00   8.2866e-02    14.9378 |  40.4 |

 epch: 305   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.4658     0.5342   0     0.3960     0.6040   0     0.5274     0.4726   1
   2    0.7968     0.2032   1     0.8342     0.1658   1     0.7975     0.2025   1
   3    0.8750     0.1250   1     0.8764     0.1236   1     0.8839     0.1161   1
   4    0.5266     0.4734   1     0.5744     0.4256   1     0.5038     0.4962   1
   5    0.5236     0.4764   1     0.7995     0.2005   1     0.6579     0.3421   1
   6    0.1895     0.8105   0     0.2106     0.7894   0     0.3496     0.6504   0



[e] Policy training epoch:305  it:62645 -  Total Loss: 14.9378     
Task: 11.5562   Sparsity: 3.29875e+00    Sharing: 8.28657e-02 

 ep

                                                                                                                        

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar trn total |   logloss   bceloss    aucroc     aucpr |  val loss     val spar     val shar  val total |  time |
  306 |   5.22e-04   5.22e-04   5.22e-03  1.019e-01 |    1.0568   3.2988e+00   8.2866e-02    4.4384 |   0.00017   0.77271   0.80601   0.80984 |   11.5839   3.2988e+00   8.2866e-02    14.9655 |  34.2 |


                                                                                                                        

  306 |   5.22e-04   5.22e-04   5.22e-03  1.019e-01 |    1.1142   3.0225e+00   9.1317e-02    4.2280 |   0.00017   0.76019   0.80631   0.81010 |   11.3956   3.0198e+00   9.0519e-02    14.5058 |  42.0 |
 decay gumbel softmax to 0.0983785866680311

 epch: 306   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.4155     0.5845   0     0.3501     0.6499   0     0.4987     0.5013   0
   2    0.7660     0.2340   1     0.8107     0.1893   1     0.7836     0.2164   1
   3    0.8601     0.1399   1     0.8537     0.1463   1     0.8683     0.1317   1
   4    0.4876     0.5124   0     0.5524     0.4476   1     0.4584     0.5416   0
   5    0.4891     0.5109   0     0.7700     0.2300   1     0.6040     0.3960   1
   6    0.1713     0.8287   0     0.1838     0.8162   0     0.2974     0.7026   0




                                                                                                                        

  307 |   5.22e-04   5.22e-04   5.22e-03  9.838e-02 |    1.0681   3.0198e+00   9.0519e-02    4.1784 |   0.00017   0.75930   0.80622   0.80982 |   11.3820   3.0198e+00   9.0519e-02    14.4923 |  30.8 |


                                                                                                                        

  307 |   5.22e-04   5.22e-04   5.22e-03  9.838e-02 |    0.9268   2.7388e+00   8.1611e-02    3.7472 |   0.00017   0.75930   0.80622   0.80982 |   11.3820   2.7364e+00   8.1994e-02    14.2004 |  37.9 |

 epch: 307   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.3674     0.6326   0     0.3366     0.6634   0     0.4422     0.5578   0
   2    0.7340     0.2660   1     0.7842     0.2158   1     0.7422     0.2578   1
   3    0.8358     0.1642   1     0.8214     0.1786   1     0.8529     0.1471   1
   4    0.4523     0.5477   0     0.5112     0.4888   1     0.4181     0.5819   0
   5    0.4346     0.5654   0     0.7511     0.2489   1     0.5703     0.4297   1
   6    0.1453     0.8547   0     0.1670     0.8330   0     0.2609     0.7391   0




                                                                                                                        

  308 |   5.22e-04   5.22e-04   5.22e-03  9.838e-02 |    1.1077   2.7364e+00   8.1994e-02    3.9260 |   0.00017   0.75998   0.80630   0.81005 |   11.3930   2.7364e+00   8.1994e-02    14.2114 |  32.3 |


                                                                                                                        

  308 |   5.22e-04   5.22e-04   5.22e-03  9.838e-02 |    0.9756   2.5267e+00   7.5562e-02    3.5779 |   0.00017   0.75306   0.80656   0.81036 |   11.2888   2.5245e+00   7.4913e-02    13.8882 |  37.0 |
 decay gumbel softmax to 0.09493533613465001

 epch: 308   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.3287     0.6713   0     0.2966     0.7034   0     0.4001     0.5999   0
   2    0.7100     0.2900   1     0.7591     0.2409   1     0.7019     0.2981   1
   3    0.8138     0.1862   1     0.7925     0.2075   1     0.8365     0.1635   1
   4    0.4327     0.5673   0     0.4848     0.5152   0     0.3831     0.6169   0
   5    0.4287     0.5713   0     0.7349     0.2651   1     0.5367     0.4633   1
   6    0.1322     0.8678   0     0.1582     0.8418   0     0.2334     0.7666   0




                                                                                                                        

  309 |   4.44e-04   4.44e-04   5.22e-03  9.494e-02 |    1.0159   2.5245e+00   7.4913e-02    3.6153 |   0.00016   0.75288   0.80655   0.81011 |   11.2855   2.5245e+00   7.4913e-02    13.8849 |  29.7 |


                                                                                                                        

  309 |   4.44e-04   4.44e-04   5.22e-03  9.494e-02 |    0.9990   2.3066e+00   1.0192e-01    3.4075 |   0.00016   0.74546   0.80658   0.81000 |   11.1739   2.3050e+00   1.0269e-01    13.5816 |  36.4 |

 epch: 309   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.3144     0.6856   0     0.2685     0.7315   0     0.3610     0.6390   0
   2    0.6906     0.3094   1     0.7067     0.2933   1     0.6973     0.3027   1
   3    0.7914     0.2086   1     0.7635     0.2365   1     0.8095     0.1905   1
   4    0.3974     0.6026   0     0.4701     0.5299   0     0.3507     0.6493   0
   5    0.3875     0.6125   0     0.6879     0.3121   1     0.4900     0.5100   0
   6    0.1257     0.8743   0     0.1449     0.8551   0     0.2117     0.7883   0




                                                                                                                        

  310 |   4.44e-04   4.44e-04   5.22e-03  9.494e-02 |    1.0287   2.3050e+00   1.0269e-01    3.4363 |   0.00016   0.74749   0.80663   0.81030 |   11.2045   2.3050e+00   1.0269e-01    13.6122 |  30.9 |


                                                                                                                        

  310 |   4.44e-04   4.44e-04   5.22e-03  9.494e-02 |    1.0759   2.1969e+00   1.3323e-01    3.4060 |   0.00016   0.74749   0.80663   0.81030 |   11.2045   2.1964e+00   1.3459e-01    13.5355 |  35.7 |
 decay gumbel softmax to 0.09161259936993725

 epch: 310   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.2815     0.7185   0     0.2464     0.7536   0     0.3279     0.6721   0
   2    0.7040     0.2960   1     0.6866     0.3134   1     0.7017     0.2983   1
   3    0.7762     0.2238   1     0.7366     0.2634   1     0.7809     0.2191   1
   4    0.4207     0.5793   0     0.4355     0.5645   0     0.3122     0.6878   0
   5    0.3803     0.6197   0     0.6785     0.3215   1     0.4692     0.5308   0
   6    0.1232     0.8768   0     0.1404     0.8596   0     0.1889     0.8111   0



[e] Policy training epoch:310  it:63725 -  Total Loss: 13.5355     
Task: 11.2045   Sparsi

                                                                                                                        

Epoch | BckBone LR   Heads LR  Policy LR Gumbl Temp |  trn loss     trn spar     trn shar trn total |   logloss   bceloss    aucroc     aucpr |  val loss     val spar     val shar  val total |  time |
  311 |   4.44e-04   4.44e-04   5.22e-03  9.161e-02 |    1.1216   2.1964e+00   1.3459e-01    3.4526 |   0.00016   0.74868   0.80647   0.81007 |   11.2217   2.1964e+00   1.3459e-01    13.5527 |  31.5 |


                                                                                                                        

  311 |   4.44e-04   4.44e-04   5.22e-03  9.161e-02 |    1.1043   2.1418e+00   1.0264e-01    3.3487 |   0.00017   0.75625   0.80640   0.81015 |   11.3355   2.1418e+00   1.0289e-01    13.5801 |  35.9 |

 epch: 311   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.2592     0.7408   0     0.2319     0.7681   0     0.3013     0.6987   0
   2    0.6769     0.3231   1     0.6976     0.3024   1     0.6788     0.3212   1
   3    0.7638     0.2362   1     0.7424     0.2576   1     0.7904     0.2096   1
   4    0.4188     0.5812   0     0.4206     0.5794   0     0.2844     0.7156   0
   5    0.3599     0.6401   0     0.6617     0.3383   1     0.5015     0.4985   1
   6    0.1102     0.8898   0     0.1332     0.8668   0     0.1777     0.8223   0




                                                                                                                        

  312 |   4.44e-04   4.44e-04   5.22e-03  9.161e-02 |    1.1877   2.1418e+00   1.0289e-01    3.4324 |   0.00017   0.75629   0.80658   0.81043 |   11.3367   2.1418e+00   1.0289e-01    13.5814 |  31.0 |


                                                                                                                        

  312 |   4.44e-04   4.44e-04   5.22e-03  9.161e-02 |    1.1739   2.0732e+00   1.6473e-01    3.4118 |   0.00016   0.74864   0.80664   0.81035 |   11.2217   2.0733e+00   1.6242e-01    13.4574 |  35.7 |
 decay gumbel softmax to 0.08840615839198944

 epch: 312   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.2571     0.7429   0     0.2193     0.7807   0     0.2710     0.7290   0
   2    0.6807     0.3193   1     0.7054     0.2946   1     0.6950     0.3050   1
   3    0.7350     0.2650   1     0.7467     0.2533   1     0.7745     0.2255   1
   4    0.3836     0.6164   0     0.3935     0.6065   0     0.2727     0.7273   0
   5    0.3529     0.6471   0     0.6274     0.3726   1     0.4970     0.5030   0
   6    0.1009     0.8991   0     0.1219     0.8781   0     0.1719     0.8281   0


[Final] ep:312  it:64157 -  Total Loss: 13.4574     
Task: 11.2217   Sparsity: 2.07326e+00 

In [36]:
print_loss(environ.val_metrics, title = f"[Final] ep:{current_epoch}  it:{current_iter}")
# environ.display_trained_policy(current_epoch)
# environ.display_trained_logits(current_epoch)
# environ.log_file.flush()

[Final] ep:287  it:58757 -  Total Loss: 15.5793     
Task: 11.2231   Sparsity: 4.35058e+00    Sharing: 5.58868e-03 


In [246]:
print( f" Backbone Learning Rate      : {environ.opt['train']['backbone_lr']}\n"
       f" Tasks    Learning Rate      : {environ.opt['train']['task_lr']}\n"
       f" Policy   Learning Rate      : {environ.opt['train']['policy_lr']}\n")


print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" #
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}") #

print( f" current_iters               : {current_iter}")  
print( f" current_epochs              : {current_epoch}") 
print( f" train_total_epochs          : {train_total_epochs}") 
print( f" stop_epoch_training         : {stop_epoch_training}")

 Backbone Learning Rate      : 0.001
 Tasks    Learning Rate      : 0.001
 Policy   Learning Rate      : 0.01

 Sparsity regularization     : 0.0
 Sharing  regularization     : 0.0001 
 Tasks    regularization     : 1.0   
 Gumbel Temp                 : 0.7234         
 Gumbel Temp decay           : 2


In [247]:
environ.opt['train']['lambda_sparsity'] = 0.0000
environ.opt['train']['lambda_sharing']  = 0.001
environ.opt['train']['lambda_tasks']    = 1.0
environ.opt['train']['decay_temp_freq'] = 2

In [250]:
print( f" Sparsity regularization     : {environ.opt['train']['lambda_sparsity']}\n"
       f" Sharing  regularization     : {environ.opt['train']['lambda_sharing']} \n"
       f" Tasks    regularization     : {environ.opt['train']['lambda_tasks']}   \n"
       f" Gumbel Temp                 : {environ.gumbel_temperature:.4f}         \n" #
       f" Gumbel Temp decay           : {environ.opt['train']['decay_temp_freq']}\n") #


print( f" current_iters               : {current_iter}")  
print( f" current_epochs              : {current_epoch}") 
print( f" train_total_epochs          : {train_total_epochs}") 
print( f" stop_epoch_training         : {stop_epoch_training}")

 Sparsity regularization     : 0.0
 Sharing  regularization     : 0.001 
 Tasks    regularization     : 1.0   
 Gumbel Temp                 : 0.7234         
 Gumbel Temp decay           : 2

 current_iters               : 38669
 current_epochs              : 192
 train_total_epochs          : 5
 stop_epoch_training         : 192


In [17]:
# train_total_epochs = 50
stop_epoch_training = current_epoch + train_total_epochs     
print(f"current_iters         : {current_iter}")  
print(f"current_epochs        : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 
print(f"stop_epoch_training   : {stop_epoch_training}")

current_iters         : 10908
current_epochs        : 51
train_total_epochs    : 50
stop_epoch_training   : 101


In [17]:
# pp.pprint(environ.losses)
# pp.pprint(environ.val_metrics)

In [23]:
# print_loss(current_iter, environ.losses, title = f"[e] Policy training epoch:{current_epoch}    iter:")
# print()
# print_loss(current_iter, trn_losses, title = f"[e] Policy training epoch:{current_epoch}    iter:")
# print()
# print_loss(current_iter, environ.val_metrics, title = f"[e] Policy training epoch:{current_epoch}    iter:")

In [24]:
# print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics, 0, out=[sys.stdout])

In [2]:
# environ.losses
# environ.val_metrics
# environ.batch_data
# environ.display_parameters()
# environ_params = environ.get_task_specific_parameters()
# environ_params = environ.get_arch_parameters()
# environ_params = environ.get_backbone_parameters()
# print(environ_params)
# for param in environ_params:
#     print(param.grad.shape, '\n', param.grad)
#     print(param)

In [1]:
# print(environ.optimizers['alphas'])
# print(environ.optimizers['weights'])
# print(environ.optimizers['weights'].param_groups)

In [None]:
with np.printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan', precision=7, formatter={'float': lambda x: f"{x:12.5e}"}):
    environ.print_logit_grads('gradients')


In [136]:
environ.display_trained_logits(current_epoch)
environ.display_trained_policy(current_epoch)


 epch: 312   logits        sel          logits       sel         logits        sel 
 -----  -----------------  ---    ----------------   ---    ----------------   --- 
   1   -0.4585     0.6026   0    -0.4601     0.8096   0    -0.4597     0.5298   0
   2   -0.1734    -0.9305   1    -0.1718    -1.0447   1    -0.1615    -0.9853   1
   3   -0.1602    -1.1804   1    -0.1241    -1.2051   1    -0.1391    -1.3729   1
   4   -0.3442     0.1300   0    -0.3490     0.0837   0    -0.3420     0.6391   0
   5   -0.2803     0.3260   0    -0.2921    -0.8131   1    -0.2676    -0.2554   0
   6   -0.4716     1.7158   0    -0.4718     1.5023   0    -0.4718     1.1004   0



 epch: 312   softmax       sel        softmax        sel        softmax        sel 
 -----  -----------------  ---    -----------------  ---    -----------------  --- 
   1    0.2571     0.7429   0     0.2193     0.7807   0     0.2710     0.7290   0
   2    0.6807     0.3193   1     0.7054     0.2946   1     0.6950     0.3050   1
   3

In [123]:
environ.display_test_sample_policy(current_epoch, hard_sampling = True)
environ.display_train_sample_policy(current_epoch, hard_sampling = True)

 Sample Policy (Testing mode - hard_sampling: True) 
 312 epochs  logits         sel        logits         sel         logits         sel 
 -----   ----------------  -----    ----------------  -----    ---------------   ----- 
   1    -0.4585    0.6026  [0 1]   -0.4601    0.8096  [0 1]   -0.4597    0.5298  [0 1]
   2    -0.1734   -0.9305  [1 0]   -0.1718   -1.0447  [1 0]   -0.1615   -0.9853  [1 0]
   3    -0.1602   -1.1804  [1 0]   -0.1241   -1.2051  [1 0]   -0.1391   -1.3729  [1 0]
   4    -0.3442    0.1300  [0 1]   -0.3490    0.0837  [0 1]   -0.3420    0.6391  [0 1]
   5    -0.2803    0.3260  [0 1]   -0.2921   -0.8131  [1 0]   -0.2676   -0.2554  [0 1]
   6    -0.4716    1.7158  [0 1]   -0.4718    1.5023  [0 1]   -0.4718    1.1004  [0 1]


 Sample Policy (Training mode - hard_sampling: True) 
 312 epochs    logits          gumbel                logits           gumbel               logits             gumbel 
 -----   ----------------------------------     -----------------------------

In [117]:
environ.display_test_sample_policy(current_epoch, hard_sampling = False)
environ.display_train_sample_policy(current_epoch, hard_sampling = False)

 Sample Policy (Testing mode - hard_sampling: False) 
 312 epochs  logits         sel        logits         sel         logits         sel 
 -----   ----------------  -----    ----------------  -----    ---------------   ----- 
   1    -0.4585    0.6026  [0 1]   -0.4601    0.8096  [0 1]   -0.4597    0.5298  [1 0]
   2    -0.1734   -0.9305  [1 0]   -0.1718   -1.0447  [0 1]   -0.1615   -0.9853  [1 0]
   3    -0.1602   -1.1804  [1 0]   -0.1241   -1.2051  [1 0]   -0.1391   -1.3729  [1 0]
   4    -0.3442    0.1300  [1 0]   -0.3490    0.0837  [0 1]   -0.3420    0.6391  [1 0]
   5    -0.2803    0.3260  [0 1]   -0.2921   -0.8131  [1 0]   -0.2676   -0.2554  [0 1]
   6    -0.4716    1.7158  [1 0]   -0.4718    1.5023  [0 1]   -0.4718    1.1004  [0 1]


 Sample Policy (Training mode - hard_sampling: False) 
 312 epochs    logits          gumbel                logits           gumbel               logits             gumbel 
 -----   ----------------------------------     ---------------------------

In [137]:
environ.opt['train']['hard_sampling']

True

In [12]:
# wandb.finish()

## Post Training Stuff

In [2]:
wandb.finish()

NameError: name 'wandb' is not defined

In [9]:
# 
p = environ.get_current_state(0)

In [11]:
pp.pprint(p)

{   'alphas': {   'param_groups': [   {   'amsgrad': False,
                                          'betas': (0.9, 0.999),
                                          'eps': 1e-08,
                                          'initial_lr': 0.01,
                                          'lr': 0.01,
                                          'params': [0, 1, 2],
                                          'weight_decay': 0.0005}],
                  'state': {}},
    'iter': 0,
    'mtl-net': OrderedDict([   (   'task1_logits',
                                   tensor([[ 0.001544,  0.002121],
        [ 0.000495,  0.001111],
        [ 0.000426, -0.000401],
        [ 0.000426, -0.000470],
        [ 0.000497,  0.000899],
        [ 0.002776, -0.000226]])),
                               (   'task2_logits',
                                   tensor([[ 0.000482, -0.001070],
        [-0.000505, -0.001076],
        [ 0.000701, -0.000132],
        [-0.000171, -0.001668],
        [ 0.000262, -0.000324]

### Post Warm-up Training stuff

In [15]:
pp.pprint(environ.val_metrics)

{   'aggregated': {   'auc_pr': 0.5683542231355692,
                      'avg_prec_score': 0.5686955149511226,
                      'bceloss': 0.6940069357554117,
                      'f1_max': 0.6753742658177068,
                      'kappa': 0.08318090860174578,
                      'kappa_max': 0.11385237142207795,
                      'logloss': tensor(0.0002, device='cuda:0', dtype=torch.float64),
                      'p_f1_max': 0.29707588851451877,
                      'p_kappa_max': 0.48534467220306404,
                      'roc_auc_score': 0.5679242272843051,
                      'sc_loss': tensor(0.2891, device='cuda:0', dtype=torch.float64)},
    'epoch': 1,
    'loss': {   'task1': 3.5098743182605396,
                'task2': 3.4508372449718867,
                'task3': 3.4480013587242775,
                'total': 10.408712921956704},
    'loss_mean': {   'task1': 0.701974863652108,
                     'task2': 0.6901674489943774,
                     'task3': 0.

In [16]:
environ.networks['mtl-net'].arch_parameters()

[Parameter containing:
 tensor([[0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000]], device='cuda:0'),
 Parameter containing:
 tensor([[0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000]], device='cuda:0'),
 Parameter containing:
 tensor([[0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5000, 0.5000]], device='cuda:0')]

In [31]:
p = environ.get_sample_policy(hard_sampling = False)
print(p)
p = environ.get_policy_prob()
print(p)
p = environ.get_policy_logits()
print(p)

# p = environ.get_current_policy()
# print(p)

([tensor([[1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [0, 1],
        [0, 1]], device='cuda:0'), tensor([[0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1]], device='cuda:0'), tensor([[1, 0],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [0, 1]], device='cuda:0')], [array([[0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5]], dtype=float32), array([[0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5]], dtype=float32), array([[0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5]], dtype=float32)])
[array([[0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5]], dtype=float32), array([[0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5],
       [0.5, 0.5]], dtype=floa

In [42]:
a = softmax([0.0, 1])
print(a)
sampled = np.random.choice((1, 0), p=a)
print(sampled)

[0.26894142 0.73105858]
1


In [20]:
print(environ.optimizers['weights'])
print(environ.schedulers['weights'].get_last_lr())

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    initial_lr: 0.01
    lr: 0.01
    weight_decay: 0.0001

Parameter Group 1
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    initial_lr: 0.01
    lr: 0.01
    weight_decay: 0.0001
)
[0.01, 0.01]


In [17]:
print('losses.keys      : ', environ.losses.keys())
print('losses[task]keys : ', environ.losses['task1'].keys())
pp.pprint(environ.losses)

losses.keys      :  dict_keys(['parms', 'losses', 'losses_mean', 'sparsity', 'sharing', 'total', 'total_mean', 'task1', 'task2', 'task3'])
losses[task]keys :  dict_keys(['cls_loss', 'cls_loss_mean'])
{   'losses': {   'task1': tensor(3.5704, device='cuda:0', dtype=torch.float64),
                  'task2': tensor(3.3313, device='cuda:0', dtype=torch.float64),
                  'task3': tensor(3.3721, device='cuda:0', dtype=torch.float64),
                  'total': tensor(10.2737, device='cuda:0', dtype=torch.float64)},
    'losses_mean': {   'task1': tensor(0.7141, device='cuda:0', dtype=torch.float64),
                       'task2': tensor(0.6663, device='cuda:0', dtype=torch.float64),
                       'task3': tensor(0.6744, device='cuda:0', dtype=torch.float64),
                       'total': tensor(2.0547, device='cuda:0', dtype=torch.float64)},
    'parms': {'gumbel_temp': 5, 'lr_0': 0.01, 'lr_1': 0.01, 'train_layers': 0},
    'sharing': {'total': tensor(0., device='cuda:

In [20]:
print( environ.val_metrics.keys())
# pp.pprint(val_metrics)
print(type(environ.val_metrics['aggregated']))
print()
print(type(environ.val_metrics['task1']['classification_agg']))
print()
pp.pprint(environ.val_metrics)

dict_keys(['loss', 'loss_mean', 'task1', 'task2', 'task3', 'aggregated', 'train_time', 'epoch'])
<class 'dict'>

<class 'dict'>

{   'aggregated': {   'auc_pr': 0.8163426647572675,
                      'avg_prec_score': 0.8164138615123665,
                      'bceloss': 0.7729340473810831,
                      'f1_max': 0.7574111413660235,
                      'kappa': 0.4593877004350045,
                      'kappa_max': 0.4722816344851592,
                      'logloss': tensor(0.0002, device='cuda:0', dtype=torch.float64),
                      'p_f1_max': 0.1918067594369252,
                      'p_kappa_max': 0.545021508137385,
                      'roc_auc_score': 0.8101998985724014,
                      'sc_loss': tensor(0.3217, device='cuda:0', dtype=torch.float64)},
    'epoch': 4000,
    'loss': {   'task1': 3.990690022259455,
                'task2': 3.7726694706664947,
                'task3': 3.817710672775008,
                'total': 11.581070165700957},
    'l

In [20]:
# import pickle
# with open("val_metrics.pkl", mode= 'wb') as f:
#         pickle.dump(val_metrics, f)
    
# with open('val_metrics.pkl', 'rb') as f:    
#     tst_val_metrics = pickle.load(f)

In [21]:
# print(environ.input.shape) 
# a = getattr(environ, 'task1_pred')
# yc_data = environ.batch['task1_data']
# print(yc_data.shape)
# yc_ind = environ.batch['task1_ind']
# print(yc_ind.shape)
# yc_hat_all = getattr(environ, 'task1_pred')
# print(yc_hat_all.shape)
# yc_hat  = yc_hat_all[yc_ind[0], yc_ind[1]]
# print(yc_hat_all.shape, yc_hat.shape)

# 
# environ.losses
# loss = {}
# for key in environ.losses.keys():
#     loss[key] = {}
#     for subkey, v in environ.losses[key].items():
#         print(f" key:  {key}   subkey: {subkey} ")
#         if isinstance(v, torch.Tensor):
#             loss[key][subkey] = v.data
#             print(f" Tensor  -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
#         else:
#             loss[key][subkey] = v
#             print(f" integer -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
# pp.pprint(tst_val_metrics)             

In [22]:
# print('metrics.keys: ', environ.metrics.keys())
# print('metrics[task].keys: ', environ.metrics['task1'].keys())
# pp.pprint(environ.metrics['task1'])
# pp.pprint(environ.losses['task1']['total'])

In [23]:
# title='Iteration'
# for t_id, _ in enumerate(environ.tasks):
#     task_key = f"task{t_id+1}"
# #     print_heading(f"{title}  {current_iter}  {task_key} : {val_metrics[task_key]['classification_agg']}", verbose = True)

#     for key, _  in val_metrics[task_key]['classification_agg'].items():
#         print('%s/%-20s'%(task_key, key), val_metrics[task_key]['classification_agg'][key], current_iter)
#         print(f"{task_key:s}/{key:20s}", val_metrics[task_key]['classification_agg'][key], current_iter)
#         print()
#             # print_current_errors(os.path.join(self.log_dir, 'loss.txt'), current_iter,key, loss[key], time.time() - start_time)

In [32]:
# environ.print_loss(current_iter, start_time, metrics = val_metrics['loss'], verbose=True)
# print(opt['lambdas'])
# p = (opt['lambdas'][0] * environ.losses['tasks']['task1'])
# print(p)

# environ.print_val_metrics(current_iter, start_time, val_metrics , title='validation', verbose=True)    

In [30]:
# print(current_iter)
# print_metrics_cr(current_iter, t1 - t0, None, val_metrics , True)
# environ.print_val_metrics(current_iter, start_time, val_metrics, title='validation', verbose = True)

In [21]:
print(f" val_metric keys               : {val_metrics.keys()}")
print(f" loss keys                     : {val_metrics['loss'].keys()}")
print(f" task1 keys                    : {val_metrics['task1'].keys()}")
print(f" task1 classification keys     : {val_metrics['task1']['classification'].keys()}")
print(f" task1 classification_agg keys : {val_metrics['task1']['classification_agg'].keys()}")
print()
print(f" task1                       : {val_metrics['task1']['classification_agg']['loss']:5f}")
print(f" task2                       : {val_metrics['task2']['classification_agg']['loss']:5f}")
print(f" task3                       : {val_metrics['task3']['classification_agg']['loss']:5f}")
print(f" loss                        : {val_metrics['loss']['total']:5f}")
print(f" train_time                  : {val_metrics['train_time']:2f}")
print(f" epoch                       : {val_metrics['epoch']}")


AttributeError: 'NoneType' object has no attribute 'keys'

### Post Weight + Policy Training Stuff 

In [178]:
environ.networks['mtl-net'].backbone.layer_config

[1, 1, 1, 1, 1, 1]

In [63]:
num_blocks = 6
num_policy_layers = 6
gt =  torch.ones((num_blocks)).long()
gt0 =  torch.zeros((num_blocks)).long()
print(gt)
print(gt0)

loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
print(loss_weights)

tensor([1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 0, 0])
tensor([0.1667, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000])


In [77]:
if environ.opt['diff_sparsity_weights'] and not environ.opt['is_sharing']:
    print(' cond 1')
    ## Assign higher weights to higher layers 
    loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
    print(f"{task_key} sparsity error:  {2 * (loss_weights[-num_blocks:] * environ.cross_entropy2(logits[-num_blocks:], gt)).mean()})")
    print_dbg(f" loss_weights :  {loss_weights}", verbose = True)
    print_dbg(f" cross_entropy:  {environ.cross_entropy2(logits[-num_blocks:], gt)}  ", verbose = True)
    print_dbg(f" loss[sparsity][{task_key}]: {self.losses['sparsity'][task_key] } ", verbose = True)

else:
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between \n Logits   : \n{logits[-num_blocks:]} \n and gt: \n{gt} \n", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-num_blocks:], gt)}")
    
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt[-1:])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt0[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt0[-1:])} \n")
    
    print('\n cond 3')    
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt[0:1])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt0[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt0[0:1])} \n")
        
        

cond 2
Compute CrossEntropyLoss between 
 Logits   : 
tensor([[0.3306, 0.4518],
        [0.3532, 0.5529],
        [0.3888, 0.6125],
        [0.4204, 0.7685],
        [0.4520, 0.7994],
        [0.4840, 0.8021]]) 
 and gt: 
tensor([1, 1, 1, 1, 1, 1]) 

task1_logits sparsity error:  0.5725929141044617

 cond 2
Compute CrossEntropyLoss between Logits      : tensor([[0.4840, 0.8021]])  and gt: 1 
task1_logits sparsity error:  0.5467103123664856 

Compute CrossEntropyLoss between Logits      : tensor([[0.4840, 0.8021]])  and gt: 0 
task1_logits sparsity error:  0.864768385887146 


 cond 3
Compute CrossEntropyLoss between Logits   : tensor([[0.3306, 0.4518]])  and gt: tensor([1]) 
task1_logits sparsity error:  0.634384036064148 

Compute CrossEntropyLoss between Logits   : tensor([[0.3306, 0.4518]])  and gt: tensor([0]) 
task1_logits sparsity error:  0.7555801868438721 



In [24]:
# flag = 'update_w'
# environ.fix_alpha
# environ.free_w(opt['fix_BN'])

flag = 'update_alpha'
environ.fix_weights()
environ.free_alpha()

In [29]:
environ.networks['mtl-net'].num_layers

6

In [25]:
print(f"current_iters         : {current_iter}")  
print(f"current_epochs           : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

train_total_epochs += 5

print(f"current_iters         : {current_iter}")  
print(f"current_epochs           : {current_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

current_iters         : 6580
curr_epochs           : 60
train_total_epochs    : 60


In [25]:
# print_metrics_cr(current_epoch, time.time() - t0, None, environ.val_metrics , num_prints)      

# num_prints += 1
# t0 = time.time()

# # Take check point
# environ.save_checkpoint('latest', current_iter)
# environ.train()
# #-------------------------------------------------------
# # END validation process
# #-------------------------------------------------------       
# flag = 'update_alpha'
# environ.fix_w()
# environ.free_alpha()

Epoch | logloss bceloss  aucroc   aucpr  f1_max| t1 loss t2 loss t3 lossttl loss|tr_time|
1     | 0.00020 0.89942 0.81016 0.81627 0.75738|  4.6591  4.3649  4.4500 13.4740| 333.4|

In [71]:
# dilation = 2
# kernel_size = np.asarray((3, 3))
# upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
# print(upsampled_kernel_size)

In [30]:
# environ.optimizers['weights'].param_groups[0]
# for param_group in optimizer.param_groups:
#     return param_group['lr']

In [31]:
environ.schedulers['weights'].get_last_lr()

[1e-05, 1e-05]

In [86]:
current_state = {}
for k, v in environ.optimizers.items():
    print(f'state dict for {k} = {v}')
    current_state[k] = v.state_dict()
pp.pprint(current_state)

state dict for weights = SGD (
Parameter Group 0
    dampening: 0
    lr: 0.0001
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001

Parameter Group 1
    dampening: 0
    lr: 0.0001
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001
)
state dict for alphas = Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0.0005
)
{   'alphas': {   'param_groups': [   {   'amsgrad': False,
                                          'betas': (0.9, 0.999),
                                          'eps': 1e-08,
                                          'lr': 0.0001,
                                          'params': [0, 1, 2],
                                          'weight_decay': 0.0005}],
                  'state': {   0: {   'exp_avg': tensor([[ 0.0607, -0.0007],
        [-0.0428, -0.0069],
        [-0.1218,  0.0138],
        [ 0.0086,  0.0238]], device='cuda:0'),
                                      'exp_

In [88]:
current_state = {}
for k, v in environ.schedulers.items():
    print(f'state dict for {k} = {v}')
    print(v.state_dict())

state dict for weights = <torch.optim.lr_scheduler.StepLR object at 0x7f90c01c0ca0>
{'step_size': 4000, 'gamma': 0.5, 'base_lrs': [0.0001, 0.0001], 'last_epoch': 9100, '_step_count': 9101, 'verbose': False, '_get_lr_called_within_step': False, '_last_lr': [2.5e-05, 2.5e-05]}


### Losses and Metrics

In [27]:
trn_losses = environ.losses

In [39]:
print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics , num_prints)      

Epoch |  trn loss     trn spar     trn shar    trn total |   logloss   bceloss    aucroc     aucpr |  val loss     val spar     val shar    val total |tr_time |
12    |    7.3148   1.0717e+00   2.6055e-03       8.3891 |   0.00021   0.97756   0.68488   0.68055 |   14.6682   1.0160e+00   2.1285e-03      15.6863 |  227.5 |

In [27]:
# print_metrics_cr(current_epoch, time.time() - start_time, trn_losses, environ.val_metrics , num_prints)      

In [26]:
# pp.pprint(environ.losses)
pp.pprint(trn_losses)

{   'losses': {   'task1': tensor(2.450312, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'task2': tensor(2.467013, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'task3': tensor(2.687793, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'total': tensor(7.605118, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)},
    'losses_mean': {   'task1': tensor(0.490062, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'task2': tensor(0.493403, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'task3': tensor(0.537559, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'total': tensor(1.521024, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)},
    'parms': {   'gumbel_temp': 2.6244000000000005,
                 'lr_0': 0.01,
                 'lr_1': 0.01,
        

In [38]:
pp.pprint(environ.val_metrics)

{   'aggregated': {   'auc_pr': 0.6805474562906618,
                      'avg_prec_score': 0.6807816533341606,
                      'bceloss': 0.9775612473487854,
                      'f1_max': 0.6969378498338702,
                      'kappa': 0.2500278499434187,
                      'kappa_max': 0.29669160426755137,
                      'logloss': tensor(0.000214, device='cuda:0', dtype=torch.float64),
                      'p_f1_max': 0.2160851760388027,
                      'p_kappa_max': 0.5119348396857579,
                      'roc_auc_score': 0.6848840436108541,
                      'sc_loss': tensor(0.407449, device='cuda:0', dtype=torch.float64)},
    'epoch': 12,
    'loss': {   'task1': 3.91353933321791,
                'task2': 4.751992322450581,
                'task3': 6.002632442335283,
                'total': 14.668164098003771},
    'loss_mean': {   'task1': 0.782707866643582,
                     'task2': 0.9503984644901166,
                     'task3': 1.20

In [187]:
# environ.opt['train']['Lambda_sharing'] = 0.5
# opt['train']['Lambda_sharing'] = 0.5

# environ.opt['train']['policy_lr'] = 0.001
# opt['train']['policy_lr'] = 0.001

In [38]:
environ.losses.keys()
pp.pprint(environ.losses)

{   'losses': {   'task1': tensor(2.3789, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'task2': tensor(2.1948, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'task3': tensor(2.3476, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'total': tensor(6.9212, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)},
    'losses_mean': {   'task1': tensor(0.4758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'task2': tensor(0.4390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'task3': tensor(0.4695, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'total': tensor(1.3842, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)},
    'parms': {'gumbel_temp': 2.6330653896376153, 'lr_0': 0.005, 'lr_1': 0.005},
    'sharing': {   'total': tensor(2.1338e-06, device='cud

In [79]:
tmp = environ.get_loss_dict()
print(tmp.keys())
pp.pprint(tmp)

dict_keys(['task1', 'task2', 'task3', 'tasks', 'total', 'sharing', 'sparsity'])
{   'sharing': {'total': tensor(0.0002, device='cuda:0')},
    'sparsity': {   'task1_logits': tensor(0.5147, device='cuda:0'),
                    'task2_logits': tensor(0.5195, device='cuda:0'),
                    'task3_logits': tensor(0.4978, device='cuda:0'),
                    'total': tensor(0.0766, device='cuda:0')},
    'task1': {'total': tensor(3.3657, device='cuda:0', dtype=torch.float64)},
    'task2': {'total': tensor(3.5906, device='cuda:0', dtype=torch.float64)},
    'task3': {'total': tensor(3.2182, device='cuda:0', dtype=torch.float64)},
    'tasks': {   'task1': tensor(3.3657, device='cuda:0', dtype=torch.float64),
                 'task2': tensor(3.5906, device='cuda:0', dtype=torch.float64),
                 'task3': tensor(3.2182, device='cuda:0', dtype=torch.float64),
                 'total': tensor(10.1745, device='cuda:0', dtype=torch.float64)},
    'total': {'total': tensor(10.25

In [188]:
print(opt['diff_sparsity_weights'])
print(opt['is_sharing'])
print(opt['diff_sparsity_weights'] and not opt['is_sharing'])
print(environ.opt['train']['Lambda_sharing'])
print(opt['train']['Lambda_sharing'])
print(environ.opt['train']['Lambda_sparsity'])
print(opt['train']['Lambda_sparsity'])
print(environ.opt['train']['policy_lr'])
print(opt['train']['policy_lr'])

True
True
False
0.5
0.5
0.05
0.05
0.001
0.001


### Policy / Logit stuff

In [93]:
from scipy.special          import softmax

In [33]:
np.set_printoptions(precision=8,edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions(precision=8,linewidth=132)

#### `get_task_logits(n)` Get logits for task group n

In [128]:
task_logits = environ.get_task_logits(1)
print(task_logits)

Parameter containing:
tensor([[-0.00035114, -0.06397165],
        [ 0.00056738, -0.03663344],
        [ 0.00056098, -0.02617791],
        [-0.00044851, -0.07137010],
        [ 0.00013184, -0.05879313],
        [ 0.00079021, -0.05743587]], device='cuda:0')


#### `get_arch_parameters()`: Get last used logits from network

In [34]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

[Parameter containing:
tensor([[-0.00035120, -0.06617914],
        [ 0.00056736, -0.04341661],
        [ 0.00056091, -0.01096974],
        [-0.00044879, -0.01083876],
        [ 0.00013163,  0.00874004],
        [ 0.00079006, -0.00861552]], device='cuda:0'), Parameter containing:
tensor([[-0.00035114, -0.06397165],
        [ 0.00056738, -0.03663344],
        [ 0.00056098, -0.02617791],
        [-0.00044851, -0.07137010],
        [ 0.00013184, -0.05879313],
        [ 0.00079021, -0.05743587]], device='cuda:0'), Parameter containing:
tensor([[-0.00035016, -0.06321616],
        [ 0.00056696, -0.03072025],
        [ 0.00056129, -0.01022454],
        [-0.00044983, -0.00021709],
        [ 0.00013071,  0.00484093],
        [ 0.00078938, -0.02230957]], device='cuda:0')]


In [25]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

[Parameter containing:
tensor([[ 1.873275e-03, -5.276022e-01],
        [ 2.345233e-03, -2.740704e-01],
        [ 2.614364e-03,  1.604760e-02],
        [ 2.143114e-04,  2.198091e-02],
        [ 4.191113e-04,  5.969038e-02],
        [ 2.007700e-03,  3.544179e-02]], device='cuda:0'), Parameter containing:
tensor([[ 1.873281e-03, -4.892288e-01],
        [ 2.345207e-03, -2.255457e-01],
        [ 2.614349e-03, -2.191145e-01],
        [ 2.143144e-04, -3.354620e-01],
        [ 4.190930e-04, -3.310193e-01],
        [ 2.007697e-03, -2.532191e-01]], device='cuda:0'), Parameter containing:
tensor([[ 1.873283e-03, -6.248206e-01],
        [ 2.345208e-03, -2.149665e-01],
        [ 2.614360e-03, -1.423603e-01],
        [ 2.143196e-04, -1.089546e-01],
        [ 4.191188e-04, -7.532501e-02],
        [ 2.007698e-03, -1.407905e-01]], device='cuda:0')]


#### `get_policy_logits()`:  Get Policy Logits - returns same as `get_arch_parameters()`

In [26]:
logs = environ.get_policy_logits()
for i in logs:
    print(i, '\n')
# probs = softmax(logs, axis= -1)
# for i in probs:
#     print(i, '\n')

[[ 1.8732749e-03 -5.2760220e-01]
 [ 2.3452332e-03 -2.7407044e-01]
 [ 2.6143640e-03  1.6047601e-02]
 [ 2.1431143e-04  2.1980910e-02]
 [ 4.1911125e-04  5.9690382e-02]
 [ 2.0077003e-03  3.5441790e-02]] 

[[ 1.8732806e-03 -4.8922884e-01]
 [ 2.3452067e-03 -2.2554573e-01]
 [ 2.6143489e-03 -2.1911447e-01]
 [ 2.1431442e-04 -3.3546203e-01]
 [ 4.1909298e-04 -3.3101928e-01]
 [ 2.0076970e-03 -2.5321913e-01]] 

[[ 1.8732828e-03 -6.2482059e-01]
 [ 2.3452076e-03 -2.1496648e-01]
 [ 2.6143598e-03 -1.4236034e-01]
 [ 2.1431963e-04 -1.0895463e-01]
 [ 4.1911882e-04 -7.5325012e-02]
 [ 2.0076977e-03 -1.4079048e-01]] 



#### `get_policy_prob()` : Gets the softmax of the logits

In [27]:
policy_softmaxs = environ.get_policy_prob()
for i in policy_softmaxs:
    print(i, '\n')

[[0.6293608  0.37063923]
 [0.5686673  0.43133274]
 [0.49664173 0.5033582 ]
 [0.4945586  0.5054415 ]
 [0.48518652 0.5148135 ]
 [0.49164233 0.50835776]] 

[[0.62036604 0.379634  ]
 [0.55672747 0.44327256]
 [0.5552062  0.44479376]
 [0.58313996 0.4168601 ]
 [0.58210933 0.41789067]
 [0.5634626  0.4365374 ]] 

[[0.6517394  0.34826055]
 [0.5541151  0.44588488]
 [0.5361803  0.46381968]
 [0.5272652  0.47273484]
 [0.518927   0.48107296]
 [0.535639   0.464361  ]] 



#### `get_sample_policy( hard_sampling = False)` : Calls test_sample_policy of network with random choices based on softmax of logits

In [382]:
policy_softmaxs = environ.get_policy_prob()
policies,logits = environ.get_sample_policy(hard_sampling = False)

for l, p, s in zip(logits, policies, policy_softmaxs) :
    for  l_row, p_row, s_row in zip(l, p, s):
        print( l_row,'\t', p_row, '\t', s_row)
    print('\n')

[-0.00035120 -0.06617914] 	 [0 1] 	 [0.51645106 0.48354897]
[ 0.00056736 -0.04341661] 	 [0 1] 	 [0.51099426 0.48900577]
[ 0.00056091 -0.01096974] 	 [1 0] 	 [0.50288266 0.49711737]
[-0.00044879 -0.01083876] 	 [0 1] 	 [0.50259751 0.49740252]
[0.00013163 0.00874004] 	 [0 1] 	 [0.49784794 0.50215214]
[ 0.00079006 -0.00861552] 	 [0 1] 	 [0.50235140 0.49764863]


[-0.00035114 -0.06397165] 	 [0 1] 	 [0.51589972 0.48410025]
[ 0.00056738 -0.03663344] 	 [1 0] 	 [0.50929916 0.49070087]
[ 0.00056098 -0.02617791] 	 [1 0] 	 [0.5066843 0.4933157]
[-0.00044851 -0.07137010] 	 [0 1] 	 [0.51772296 0.48227707]
[ 0.00013184 -0.05879313] 	 [0 1] 	 [0.514727 0.485273]
[ 0.00079021 -0.05743587] 	 [0 1] 	 [0.51455247 0.48544762]


[-0.00035016 -0.06321616] 	 [0 1] 	 [0.51571137 0.48428872]
[ 0.00056696 -0.03072025] 	 [1 0] 	 [0.50782120 0.49217883]
[ 0.00056129 -0.01022454] 	 [1 0] 	 [0.50269639 0.49730355]
[-0.00044983 -0.00021709] 	 [0 1] 	 [0.4999418 0.5000582]
[0.00013071 0.00484093] 	 [1 0] 	 [0.49882248 

#### `get_sample_policy( hard_sampling = True)` : Calls test_sample_policy of network using ARGMAX of logits

In [131]:
policy_softmaxs = environ.get_policy_prob()
hard_policies, logits = environ.get_sample_policy(hard_sampling = True)

for p,l,s in zip(hard_policies, logits, policy_softmaxs) :
    for  p_row, l_row, s_row in zip(p, l, s):
        print( l_row,'\t', p_row, '\t', s_row)
    print('\n')

[-0.0003512  -0.06617914] 	 [1 0] 	 [0.51645106 0.48354897]
[ 0.00056736 -0.04341661] 	 [1 0] 	 [0.51099426 0.48900577]
[ 0.00056091 -0.01096974] 	 [1 0] 	 [0.50288266 0.49711737]
[-0.00044879 -0.01083876] 	 [1 0] 	 [0.5025975  0.49740252]
[0.00013163 0.00874004] 	 [0 1] 	 [0.49784794 0.50215214]
[ 0.00079006 -0.00861552] 	 [1 0] 	 [0.5023514  0.49764863]


[-0.00035114 -0.06397165] 	 [1 0] 	 [0.5158997  0.48410025]
[ 0.00056738 -0.03663344] 	 [1 0] 	 [0.50929916 0.49070087]
[ 0.00056098 -0.02617791] 	 [1 0] 	 [0.5066843 0.4933157]
[-0.00044851 -0.0713701 ] 	 [1 0] 	 [0.51772296 0.48227707]
[ 0.00013184 -0.05879313] 	 [1 0] 	 [0.514727 0.485273]
[ 0.00079021 -0.05743587] 	 [1 0] 	 [0.5145525  0.48544762]


[-0.00035016 -0.06321616] 	 [1 0] 	 [0.51571137 0.48428872]
[ 0.00056696 -0.03072025] 	 [1 0] 	 [0.5078212  0.49217883]
[ 0.00056129 -0.01022454] 	 [1 0] 	 [0.5026964  0.49730355]
[-0.00044983 -0.00021709] 	 [0 1] 	 [0.4999418 0.5000582]
[0.00013071 0.00484093] 	 [0 1] 	 [0.49882248 

#### Print

In [135]:
print(f" Layer    task 1      task 2      task 3")
print(f" -----    ------      ------      ------")
for idx, (l1, l2, l3) in enumerate(zip(hard_policies[0], hard_policies[1], hard_policies[2]),1):
    print(f"   {idx}      {l1}       {l2}       {l3}")
    

    print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

 Layer    task 1      task 2      task 3
 -----    ------      ------      ------
   1      [1 0]       [1 0]       [1 0]
   2      [1 0]       [1 0]       [1 0]
   3      [1 0]       [1 0]       [1 0]
   4      [1 0]       [1 0]       [0 1]
   5      [0 1]       [1 0]       [0 1]
   6      [1 0]       [1 0]       [1 0]


 where [p1  p2]:  p1: layer is selected    p2: layer is not selected


In [402]:
def display_trained_policy(iter):

    policy_softmaxs = environ.get_policy_prob()
    policy_argmaxs = 1-np.argmax(policy_softmaxs, axis = -1)
    print(f"  Trained polcies at iteration: {iter} ")
    print(f"                   task 1                           task 2                         task 3        ")
    print(f" Layer       softmax        select          softmax        select          softmax        select   ")
    print(f" -----    ---------------   ------       ---------------   ------       ---------------   ------   ")
    for idx, (l1,l2,l3,  p1,p2,p3) in enumerate(zip(policy_softmaxs[0], policy_softmaxs[1], policy_softmaxs[2], policy_argmaxs[0], policy_argmaxs[1], policy_argmaxs[2]),1):
        print(f"   {idx}      {l1[0]:.4f}   {l1[1]:.4f}   {p1:4d}    {l2[0]:11.4f}   {l2[1]:.4f}   {p2:4d}    {l3[0]:11.4f}   {l3[1]:.4f}   {p3:4d}")

    print()
# print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

In [405]:
display_trained_policy(5)

  Trained polcies at iteration: 5 
                   task 1                           task 2                         task 3        
 Layer       softmax        select          softmax        select          softmax        select   
 -----    ---------------   ------       ---------------   ------       ---------------   ------   
   1      0.5165   0.4835      1         0.5159   0.4841      1         0.5157   0.4843      1
   2      0.5110   0.4890      1         0.5093   0.4907      1         0.5078   0.4922      1
   3      0.5029   0.4971      1         0.5067   0.4933      1         0.5027   0.4973      1
   4      0.5026   0.4974      1         0.5177   0.4823      1         0.4999   0.5001      0
   5      0.4978   0.5022      0         0.5147   0.4853      1         0.4988   0.5012      0
   6      0.5024   0.4976      1         0.5146   0.4854      1         0.5058   0.4942      1



In [355]:
print(f"                        POLICIES (SOFTMAX)                                       task 3          ")
print(f" Layer    task1              task2            task3 softmax         softmax         argmax         softmax         argmax   ")
print(f" -----    -------------     -------------     -------------   ------   ")
for idx, (l1,l2,l3, h1,h2,h3) in enumerate(zip(policy_softmaxs[0], policy_softmaxs[1], policy_softmaxs[2],hard_policies[0], hard_policies[1], hard_policies[2]),1):
    print(f"   {idx}      {l1[0]:.4f} {l1[1]:.4f}     {l2[0]:.4f} {l2[1]:.4f}     {l3[0]:.4f} {l3[1]:.4f}    {h3}")
    
print(f"\n\n where [p1  p2]:  p1: layer is selected    p2: layer is not selected")

                        POLICIES (SOFTMAX)                                       task 3          
 Layer    task1              task2            task3 softmax         softmax         argmax         softmax         argmax   
 -----    -------------     -------------     -------------   ------   
   1      0.5165 0.4835     0.5159 0.4841     0.5157 0.4843    [1 0]
   2      0.5110 0.4890     0.5093 0.4907     0.5078 0.4922    [1 0]
   3      0.5029 0.4971     0.5067 0.4933     0.5027 0.4973    [1 0]
   4      0.5026 0.4974     0.5177 0.4823     0.4999 0.5001    [0 1]
   5      0.4978 0.5022     0.5147 0.4853     0.4988 0.5012    [0 1]
   6      0.5024 0.4976     0.5146 0.4854     0.5058 0.4942    [1 0]


 where [p1  p2]:  p1: layer is selected    p2: layer is not selected


In [349]:
# print(policy_softmaxs[2], np.argmax(1-policy_softmaxs[2], axis = -1))
print(policy_softmaxs, np.argmax(policy_softmaxs, axis = -1))

[array([[0.51645106, 0.48354897],
       [0.51099426, 0.48900577],
       [0.50288266, 0.49711737],
       [0.50259751, 0.49740252],
       [0.49784794, 0.50215214],
       [0.50235140, 0.49764863]], dtype=float32), array([[0.51589972, 0.48410025],
       [0.50929916, 0.49070087],
       [0.50668430, 0.49331570],
       [0.51772296, 0.48227707],
       [0.51472700, 0.48527300],
       [0.51455247, 0.48544762]], dtype=float32), array([[0.51571137, 0.48428872],
       [0.50782120, 0.49217883],
       [0.50269639, 0.49730355],
       [0.49994180, 0.50005817],
       [0.49882248, 0.50117755],
       [0.50577444, 0.49422547]], dtype=float32)] [[0 0 0 0 1 0]
 [0 0 0 0 0 0]
 [0 0 0 1 1 0]]


#### `get_current_logits()` : Calls test_sample_policy of network using ARGMAX of logits

In [149]:
logits  = (environ.get_current_logits())
for i in logits:
    print(i ,'\n')

[[-0.0003512  -0.06617914]
 [ 0.00056736 -0.04341661]
 [ 0.00056091 -0.01096974]
 [-0.00044879 -0.01083876]
 [ 0.00013163  0.00874004]
 [ 0.00079006 -0.00861552]] 

[[-0.00035114 -0.06397165]
 [ 0.00056738 -0.03663344]
 [ 0.00056098 -0.02617791]
 [-0.00044851 -0.0713701 ]
 [ 0.00013184 -0.05879313]
 [ 0.00079021 -0.05743587]] 

[[-0.00035016 -0.06321616]
 [ 0.00056696 -0.03072025]
 [ 0.00056129 -0.01022454]
 [-0.00044983 -0.00021709]
 [ 0.00013071  0.00484093]
 [ 0.00078938 -0.02230957]] 



#### `get_current_policy()` : Calls test_sample_policy of network using ARGMAX of logits

In [107]:
pols  = (environ.get_current_policy())

for i in pols:
    print(i ,'\n')

[[1 0]
 [1 0]
 [0 1]
 [1 0]
 [1 0]
 [1 0]] 

[[0 1]
 [1 0]
 [1 0]
 [1 0]
 [0 1]
 [0 1]] 

[[0 1]
 [0 1]
 [0 1]
 [0 1]
 [1 0]
 [0 1]] 



#### `gumbel_softmax()`  

In [170]:
np.set_printoptions(precision=8,edgeitems=3, infstr='inf', linewidth=150, nanstr='nan', floatmode = 'maxprec_equal')
torch.set_printoptions(precision=8,linewidth=132)

In [319]:
print(environ.temp)
# tau = environ.temp
tau = 1
for i in range(3): 
    logits_tensor = torch.tensor(logits[0])
    # Sample soft categorical using reparametrization trick:
    gumbel_soft = F.gumbel_softmax(logits_tensor, tau=tau, hard=False).cpu().numpy() 

    # Sample hard categorical using "Straight-through" trick:
    gumbel_hard  = F.gumbel_softmax(logits_tensor, tau=tau, hard=True).cpu().numpy()
    
    for l, gs, gh in zip(lgts, gumbel_soft, gumbel_hard):
        print(f"   {l}   \t {gs}            \t {gh}")
#     print(lgts)
#     print(gumbel_soft)
#     print(gumbel_hard)
    print()

0.0001180506617226113
   [-0.00035016 -0.06321616]   	 [0.17155227 0.82844770]            	 [1. 0.]
   [ 0.00056696 -0.03072025]   	 [0.04782803 0.95217192]            	 [1. 0.]
   [ 0.00056129 -0.01022454]   	 [0.52678031 0.47321963]            	 [0. 1.]
   [-0.00044983 -0.00021709]   	 [0.74226642 0.25773358]            	 [0. 1.]
   [0.00013071 0.00484093]   	 [0.81233245 0.18766758]            	 [1. 0.]
   [ 0.00078938 -0.02230957]   	 [0.76270294 0.23729712]            	 [1. 0.]

   [-0.00035016 -0.06321616]   	 [0.01975500 0.98024493]            	 [0. 1.]
   [ 0.00056696 -0.03072025]   	 [0.33801472 0.66198522]            	 [1. 0.]
   [ 0.00056129 -0.01022454]   	 [0.2644645 0.7355355]            	 [0. 1.]
   [-0.00044983 -0.00021709]   	 [0.08984101 0.91015899]            	 [1. 0.]
   [0.00013071 0.00484093]   	 [0.17066659 0.82933342]            	 [1. 0.]
   [ 0.00078938 -0.02230957]   	 [0.74648136 0.25351864]            	 [0. 1.]

   [-0.00035016 -0.06321616]   	 [0.5077298 0.

In [151]:
for lgts in logits:
    logits_tensor = torch.tensor(lgts)
    print(lgts)
    # Sample soft categorical using reparametrization trick:
    gumbel_soft = F.gumbel_softmax(logits_tensor, tau=1, hard=False)
    print(gumbel_soft)

    # Sample hard categorical using "Straight-through" trick:
    gumbel_hard  = F.gumbel_softmax(logits_tensor, tau=1, hard=True)
    print(gumbel_hard)
    print()

[[-0.0003512  -0.06617914]
 [ 0.00056736 -0.04341661]
 [ 0.00056091 -0.01096974]
 [-0.00044879 -0.01083876]
 [ 0.00013163  0.00874004]
 [ 0.00079006 -0.00861552]]
tensor([[0.22969657, 0.77030343],
        [0.47433791, 0.52566212],
        [0.60556847, 0.39443150],
        [0.00808809, 0.99191189],
        [0.99667323, 0.00332679],
        [0.56034184, 0.43965816]])
tensor([[0., 1.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.]])

[[-0.00035114 -0.06397165]
 [ 0.00056738 -0.03663344]
 [ 0.00056098 -0.02617791]
 [-0.00044851 -0.0713701 ]
 [ 0.00013184 -0.05879313]
 [ 0.00079021 -0.05743587]]
tensor([[0.29727638, 0.70272362],
        [0.65075004, 0.34924990],
        [0.83831531, 0.16168469],
        [0.72130281, 0.27869719],
        [0.87410325, 0.12589674],
        [0.53555954, 0.46444049]])
tensor([[0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.]])

[[-0.00035016 -0.06321616]
 [ 0.00056696 -0.

In [33]:
smax = scipy.special.softmax(logs, axis =1)
# smax = np.array( 
# [[0.46973792, 0.530262  ],
#  [0.45025694, 0.549743  ],
#  [0.4443086 , 0.5556915 ],
#  [0.4138397 , 0.58616036],
#  [0.4140113 , 0.5859887 ],
#  [0.42114905, 0.57885087]])

print(smax.shape)
print(smax)
print(smax[0])
print(smax[0].sum())
print(np.random.choice((1,0), p =smax[0]))

(6, 2)
[[0.47754285 0.52245715]
 [0.45825934 0.54174066]
 [0.45530966 0.54469034]
 [0.43196854 0.56803146]
 [0.43017322 0.56982678]
 [0.43333559 0.56666441]]
[0.47754285 0.52245715]
0.9999999999999998
0


In [142]:
logs = np.array(
[[0.33064184, 0.42053092],
 [0.3532089 , 0.52056104],
 [0.3888512 , 0.5680909 ],
 [0.42039296, 0.694217  ],
 [0.4519742 , 0.73311865],
 [0.48401102, 0.7522658 ]],
)