In [1]:
# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:98% !important; }</style>"))
%load_ext autoreload
%autoreload 2

import os 
import time
import argparse
import yaml
from tqdm import tqdm, tqdm_notebook, trange
# import tqdm.notebook.trange as tnrange
import copy, pprint
import numpy as np
import torch
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader 
import scipy.sparse
from time import sleep
from scipy.special import softmax
 
from datetime import datetime
from GPUtil import showUtilization as gpu_usage
 # from tqdm import trange, tqdm
from tqdm.notebook import trange, tqdm

from dev.sparsechem_utils_dev import load_sparse, load_task_weights, class_fold_counts, fold_and_transform_inputs
from dev.sparsechem_utils_dev import print_metrics_cr
from dev.chembl_dataloader_dev import ClassRegrSparseDataset_v3, ClassRegrSparseDataset, InfiniteDataLoader
from utils.util import ( makedir, print_separator, create_path, print_yaml, print_yaml2, should, 
                         fix_random_seed, read_yaml_from_input, timestring, print_heading, print_dbg, 
                         print_underline, write_parms_report, get_command_line_args)
from dev.sparsechem_env_dev import SparseChemEnv_Dev
from dev.train_dev import evaluate

# torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
pp = pprint.PrettyPrinter(indent=4)
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
torch.set_printoptions( linewidth=132)



In [2]:
print(' Cuda is available  : ', torch.cuda.is_available())
print(' CUDA device count  : ', torch.cuda.device_count())
print(' CUDA current device: ', torch.cuda.current_device())
print(' GPU Processes      : ', torch.cuda.list_gpu_processes())
print()

for i in range(torch.cuda.device_count()):
    print(f" Device : cuda:{i}")
    print('   name:       ', torch.cuda.get_device_name())
    print('   capability: ', torch.cuda.get_device_capability())
    print('   properties: ', torch.cuda.get_device_properties(i))
    ## current GPU memory usage by tensors in bytes for a given device
    print('   Allocated : ', torch.cuda.memory_allocated(i) ) 
    ## current GPU memory managed by caching allocator in bytes for a given device, in previous PyTorch versions the command was torch.cuda.memory_cached
    print('   Reserved  : ', torch.cuda.memory_reserved(i) )   
    print()

gpu_usage()                             

 Cuda is available  :  True
 CUDA device count  :  1
 CUDA current device:  0
 GPU Processes      :  GPU:0
no processes are running

 Device : cuda:0
   name:        NVIDIA GeForce GTX 970M
   capability:  (5, 2)
   properties:  _CudaDeviceProperties(name='NVIDIA GeForce GTX 970M', major=5, minor=2, total_memory=3071MB, multi_processor_count=10)
   Allocated :  0
   Reserved  :  0

| ID | GPU  | MEM |
-------------------
|  0 | nan% |  1% |


## Read yaml config file

In [3]:
input_args = " --config yamls/adashare/chembl_2task.yml --cpu --batch_size 09999".split()
# get command line arguments
args = get_command_line_args(input_args)
print(args)

print()

if args.exp_instance is None:
    args.exp_instance = datetime.now().strftime("%m%d_%H%M")
    
print(args.exp_instance)
print

 command line parms :  {'config': 'yamls/adashare/chembl_2task.yml', 'exp_instance': None, 'exp_ids': [0], 'batch_size': 9999, 'backbone_lr': None, 'task_lr': None, 'decay_lr_rate': None, 'decay_lr_freq': None, 'gpus': [0], 'cpu': True}
Namespace(config='yamls/adashare/chembl_2task.yml', exp_instance=None, exp_ids=[0], batch_size=9999, backbone_lr=None, task_lr=None, decay_lr_rate=None, decay_lr_freq=None, gpus=[0], cpu=True)

0126_0734


<function print>

In [4]:
print_separator('READ YAML')

opt, gpu_ids, _ = read_yaml_from_input(args)

fix_random_seed(opt["seed"][0])

opt['exp_description'] = f"Run network warmup  for 100 iters, then alternating policy/weights modify curriculum speed from 20 to 3  \n"

# folder_name=  f"{opt['exp_instance']}_bs{opt['train']['batch_size']:03d}_{opt['train']['decay_lr_rate']:3.2f}_{opt['train']['decay_lr_freq']}"

print_heading(f" experiment name       : {opt['exp_name']} \n"
              f" experiment instance   : {opt['exp_instance']} \n"
              f" folder_name           : {opt['paths']['exp_folder']} \n"
              f" experiment description: {opt['exp_description']}", verbose = True)

##################################################
####################READ YAML#####################
##################################################
------------------------------------------------------------------------------------------------------------------------
 experiment name       : SparseChem 
 experiment instance   : 0126_0734 
 folder_name           : 0126_0734_bs128_lr0.01_dr0.50_df2000 
 experiment description: Run network warmup  for 100 iters, then alternating policy/weights modify curriculum speed from 20 to 3  

------------------------------------------------------------------------------------------------------------------------ 



In [5]:
# print(opt['exp_instance'])
# print(opt['paths']['exp_folder'])

In [6]:
# for line in lines: 
create_path(opt)    

# print yaml on the screen
for line in print_yaml2(opt):
    print(line)

write_parms_report(opt)    

log_dir        =  os.path.join(opt['paths']['log_dir'], opt['paths']['exp_folder'])
checkpoint_dir =  os.path.join(opt['paths']['checkpoint_dir'], opt['paths']['exp_folder'])

 Create folder ../experiments/AdaSparseChem/0126_0734_bs128_lr0.01_dr0.50_df2000
            exp_name : SparseChem
        exp_instance : 0126_0734
     exp_description : Run network warmup  for 100 iters, then alternating policy/weights modify curriculum speed from 20 to 3  

                seed : [88, 45, 50, 100, 44, 48, 2048, 2222, 9999]
            backbone : SparseChem
       backbone_orig : ResNet18
          orig_tasks : ['seg', 'sn']
               tasks : ['class', 'class', 'class']
     tasks_num_class : [5, 5, 5]
             lambdas : [1, 1, 1]
        policy_model : task-specific
             verbose : False
     input_size_freq : None
          input_size : 32000
        hidden_sizes : [25, 25, 25, 25, 25, 25]
    tail_hidden_size : 25
 first_non_linearity : relu
middle_non_linearity : relu
      middle_dropout : 0.2
  last_non_linearity : relu
        last_dropout : 0.2
   class_output_size : None
    regr_output_size : None
              policy : True
     init_neg_lo

## Chembl Dataloader V3

In [7]:
trainset  = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 0, verbose = False)
trainset1 = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 1)
trainset2 = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 2)
valset    = ClassRegrSparseDataset_v3(opt, split_ratios = opt['dataload']['x_split_ratios'], ratio_index = 3)


train_loader  = InfiniteDataLoader(trainset , batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset.collate, shuffle=False)
val_loader    = InfiniteDataLoader(valset   , batch_size=opt['train']['batch_size'], num_workers = 1, pin_memory=True, collate_fn=valset.collate  , shuffle=False)
train1_loader = InfiniteDataLoader(trainset1, batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset1.collate, shuffle=False)
train2_loader = InfiniteDataLoader(trainset2, batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset2.collate, shuffle=False)


opt['train']['weight_iter_alternate'] = opt['train'].get('weight_iter_alternate', len(train1_loader))
opt['train']['alpha_iter_alternate']  = opt['train'].get('alpha_iter_alternate'  , len(train2_loader))

## Create Environment

In [8]:

print(f" trainset.y_class                       :  {[ i.shape  for i in trainset.y_class_list]}")
print(f" trainset1.y_class                      :  {[ i.shape  for i in trainset1.y_class_list]}")
print(f" trainset2.y_class                      :  {[ i.shape  for i in trainset2.y_class_list]}")
print(f" valset.y_class                         :  {[ i.shape  for i in valset.y_class_list  ]} ")
print()

print(f' size of training set 0 (warm up)       :  {len(trainset)}')
print(f' size of training set 1 (network parms) :  {len(trainset1)}')
print(f' size of training set 2 (policy weights):  {len(trainset2)}')
print(f' size of validation set                 :  {len(valset)}')
print(f'                               Total    :  {len(trainset)+len(trainset1)+len(trainset2)+len(valset)}')
print()
print(f" batch size                             :  {opt['train']['batch_size']}")
print()
print(f" # batches training 0 (warm up)         :  {len(train_loader)}")
print(f" # batches training 1 (network parms)   :  {len(train1_loader)}")
print(f" # batches training 2 (policy weights)  :  {len(train2_loader)}")
print(f" # batches validation dataset           :  {len(val_loader)}")
print()
print(
    f"\n batch size                             : {opt['train']['batch_size']}", 
    f"\n backbone                               : {opt['backbone']}",
    f"\n paths.log_dir                          : {opt['paths']['log_dir']}", 
    f"\n paths.checkpoint_dir                   : {opt['paths']['checkpoint_dir']}", 
    f"\n experiment name                        : {opt['exp_name']}",
    f"\n tasks_num_class                        : {opt['tasks_num_class'],}",
    f"\n Hidden sizes                           : {opt['hidden_sizes']}",     
    f"\n init_neg_logits                        : {opt['init_neg_logits'],}",
    f"\n device id                              : {gpu_ids[0]}",
    f"\n init temp                              : {opt['train']['init_temp'],}",
    f"\n decay temp                             : {opt['train']['decay_temp']}",
    f"\n fix BN parms                           : {opt['fix_BN']}",
    f"\n skip_layer                             : {opt['skip_layer']}",
    f"\n train.init_method                      : {opt['train']['init_method']}",
    f"\n Total iterations                       : {opt['train']['total_iters']}",
    f"\n Warm-up iterations                     : {opt['train']['warm_up_iters']}",
    f"\n Print Frequency                        : {opt['train']['print_freq']}",
    f"\n Validation Frequency                   : {opt['train']['val_freq']} \n",
    f"\n Weight iter alternate                  : {opt['train']['weight_iter_alternate'] }",
    f"\n Alpha  iter alternate                  : {opt['train']['alpha_iter_alternate'] }")
# print('\n\n Opt file \n ------------ \n')
# pp.pprint(opt)

 trainset.y_class                       :  [(13791, 5), (13791, 5), (13791, 5)]
 trainset1.y_class                      :  [(18, 5), (18, 5), (18, 5)]
 trainset2.y_class                      :  [(18, 5), (18, 5), (18, 5)]
 valset.y_class                         :  [(4561, 5), (4561, 5), (4561, 5)] 

 size of training set 0 (warm up)       :  13791
 size of training set 1 (network parms) :  18
 size of training set 2 (policy weights):  18
 size of validation set                 :  4561
                               Total    :  18388

 batch size                             :  128

 # batches training 0 (warm up)         :  108
 # batches training 1 (network parms)   :  1
 # batches training 2 (policy weights)  :  1
 # batches validation dataset           :  36


 batch size                             : 128 
 backbone                               : SparseChem 
 paths.log_dir                          : ../experiments/AdaSparseChem 
 paths.checkpoint_dir                   : ../experimen

### Create model


In [9]:
environ = SparseChemEnv_Dev(log_dir = log_dir, 
                            checkpoint_dir = checkpoint_dir, 
                            exp_name = opt['exp_name'],
                            tasks_num_class = opt['tasks_num_class'], 
                            init_neg_logits = opt['init_neg_logits'], 
                            device = gpu_ids[0],
                            init_temperature = opt['train']['init_temp'], 
                            temperature_decay= opt['train']['decay_temp'], 
                            is_train=True,
                            opt=opt, 
                            verbose = False)

cfg = environ.print_configuration()
write_parms_report(opt, cfg, mode = 'a')

-------------------------------------------------------
* SparseChemEnv_Dev  Initializtion - verbose: False
------------------------------------------------------- 

------------------------------------------------------------
SparseChemEnv_Dev.super() init()  Start - verbose: False
------------------------------------------------------------ 

 log_dir        :  ../experiments/AdaSparseChem/0126_0734_bs128_lr0.01_dr0.50_df2000 
 checkpoint_dir :  ../experiments/AdaSparseChem/0126_0734_bs128_lr0.01_dr0.50_df2000 
 exp_name       :  SparseChem 
 tasks_num_class:  [5, 5, 5] 
 device         :  cuda:0 
 device id      :  0 
 dataset        :  Chembl_23_mini 
 tasks          :  ['class', 'class', 'class'] 

--------------------------------------------------
SparseChemEnv_Dev.super() init()  end
-------------------------------------------------- 

 is_train       :  True 
 init_neg_logits:  None 
 init temp      :  5 
 decay temp     :  0.965 
 input_size     :  32000 
 normalize loss :  No

In [10]:
environ.networks['mtl-net']

MTL3_Dev(
  (backbone): SparseChem_Backbone(
    (Input_linear): SparseLinear(in_features=32000, out_features=25, bias=True)
    (blocks): ModuleList(
      (0): ModuleList(
        (0): SparseChemBlock(
          (linear): Linear(in_features=25, out_features=25, bias=True)
          (non_linear): ReLU()
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (1): ModuleList(
        (0): SparseChemBlock(
          (linear): Linear(in_features=25, out_features=25, bias=True)
          (non_linear): ReLU()
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): ModuleList(
        (0): SparseChemBlock(
          (linear): Linear(in_features=25, out_features=25, bias=True)
          (non_linear): ReLU()
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (3): ModuleList(
        (0): SparseChemBlock(
          (linear): Linear(in_features=25, out_features=25, bias=True)
          (non_linear): ReLU()
          (dropout): D

## Warmup Training

### Training Preparation

In [11]:
print( 
    f"\n tasks_num_class         : {opt['tasks_num_class'],}",
    f"\n init_neg_logits         : {opt['init_neg_logits'],}",
    f"\n device id               : {gpu_ids[0]}",
    f"\n init temp               : {opt['train']['init_temp'],}",
    f"\n decay temp              : {opt['train']['decay_temp']}",
    f"\n fix BN parms            : {opt['fix_BN']}",
    f"\n skip_layer              : {opt['skip_layer']}",
    f"\n"
    f"\n train.init_method       : {opt['train']['init_method']}",
    f"\n Total iterations        : {opt['train']['total_iters']}",
    f"\n Warm-up iterations      : {opt['train']['warm_up_iters']}",
    f"\n Print Frequency         : {opt['train']['print_freq']}",
    f"\n Validation Frequency    : {opt['train']['val_freq']}",
    f"\n Weight iter alternate   : {opt['train']['weight_iter_alternate'] }",
    f"\n Alpha  iter alternate   : {opt['train']['alpha_iter_alternate'] }",
    f"\n Network[mtl_net].layers : {environ.networks['mtl-net'].layers}",
    f"\n Num_blocks              : {sum(environ.networks['mtl-net'].layers)}")


 tasks_num_class         : ([5, 5, 5],) 
 init_neg_logits         : (None,) 
 device id               : 0 
 init temp               : (5,) 
 decay temp              : 0.965 
 fix BN parms            : False 
 skip_layer              : 0 

 train.init_method       : equal 
 Total iterations        : 25000 
 Warm-up iterations      : 25000 
 Print Frequency         : -1 
 Validation Frequency    : 500 
 Weight iter alternate   : 1 
 Alpha  iter alternate   : 1 
 Network[mtl_net].layers : [1, 1, 1, 1, 1, 1] 
 Num_blocks              : 6


In [12]:
environ.define_optimizer(policy_learning=False)
environ.define_scheduler(policy_learning=False)
# Fix Alpha - 
environ.fix_alpha()
environ.free_weights(opt['fix_BN'])

if opt['train']['resume']:
    print('Resume training')
    current_iter = environ.load(opt['train']['which_iter'])
    environ.networks['mtl-net'].reset_logits()
else:
    print('Initiate Training ')


if torch.cuda.is_available():
    print('cuda available', gpu_ids)   
    environ.cuda(gpu_ids)
else:
    print('cuda not available')
    environ.cpu()
print('\n')

current_iter   = 0
current_iter_w = 0 
current_iter_a = 0
flag         = 'update_w'
best_value   = 0 
best_iter    = 0
p_epoch      = 0
best_metrics = None
flag_warmup  = True
eval_iter    = -1
num_prints   = 0

num_blocks = sum(environ.networks['mtl-net'].layers)

if opt['train']['print_freq'] == -1:
    print(f" set print_freq to length of train loader: {len(train_loader)}")
    opt['train']['print_freq']    = len(train_loader)

Initiate Training 
cuda available [0]


 set print_freq to length of train loader: 108


In [13]:
opt['train']['warm_up_iters'] = 100
# opt['train']['print_freq'] =1
# opt['train']['weight_iter_alternate'] = 10R1
# opt['train']['alpha_iter_alternate']  = 10
# opt['train']['warm_up_iters'] = 200
# opt['train']['val_freq']   = 50
# opt['train']['val_freq']      = 500

print(f"which_iter          : {opt['train']['which_iter']}")
print(f"train_resume        : {opt['train']['resume']}")
print()
print(f"Length train_loader : {len(train_loader)}")
print(f"Length val_loader   : {len(val_loader)}")
print()
print(f"total_iters         : {opt['train']['total_iters']}")  
print(f"warm_up_iters       : {opt['train']['warm_up_iters']}")   
print()
print(f"val_freq            : {opt['train']['val_freq']     }")      
print(f"print_freq          : {opt['train']['print_freq']  }")
print()
print(f"batch_size          : {opt['train']['batch_size']   }")         
print(f"Backbone LR         : {opt['train']['backbone_lr']   }")        
print(f"LR decay rate       : {opt['train']['decay_lr_rate']   }")        
print(f"LR decay frequency  : {opt['train']['decay_lr_freq']   }")        
print()
print(f" output folder      : {opt['paths']['exp_folder']}")

which_iter          : warmup
train_resume        : False

Length train_loader : 108
Length val_loader   : 36

total_iters         : 25000
warm_up_iters       : 100

val_freq            : 500
print_freq          : 108

batch_size          : 128
Backbone LR         : 0.01
LR decay rate       : 0.5
LR decay frequency  : 2000

 output folder      : 0126_0734_bs128_lr0.01_dr0.50_df2000


### Warm-up Training

In [14]:
# print(current_iter)
#     stop_iter  = current_iter + 1000

In [14]:
stop_iter  = current_iter + opt['train']['warm_up_iters']
print(f" Last iteration: {current_iter}  # of warm-up iterations to do:{opt['train']['warm_up_iters']} - Run  from {current_iter+1} to {stop_iter}")


 Last iteration: 0  # of warm-up iterations to do:100 - Run  from 1 to 100


In [15]:
##---------------------------------------------------------------     
## part one: warm up
##---------------------------------------------------------------
num_prints = 0
print(f" Last iteration: {current_iter}  # of warm-up iterations to do:{opt['train']['warm_up_iters']} - Run  from {current_iter+1} to {stop_iter}")
t0 = time.time()

with trange(current_iter+1, stop_iter+1 , initial = current_iter, total = stop_iter, position=0, leave= True, desc="training") as t_warmup :
    
    for current_iter in t_warmup:
        start_time = time.time()
        environ.train()
        batch = next(train_loader)    
    
#         print_heading(f" {timestring()} - WARMUP Training iter {current_iter}/{opt['train']['warm_up_iters']} ", verbose = False)

        environ.set_inputs(batch, train_loader.dataset.input_size)

        environ.optimize(opt['lambdas'], 
                         is_policy=False, 
                         flag='update_w', 
                         verbose = False)
        
        t_warmup.set_postfix({'curr_iter':current_iter, 
                              'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                              'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})

        if should(current_iter, opt['train']['print_freq']):
#             environ.print_loss(current_iter, start_time, verbose = False)
            environ.print_loss(current_iter, start_time, title = f"[c]Warmup training : {curr_epoch} iteration:", verbose = True)
        ##--------------------------------------------------------------- 
        # validation
        ##--------------------------------------------------------------- 
        if should(current_iter, opt['train']['val_freq']):
            environ.print_loss(current_iter, start_time, title = f"[e]Weight training epoch:{curr_epoch} iteration:", verbose = True)
#             print_dbg(f"**  {timestring()}  START VALIDATION iteration: {current_iter}    Validation freq {opt['train']['val_freq']}") 

            num_seg_class = opt['tasks_num_class'][opt['tasks'].index('seg')] if 'seg' in opt['tasks'] else -1
            val_metrics = environ.evaluate(
                                   val_loader, 
                                   opt['tasks'], 
                                   is_policy=False, 
                                   num_train_layers=None,
                                   eval_iter = eval_iter, 
                                   progress=True,
                                   leave = False,
                                   verbose = False)

            environ.print_metrics(current_iter, start_time, title='validation')
            environ.save_checkpoint('warmup', current_iter)
            
            print_metrics_cr(current_iter, time.time() - t0, None, environ.val_metrics, num_prints)
            num_prints += 1            
            t0 = time.time()
            print()     

 Last iteration: 0  # of warm-up iterations to do:100 - Run  from 1 to 100


training:   0%|          | 0/100 [00:00<?, ?it/s]

 Task losses:  20.0159     mean:   4.0032 
 Task losses:  16.2607     mean:   3.2521 
 Task losses:  12.4230     mean:   2.4846 
 Task losses:  11.1655     mean:   2.2331 
 Task losses:  11.4116     mean:   2.2823 
 Task losses:  11.3344     mean:   2.2669 
 Task losses:  11.0627     mean:   2.2125 
 Task losses:  11.1395     mean:   2.2279 
 Task losses:  10.8323     mean:   2.1665 
 Task losses:  10.3662     mean:   2.0732 
 Task losses:  10.2680     mean:   2.0536 
 Task losses:  11.2267     mean:   2.2453 
 Task losses:  11.6958     mean:   2.3392 
 Task losses:  11.0142     mean:   2.2028 
 Task losses:  10.6298     mean:   2.1260 
 Task losses:  11.5142     mean:   2.3028 
 Task losses:  10.8070     mean:   2.1614 
 Task losses:  10.5107     mean:   2.1021 
 Task losses:  10.7061     mean:   2.1412 
 Task losses:  10.5138     mean:   2.1028 
 Task losses:   9.9856     mean:   1.9971 
 Task losses:   9.7149     mean:   1.9430 
 Task losses:  11.3587     mean:   2.2717 
 Task losse

### Post Warm-up Training stuff

In [19]:
print(environ.optimizers['weights'])
print(environ.schedulers['weights'].get_last_lr())

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    initial_lr: 0.01
    lr: 0.01
    weight_decay: 0.0001

Parameter Group 1
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    initial_lr: 0.01
    lr: 0.01
    weight_decay: 0.0001
)
[0.01, 0.01]


In [19]:
print('losses.keys      : ', environ.losses.keys())
print('losses[task]keys : ', environ.losses['task1'].keys())
pp.pprint(environ.losses)

losses.keys      :  dict_keys(['parms', 'task1', 'task2', 'task3'])
losses[task]keys :  dict_keys(['cls_loss', 'cls_loss_mean'])
{   'parms': {'gumbel_temp': 5, 'lr_0': 0.0001, 'lr_1': 0.0001},
    'task1': {   'cls_loss': tensor(3.5276, device='cuda:0', dtype=torch.float64),
                 'cls_loss_mean': tensor(0.7055, device='cuda:0', dtype=torch.float64)},
    'task2': {   'cls_loss': tensor(2.6970, device='cuda:0', dtype=torch.float64),
                 'cls_loss_mean': tensor(0.5394, device='cuda:0', dtype=torch.float64)},
    'task3': {   'cls_loss': tensor(4.1007, device='cuda:0', dtype=torch.float64),
                 'cls_loss_mean': tensor(0.8201, device='cuda:0', dtype=torch.float64)}}


In [20]:
print( environ.val_metrics.keys())
# pp.pprint(val_metrics)
print(type(environ.val_metrics['aggregated']))
print()
print(type(environ.val_metrics['task1']['classification_agg']))
print()
pp.pprint(environ.val_metrics)

dict_keys(['loss', 'loss_mean', 'task1', 'task2', 'task3', 'aggregated', 'train_time', 'epoch'])
<class 'dict'>

<class 'dict'>

{   'aggregated': {   'auc_pr': 0.8163426647572675,
                      'avg_prec_score': 0.8164138615123665,
                      'bceloss': 0.7729340473810831,
                      'f1_max': 0.7574111413660235,
                      'kappa': 0.4593877004350045,
                      'kappa_max': 0.4722816344851592,
                      'logloss': tensor(0.0002, device='cuda:0', dtype=torch.float64),
                      'p_f1_max': 0.1918067594369252,
                      'p_kappa_max': 0.545021508137385,
                      'roc_auc_score': 0.8101998985724014,
                      'sc_loss': tensor(0.3217, device='cuda:0', dtype=torch.float64)},
    'epoch': 4000,
    'loss': {   'task1': 3.990690022259455,
                'task2': 3.7726694706664947,
                'task3': 3.817710672775008,
                'total': 11.581070165700957},
    'l

In [20]:
# import pickle
# with open("val_metrics.pkl", mode= 'wb') as f:
#         pickle.dump(val_metrics, f)
    
# with open('val_metrics.pkl', 'rb') as f:    
#     tst_val_metrics = pickle.load(f)

In [21]:
# print(environ.input.shape) 
# a = getattr(environ, 'task1_pred')
# yc_data = environ.batch['task1_data']
# print(yc_data.shape)
# yc_ind = environ.batch['task1_ind']
# print(yc_ind.shape)
# yc_hat_all = getattr(environ, 'task1_pred')
# print(yc_hat_all.shape)
# yc_hat  = yc_hat_all[yc_ind[0], yc_ind[1]]
# print(yc_hat_all.shape, yc_hat.shape)

# 
# environ.losses
# loss = {}
# for key in environ.losses.keys():
#     loss[key] = {}
#     for subkey, v in environ.losses[key].items():
#         print(f" key:  {key}   subkey: {subkey} ")
#         if isinstance(v, torch.Tensor):
#             loss[key][subkey] = v.data
#             print(f" Tensor  -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
#         else:
#             loss[key][subkey] = v
#             print(f" integer -  key:  {key}   subkey: {subkey}           value type: {type(v)}  value: {v:.4f}")
# pp.pprint(tst_val_metrics)             

In [22]:
# print('metrics.keys: ', environ.metrics.keys())
# print('metrics[task].keys: ', environ.metrics['task1'].keys())
# pp.pprint(environ.metrics['task1'])
# pp.pprint(environ.losses['task1']['total'])

In [23]:
# title='Iteration'
# for t_id, _ in enumerate(environ.tasks):
#     task_key = f"task{t_id+1}"
# #     print_heading(f"{title}  {current_iter}  {task_key} : {val_metrics[task_key]['classification_agg']}", verbose = True)

#     for key, _  in val_metrics[task_key]['classification_agg'].items():
#         print('%s/%-20s'%(task_key, key), val_metrics[task_key]['classification_agg'][key], current_iter)
#         print(f"{task_key:s}/{key:20s}", val_metrics[task_key]['classification_agg'][key], current_iter)
#         print()
#             # print_current_errors(os.path.join(self.log_dir, 'loss.txt'), current_iter,key, loss[key], time.time() - start_time)

In [32]:
# environ.print_loss(current_iter, start_time, metrics = val_metrics['loss'], verbose=True)
# print(opt['lambdas'])
# p = (opt['lambdas'][0] * environ.losses['tasks']['task1'])
# print(p)

# environ.print_metrics(current_iter, start_time, val_metrics , title='validation', verbose=True)    

In [30]:
# print(current_iter)
# print_metrics_cr(current_iter, t1 - t0, None, val_metrics , True)
# environ.print_metrics(current_iter, start_time, val_metrics, title='validation', verbose = True)

In [21]:
print(f" val_metric keys               : {val_metrics.keys()}")
print(f" loss keys                     : {val_metrics['loss'].keys()}")
print(f" task1 keys                    : {val_metrics['task1'].keys()}")
print(f" task1 classification keys     : {val_metrics['task1']['classification'].keys()}")
print(f" task1 classification_agg keys : {val_metrics['task1']['classification_agg'].keys()}")
print()
print(f" task1                       : {val_metrics['task1']['classification_agg']['loss']:5f}")
print(f" task2                       : {val_metrics['task2']['classification_agg']['loss']:5f}")
print(f" task3                       : {val_metrics['task3']['classification_agg']['loss']:5f}")
print(f" loss                        : {val_metrics['loss']['total']:5f}")
print(f" train_time                  : {val_metrics['train_time']:2f}")
print(f" epoch                       : {val_metrics['epoch']}")


AttributeError: 'NoneType' object has no attribute 'keys'

## Weight & Policy Training

In [16]:
pp.pprint(environ.optimizers)
# pp.pprint(environ.schedulers['weights'])

print_heading(f"** {timestring()} - Training current iteration {current_iter}  warmup iters: {opt['train']['warm_up_iters']}   flag: {flag} ", verbose = True)    

if flag_warmup:
    print_heading(f"** Set optimizer and scheduler to policy_learning = True", verbose = True)
    environ.define_optimizer(policy_learning=True)
    environ.define_scheduler(policy_learning=True)
    flag_warmup = False

if current_iter == opt['train']['warm_up_iters']:
    print_heading(f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
                  f"   Take checkpoint and block gradient flow through Policy net", verbose=True)
    environ.save_checkpoint('warmup', current_iter)
    environ.fix_alpha()
    

{   'alphas': Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0.0005
),
    'weights': Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    initial_lr: 0.01
    lr: 0.01
    weight_decay: 0.0001

Parameter Group 1
    amsgrad: False
    betas: (0.5, 0.999)
    eps: 1e-08
    initial_lr: 0.01
    lr: 0.01
    weight_decay: 0.0001
)}
<torch.optim.lr_scheduler.StepLR object at 0x7fa368890ac0>
-------------------------------------------------------------------------------------------------------
** 2022-01-26 07:34:49:324696 - Training current iteration 100  warmup iters: 100   flag: update_w 
------------------------------------------------------------------------------------------------------- 

------------------------------------------------------------
** Set optimizer and scheduler to policy_learning = True
------------------------------------------------------------ 

--------------

In [17]:
# batch_enumerator1 = enumerate(train1_loader,1)  
# batch_enumerator2 = enumerate(train2_loader,1)  
current_iter_w = 0 
current_iter_a = 0 
batch_idx_a = 0 
batch_idx_w = 0 
curr_epoch  = 0 
num_prints  = 0
train_total_epochs = 10
curriculum_speed = opt['curriculum_speed']  

stop_iter_w = len(train_loader)
stop_iter_a = len(train_loader)

print(f"stop_iter_w :                      {stop_iter_w}")
print(f"stop_iter_a :                      {stop_iter_a}")
print()         
print(f"weight_iter_alternate:             {opt['train']['weight_iter_alternate']}")
print(f"alpha_iter_alternate :             {opt['train']['alpha_iter_alternate']}")
print()
print(f"opt['train']['print_freq']         {opt['train']['print_freq']}")
print(f"opt['train']['hard_sampling']      {opt['train']['hard_sampling']}")
print(f"opt['policy']                      {opt['policy']}")
print(f"opt['tasks']                       {opt['tasks']}")
print(f"opt['fix_BN']                      {opt['fix_BN']}" )
print()
print(f"total_iters                        {opt['train']['total_iters']}")  
print(f"current_iter                       {current_iter  }")
print()
print(f"current_iter_w                     {current_iter_w}")
print(f"current_iter_a                     {current_iter_a}")
print(f"batch_idx_w                        {batch_idx_w}")
print()
print(f"curriculum_speed                   {curriculum_speed}")
print(f"curr_epochs                        {curr_epoch}") 
print(f"train_total_epochs                 {train_total_epochs}") 
print()
print(f"flag                               {flag          }")

stop_iter_w :                      108
stop_iter_a :                      108

weight_iter_alternate:             1
alpha_iter_alternate :             1

opt['train']['print_freq']         108
opt['train']['hard_sampling']      False
opt['policy']                      True
opt['tasks']                       ['class', 'class', 'class']
opt['fix_BN']                      False

total_iters                        25000
current_iter                       100

current_iter_w                     0
current_iter_a                     0
batch_idx_w                        0

curriculum_speed                   3
curr_epochs                        0
train_total_epochs                 10

flag                               update_w


In [19]:
# curr_epoch = 0
# train_total_epochs = 55
# train_total_epochs = 40

# curr_epoch = 25

In [18]:
print(current_iter_a, stop_iter_a)
print(current_iter_w, stop_iter_w)
print(opt['policy'])

0 108
0 108
True


In [193]:
print(f"current_iters         : {current_iter}")  
print(f"curr_epochs           : {curr_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

current_iters         : 19540
curr_epochs           : 90
train_total_epochs    : 90


In [194]:
train_total_epochs += 30
# train_total_epochs += 2

In [195]:
print(f"current_iters         : {current_iter}")  
print(f"curr_epochs           : {curr_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

current_iters         : 19540
curr_epochs           : 90
train_total_epochs    : 120


In [196]:
num_prints = 0 
verbose = False
t = tqdm(initial = curr_epoch, total=train_total_epochs, desc=f" Alternate Weight/Policy training")


 Alternate Weight/Policy training:  75%|#######5  | 90/120 [00:00<?, ?it/s]

In [197]:
t0 = time.time()
while curr_epoch < train_total_epochs:
    curr_epoch+=1
    t.update(1)

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------
    if flag == 'update_w':
        current_iter_w  = 0 
#         stop_iter_w =   opt['train']['weight_iter_alternate']

        with trange(+1, stop_iter_w+1 , initial = current_iter_w, total = stop_iter_w, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} weight training") as t_weights :
            
            for current_iter_w in t_weights:    
                current_iter += 1

                start_time = time.time()
                environ.train()
                
                batch = next(train_loader)
#                 batch = next(train1_loader)
                
                environ.set_inputs(batch, train1_loader.dataset.input_size)

                ##----------------------------------------------------------------------
                ## Set number of layers to train based on cirriculum_speed 
                ## and p_epoch (number of epochs of policy training)
                ## When curriculum_speed == 3, a num_train_layers is incremented 
                ## after completion of every 3 policy training epochs
                ##----------------------------------------------------------------------
                if opt['is_curriculum']:
                    num_train_layers = p_epoch // opt['curriculum_speed'] + 1
                else:
                    num_train_layers = None

                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)

                t_weights.set_postfix({'iter': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                       'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
#                 if should(current_iter, opt['train']['print_freq']):
#                     environ.print_loss(current_iter, start_time, title = f"[c]Weight training epoch:{curr_epoch} iteration:", verbose = True)

        #-------------------------------------------------------
        # validation process 
        #------------------------------------------------------- 
#       if should(current_iter_w, opt['train']['weight_iter_alternate']): 
        if (current_iter_w >= stop_iter_w):
#             print(f" current_iter_w: {current_iter_w}   stop_iter_w: {stop_iter_w}   (are equal)")        
            environ.print_loss(current_iter, start_time, title = f"[e]Weight training epoch:{curr_epoch} iteration:", verbose = True)
            environ.eval()

            val_metrics = environ.evaluate(val_loader, 
                                           opt['tasks'], 
                                           is_policy=opt['policy'],
                                           num_train_layers=num_train_layers, 
                                           hard_sampling=opt['train']['hard_sampling'],
                                           eval_iter = eval_iter,
                                           progress = True, 
                                           leave = False, 
                                           verbose = False)  
    
            environ.print_metrics(current_iter, start_time, val_metrics, title = f"[v]Weight training epoch:{curr_epoch} iteration:", verbose = False)
            print_metrics_cr(curr_epoch, time.time() - t0, None, environ.val_metrics , num_prints)      

            num_prints += 1
            t0 = time.time()
            
            # Take check point
            environ.save_checkpoint('latest', current_iter)
            #-----------------------------------------------------------------------------------------------------------------------
            #
            #            #----------------------------------------------------------------------------------------------
            #            # if number of iterations completed after the warm up phase is greater than the number of 
            #            # (weight/policy alternations) x (cirriculum speed) x (number of layers to be policy trained)
            #            #
            #            # check metrics for improvement, and issue a checkpoint if necessary
            #            #----------------------------------------------------------------------------------------------
            # 
            #             if current_iter - opt['train']['warm_up_iters'] >= num_blocks * opt['curriculum_speed'] * \
            #                     (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']):
            #                 new_value = 0
            #                 print(f"  {current_iter - opt['train']['warm_up_iters']} IS GREATER THAN "
            #                        f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} -- "
            #                        f"  evaluate progress and make checkpoint if necessary." )            
            # 
            #                 ## compare validation metrics against reference metrics.
            #                 
            #                 for k in refer_metrics.keys():
            #                     if k in val_metrics.keys():
            #                         for kk in val_metrics[k].keys():
            #                             if not kk in refer_metrics[k].keys():
            #                                 continue
            #                             if (k == 'sn' and kk in ['Angle Mean', 'Angle Median']) or (
            #                                     k == 'depth' and not kk.startswith('sigma')) or (kk == 'err'):
            #                                 value = refer_metrics[k][kk] / val_metrics[k][kk]
            #                             else:
            #                                 value = val_metrics[k][kk] / refer_metrics[k][kk]
            #                             value = value / len(list(set(val_metrics[k].keys()) & set(refer_metrics[k].keys())))
            #                             new_value += value
            # 
            #                 print('Best Value %.4f  New value: %.4f' % new_value)
            # 
            #                 ## if results have improved, save these results and issue a checkpoint
            # 
            #                 if (new_value > best_value):
            #                     print('Previous best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)
            #                     best_value = new_value
            #                     best_metrics = val_metrics
            #                     best_iter = current_iter
            #                     environ.save_checkpoint('best', current_iter)
            #                     print('New      best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)                         
            #                     print('Best Value %.4f  New value: %.4f' % new_value)
            #
            #-----------------------------------------------------------------------------------------------------------------------

            environ.train()
            #-------------------------------------------------------
            # END validation process
            #-------------------------------------------------------       
            flag = 'update_alpha'
            environ.fix_weights()
            environ.free_alpha()
    #-------------------------------------------------------
    # end weight training iteration
    #-------------------------------------------------------               

    # When we want to separate the two loops
    # while curr_epoch <= train_total_epochs:
    #     curr_epoch+=1
    #     t.update(1)

    #-----------------------------------------
    # Train & Update the  policy 
    #-----------------------------------------
    if flag == 'update_alpha':
        current_iter_a = 0

        with trange( +1, stop_iter_a+1 , initial = 0, total = stop_iter_a, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} policy training") as t_policy :
            for current_iter_a in t_policy:    
                current_iter += 1

                batch = next(train_loader)
#                 batch = next(train2_loader)
                
                environ.set_inputs(batch, train2_loader.dataset.input_size)

                if opt['is_curriculum']:
                    num_train_layers = (p_epoch // opt['curriculum_speed']) + 1
                else:
                    num_train_layers = None

                print_dbg(f" num_train_layers  : {num_train_layers}", verbose = False)

                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)
                
                t_policy.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                      'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = f"[c]Policy training epoch:{curr_epoch} iteration:", verbose=False)
#                     environ.visual_policy(current_iter)

        if( current_iter_a >= stop_iter_a):            
#             print(f" current_iter_a: {current_iter_a}   stop_iter_a: {stop_iter_a}   (are equal)")
            environ.print_loss(current_iter, start_time, title = f"[e]Policy training epoch:{curr_epoch} iteration:", verbose=True)
            flag = 'update_w'
            p_epoch += 1
            environ.fix_alpha()
            environ.free_weights(opt['fix_BN'])
            if should(p_epoch, opt['train']['decay_temp_freq']):
                print(f"[e]Policy training epoch:{p_epoch} decay gumbel temp" )
                environ.decay_temperature()

            # print the distribution
            print_dbg(np.concatenate(environ.get_policy_prob(), axis=-1), verbose = False)
            

            print_dbg(f"** p_epoch incremented: {p_epoch}")

Epoch 91 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:91 iteration:  19648 -  Total Loss: 4.0531     Task Loss: 4.0531  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

Epoch | logloss bceloss  aucroc   aucpr  f1_max| t1 loss t2 loss t3 lossttl loss|tr_time|
91    | 0.00018 0.81026 0.78620 0.78616 0.74446|  3.7568  3.8431  4.5338 12.1337|   17.9|

Epoch 91 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.4930     mean:   0.6986     Sharing: 5.28842e-05     Sparsity: 8.59965e-02     Total:   3.5790     mean:   0.7846
 Task losses:   8.1971     mean:   1.6394     Sharing: 6.96386e-03     Sparsity: 8.59097e-02     Total:   8.2900     mean:   1.7323
 Task losses:  10.8667     mean:   2.1733     Sharing: 9.13756e-03     Sparsity: 8.58560e-02     Total:  10.9617     mean:   2.2683
 Task losses:  12.5174     mean:   2.5035     Sharing: 7.44592e-03     Sparsity: 8.58332e-02     Total:  12.6107     mean:   2.5968
 Task losses:  14.6937     mean:   2.9387     Sharing: 7.15319e-03     Sparsity: 8.58324e-02     Total:  14.7867     mean:   3.0317
 Task losses:   9.9347     mean:   1.9869     Sharing: 4.90936e-03     Sparsity: 8.58311e-02     Total:  10.0254     mean:   2.0777
 Task losses:   4.2904     mean:   0.8581     Sharing: 4.70061e-03     Sparsity: 8.58262e-02     Total:   4.3810     mean:   0.9486
 Task losses:   4.0279     mean:   0.8056     Sharing: 5.64806e-03     Spars

 Task losses:   3.3976     mean:   0.6795     Sharing: 8.25192e-04     Sparsity: 8.52355e-02     Total:   3.4836     mean:   0.7656
 Task losses:   3.9599     mean:   0.7920     Sharing: 9.42578e-04     Sparsity: 8.52253e-02     Total:   4.0460     mean:   0.8781
 Task losses:   3.0342     mean:   0.6068     Sharing: 7.77841e-04     Sparsity: 8.52170e-02     Total:   3.1202     mean:   0.6928
 Task losses:   3.4731     mean:   0.6946     Sharing: 4.68254e-04     Sparsity: 8.52085e-02     Total:   3.5587     mean:   0.7803
 Task losses:   3.5824     mean:   0.7165     Sharing: 7.29750e-04     Sparsity: 8.52003e-02     Total:   3.6683     mean:   0.8024
 Task losses:   3.6943     mean:   0.7389     Sharing: 1.16270e-03     Sparsity: 8.51955e-02     Total:   3.7807     mean:   0.8252
 Task losses:   3.8926     mean:   0.7785     Sharing: 9.03770e-04     Sparsity: 8.51912e-02     Total:   3.9787     mean:   0.8646
 Task losses:   4.3986     mean:   0.8797     Sharing: 6.52651e-04     Spars

Epoch 92 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:92 iteration:  19864 -  Total Loss: 4.4422     Task Loss: 4.4422  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

92    | 0.00020 0.91082 0.78427 0.78384 0.74331|  4.3326  4.0616  5.2774 13.6716|   42.9|

Epoch 92 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.9582     mean:   0.7916     Sharing: 9.46611e-04     Sparsity: 8.50822e-02     Total:   4.0442     mean:   0.8777
 Task losses:   8.9895     mean:   1.7979     Sharing: 1.18768e-03     Sparsity: 8.50743e-02     Total:   9.0758     mean:   1.8842
 Task losses:  11.5034     mean:   2.3007     Sharing: 1.44217e-03     Sparsity: 8.50693e-02     Total:  11.5899     mean:   2.3872
 Task losses:  12.5421     mean:   2.5084     Sharing: 1.39930e-03     Sparsity: 8.50643e-02     Total:  12.6286     mean:   2.5949
 Task losses:  13.9752     mean:   2.7950     Sharing: 1.20323e-03     Sparsity: 8.50573e-02     Total:  14.0615     mean:   2.8813
 Task losses:  11.9963     mean:   2.3993     Sharing: 1.04434e-03     Sparsity: 8.50630e-02     Total:  12.0824     mean:   2.4854
 Task losses:   5.7401     mean:   1.1480     Sharing: 9.94851e-04     Sparsity: 8.50784e-02     Total:   5.8262     mean:   1.2341
 Task losses:   5.3169     mean:   1.0634     Sharing: 1.10982e-03     Spars

 Task losses:   5.0089     mean:   1.0018     Sharing: 5.51775e-04     Sparsity: 8.52005e-02     Total:   5.0947     mean:   1.0875
 Task losses:   5.6864     mean:   1.1373     Sharing: 7.72059e-04     Sparsity: 8.52015e-02     Total:   5.7724     mean:   1.2233
 Task losses:   4.5102     mean:   0.9020     Sharing: 7.91614e-04     Sparsity: 8.52038e-02     Total:   4.5962     mean:   0.9880
 Task losses:   5.3126     mean:   1.0625     Sharing: 6.45543e-04     Sparsity: 8.52052e-02     Total:   5.3984     mean:   1.1484
 Task losses:   5.0537     mean:   1.0107     Sharing: 6.39975e-04     Sparsity: 8.52057e-02     Total:   5.1395     mean:   1.0966
 Task losses:   4.9952     mean:   0.9990     Sharing: 5.08696e-04     Sparsity: 8.52063e-02     Total:   5.0809     mean:   1.0848
 Task losses:   5.2975     mean:   1.0595     Sharing: 7.66019e-04     Sparsity: 8.52062e-02     Total:   5.3834     mean:   1.1455
 Task losses:   5.1592     mean:   1.0318     Sharing: 8.29279e-04     Spars

Epoch 93 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:93 iteration:  20080 -  Total Loss: 4.0194     Task Loss: 4.0194  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

93    | 0.00019 0.84544 0.79107 0.78947 0.74750|  4.3039  3.6779  4.6830 12.6647|   44.8|

Epoch 93 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.7680     mean:   0.7536     Sharing: 9.95532e-04     Sparsity: 8.50866e-02     Total:   3.8541     mean:   0.8397
 Task losses:   9.2225     mean:   1.8445     Sharing: 1.00714e-03     Sparsity: 8.50817e-02     Total:   9.3086     mean:   1.9306
 Task losses:  11.3085     mean:   2.2617     Sharing: 1.17184e-03     Sparsity: 8.50745e-02     Total:  11.3948     mean:   2.3480
 Task losses:  14.1282     mean:   2.8256     Sharing: 1.09108e-03     Sparsity: 8.50669e-02     Total:  14.2143     mean:   2.9118
 Task losses:  14.9331     mean:   2.9866     Sharing: 8.20979e-04     Sparsity: 8.50641e-02     Total:  15.0190     mean:   3.0725
 Task losses:  11.9365     mean:   2.3873     Sharing: 8.76546e-04     Sparsity: 8.50605e-02     Total:  12.0225     mean:   2.4732
 Task losses:   5.2326     mean:   1.0465     Sharing: 1.20133e-03     Sparsity: 8.50613e-02     Total:   5.3189     mean:   1.1328
 Task losses:   4.0462     mean:   0.8092     Sharing: 1.21717e-03     Spars

 Task losses:   4.4062     mean:   0.8812     Sharing: 1.40105e-03     Sparsity: 8.50174e-02     Total:   4.4926     mean:   0.9677
 Task losses:   4.0680     mean:   0.8136     Sharing: 1.33399e-03     Sparsity: 8.50098e-02     Total:   4.1544     mean:   0.8999
 Task losses:   4.5665     mean:   0.9133     Sharing: 1.14398e-03     Sparsity: 8.50045e-02     Total:   4.6526     mean:   0.9994
 Task losses:   4.6400     mean:   0.9280     Sharing: 1.01873e-03     Sparsity: 8.49992e-02     Total:   4.7260     mean:   1.0140
 Task losses:   4.9214     mean:   0.9843     Sharing: 6.98149e-04     Sparsity: 8.49932e-02     Total:   5.0071     mean:   1.0700
 Task losses:   5.1073     mean:   1.0215     Sharing: 6.71158e-04     Sparsity: 8.49861e-02     Total:   5.1929     mean:   1.1071
 Task losses:   5.2369     mean:   1.0474     Sharing: 7.74408e-04     Sparsity: 8.49784e-02     Total:   5.3227     mean:   1.1331
 Task losses:   5.1274     mean:   1.0255     Sharing: 1.07644e-03     Spars

Epoch 94 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:94 iteration:  20296 -  Total Loss: 3.7114     Task Loss: 3.7114  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

94    | 0.00019 0.84785 0.78638 0.78732 0.74435|  3.9070  4.4191  4.3834 12.7095|   43.0|

Epoch 94 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.0687     mean:   0.6137     Sharing: 5.85988e-04     Sparsity: 8.46551e-02     Total:   3.1539     mean:   0.6990
 Task losses:   5.8543     mean:   1.1709     Sharing: 8.37058e-04     Sparsity: 8.46464e-02     Total:   5.9398     mean:   1.2563
 Task losses:   6.7002     mean:   1.3400     Sharing: 8.35871e-04     Sparsity: 8.46375e-02     Total:   6.7856     mean:   1.4255
 Task losses:   7.4760     mean:   1.4952     Sharing: 5.12972e-04     Sparsity: 8.46263e-02     Total:   7.5611     mean:   1.5803
 Task losses:   8.0323     mean:   1.6065     Sharing: 9.73612e-04     Sparsity: 8.46156e-02     Total:   8.1179     mean:   1.6920
 Task losses:   5.9114     mean:   1.1823     Sharing: 1.06111e-03     Sparsity: 8.46049e-02     Total:   5.9971     mean:   1.2679
 Task losses:   4.1043     mean:   0.8209     Sharing: 9.25745e-04     Sparsity: 8.45932e-02     Total:   4.1898     mean:   0.9064
 Task losses:   3.7366     mean:   0.7473     Sharing: 6.19814e-04     Spars

 Task losses:   3.8281     mean:   0.7656     Sharing: 6.99247e-04     Sparsity: 8.43425e-02     Total:   3.9131     mean:   0.8507
 Task losses:   3.6487     mean:   0.7297     Sharing: 7.51808e-04     Sparsity: 8.43405e-02     Total:   3.7338     mean:   0.8148
 Task losses:   3.5857     mean:   0.7171     Sharing: 9.05360e-04     Sparsity: 8.43381e-02     Total:   3.6709     mean:   0.8024
 Task losses:   4.1016     mean:   0.8203     Sharing: 1.05065e-03     Sparsity: 8.43354e-02     Total:   4.1870     mean:   0.9057
 Task losses:   3.8492     mean:   0.7698     Sharing: 8.70456e-04     Sparsity: 8.43326e-02     Total:   3.9344     mean:   0.8551
 Task losses:   4.3238     mean:   0.8648     Sharing: 8.17880e-04     Sparsity: 8.43300e-02     Total:   4.4090     mean:   0.9499
 Task losses:   4.4298     mean:   0.8860     Sharing: 1.04327e-03     Sparsity: 8.43269e-02     Total:   4.5152     mean:   0.9713
 Task losses:   5.2976     mean:   1.0595     Sharing: 1.08108e-03     Spars

Epoch 95 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:95 iteration:  20512 -  Total Loss: 4.2876     Task Loss: 4.2876  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

95    | 0.00020 0.89327 0.78751 0.78998 0.74533|  3.7913  4.4640  5.1188 13.3741|   43.6|

Epoch 95 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.7746     mean:   0.7549     Sharing: 1.20285e-03     Sparsity: 8.41903e-02     Total:   3.8600     mean:   0.8403
 Task losses:   5.8392     mean:   1.1678     Sharing: 1.33109e-03     Sparsity: 8.41851e-02     Total:   5.9247     mean:   1.2534
 Task losses:   6.0047     mean:   1.2009     Sharing: 1.27830e-03     Sparsity: 8.41781e-02     Total:   6.0901     mean:   1.2864
 Task losses:   6.6186     mean:   1.3237     Sharing: 1.09242e-03     Sparsity: 8.41666e-02     Total:   6.7039     mean:   1.4090
 Task losses:   7.5628     mean:   1.5126     Sharing: 7.85589e-04     Sparsity: 8.41550e-02     Total:   7.6478     mean:   1.5975
 Task losses:   6.4685     mean:   1.2937     Sharing: 6.33642e-04     Sparsity: 8.41432e-02     Total:   6.5533     mean:   1.3785
 Task losses:   4.4994     mean:   0.8999     Sharing: 1.02503e-03     Sparsity: 8.41314e-02     Total:   4.5845     mean:   0.9850
 Task losses:   3.9303     mean:   0.7861     Sharing: 1.27233e-03     Spars

 Task losses:   3.8732     mean:   0.7746     Sharing: 1.12573e-03     Sparsity: 8.39768e-02     Total:   3.9583     mean:   0.8597
 Task losses:   3.6350     mean:   0.7270     Sharing: 1.10237e-03     Sparsity: 8.39719e-02     Total:   3.7200     mean:   0.8121
 Task losses:   3.5055     mean:   0.7011     Sharing: 1.32541e-03     Sparsity: 8.39668e-02     Total:   3.5908     mean:   0.7864
 Task losses:   4.0102     mean:   0.8020     Sharing: 1.09463e-03     Sparsity: 8.39620e-02     Total:   4.0953     mean:   0.8871
 Task losses:   3.6051     mean:   0.7210     Sharing: 9.72425e-04     Sparsity: 8.39571e-02     Total:   3.6900     mean:   0.8059
 Task losses:   4.2718     mean:   0.8544     Sharing: 9.36757e-04     Sparsity: 8.39522e-02     Total:   4.3567     mean:   0.9392
 Task losses:   4.4117     mean:   0.8823     Sharing: 1.12059e-03     Sparsity: 8.39463e-02     Total:   4.4968     mean:   0.9674
 Task losses:   4.1778     mean:   0.8356     Sharing: 9.23251e-04     Spars

Epoch 96 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:96 iteration:  20728 -  Total Loss: 4.5063     Task Loss: 4.5063  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

96    | 0.00018 0.83624 0.78492 0.78517 0.74291|  3.6511  4.0906  4.7815 12.5232|   42.5|

Epoch 96 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.3224     mean:   0.6645     Sharing: 8.09977e-04     Sparsity: 8.36984e-02     Total:   3.4069     mean:   0.7490
 Task losses:   5.8843     mean:   1.1769     Sharing: 1.25872e-03     Sparsity: 8.36920e-02     Total:   5.9692     mean:   1.2618
 Task losses:   6.3538     mean:   1.2708     Sharing: 1.00536e-03     Sparsity: 8.36848e-02     Total:   6.4385     mean:   1.3555
 Task losses:   7.1709     mean:   1.4342     Sharing: 1.06171e-03     Sparsity: 8.36765e-02     Total:   7.2556     mean:   1.5189
 Task losses:   7.6255     mean:   1.5251     Sharing: 1.22727e-03     Sparsity: 8.36679e-02     Total:   7.7104     mean:   1.6100
 Task losses:   6.1162     mean:   1.2232     Sharing: 1.01544e-03     Sparsity: 8.36578e-02     Total:   6.2009     mean:   1.3079
 Task losses:   4.2099     mean:   0.8420     Sharing: 6.90048e-04     Sparsity: 8.36478e-02     Total:   4.2943     mean:   0.9263
 Task losses:   3.6145     mean:   0.7229     Sharing: 9.60668e-04     Spars

 Task losses:   3.5285     mean:   0.7057     Sharing: 1.26601e-03     Sparsity: 8.31375e-02     Total:   3.6129     mean:   0.7901
 Task losses:   4.4005     mean:   0.8801     Sharing: 1.16292e-03     Sparsity: 8.31299e-02     Total:   4.4848     mean:   0.9644
 Task losses:   2.8890     mean:   0.5778     Sharing: 7.22090e-04     Sparsity: 8.31222e-02     Total:   2.9728     mean:   0.6616
 Task losses:   3.5187     mean:   0.7037     Sharing: 8.03386e-04     Sparsity: 8.31146e-02     Total:   3.6026     mean:   0.7876
 Task losses:   3.4588     mean:   0.6918     Sharing: 7.33187e-04     Sparsity: 8.31063e-02     Total:   3.5426     mean:   0.7756
 Task losses:   3.4281     mean:   0.6856     Sharing: 7.13617e-04     Sparsity: 8.30981e-02     Total:   3.5119     mean:   0.7694
 Task losses:   3.5305     mean:   0.7061     Sharing: 7.93278e-04     Sparsity: 8.30896e-02     Total:   3.6143     mean:   0.7900
 Task losses:   3.2962     mean:   0.6592     Sharing: 7.85366e-04     Spars

Epoch 97 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:97 iteration:  20944 -  Total Loss: 4.8378     Task Loss: 4.8378  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

97    | 0.00018 0.80685 0.79127 0.79292 0.74530|  3.7145  4.2813  4.0950 12.0908|   46.3|

Epoch 97 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.7509     mean:   0.7502     Sharing: 1.07665e-03     Sparsity: 8.28006e-02     Total:   3.8348     mean:   0.8341
 Task losses:   6.9258     mean:   1.3852     Sharing: 1.15331e-03     Sparsity: 8.27941e-02     Total:   7.0097     mean:   1.4691
 Task losses:   7.6414     mean:   1.5283     Sharing: 7.95851e-04     Sparsity: 8.27875e-02     Total:   7.7250     mean:   1.6119
 Task losses:   9.3072     mean:   1.8614     Sharing: 9.75276e-04     Sparsity: 8.27861e-02     Total:   9.3910     mean:   1.9452
 Task losses:   8.3766     mean:   1.6753     Sharing: 9.59272e-04     Sparsity: 8.27845e-02     Total:   8.4603     mean:   1.7591
 Task losses:   8.5163     mean:   1.7033     Sharing: 9.78177e-04     Sparsity: 8.27909e-02     Total:   8.6001     mean:   1.7870
 Task losses:   5.0745     mean:   1.0149     Sharing: 7.00861e-04     Sparsity: 8.27961e-02     Total:   5.1580     mean:   1.0984
 Task losses:   4.1417     mean:   0.8283     Sharing: 7.28061e-04     Spars

 Task losses:   3.4193     mean:   0.6839     Sharing: 6.32798e-04     Sparsity: 8.24285e-02     Total:   3.5024     mean:   0.7669
 Task losses:   4.0929     mean:   0.8186     Sharing: 5.32776e-04     Sparsity: 8.24197e-02     Total:   4.1759     mean:   0.9015
 Task losses:   2.9908     mean:   0.5982     Sharing: 4.91117e-04     Sparsity: 8.24105e-02     Total:   3.0737     mean:   0.6811
 Task losses:   3.5965     mean:   0.7193     Sharing: 6.61766e-04     Sparsity: 8.24018e-02     Total:   3.6796     mean:   0.8024
 Task losses:   4.0730     mean:   0.8146     Sharing: 6.74392e-04     Sparsity: 8.23929e-02     Total:   4.1560     mean:   0.8977
 Task losses:   4.4032     mean:   0.8806     Sharing: 5.46381e-04     Sparsity: 8.23844e-02     Total:   4.4862     mean:   0.9636
 Task losses:   4.2797     mean:   0.8559     Sharing: 8.94000e-04     Sparsity: 8.23765e-02     Total:   4.3630     mean:   0.9392
 Task losses:   3.6841     mean:   0.7368     Sharing: 9.94508e-04     Spars

Epoch 98 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:98 iteration:  21160 -  Total Loss: 3.7264     Task Loss: 3.7264  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

98    | 0.00019 0.86661 0.78437 0.78895 0.74265|  3.6955  4.9663  4.3294 12.9913|   44.2|

Epoch 98 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   4.0419     mean:   0.8084     Sharing: 7.18882e-04     Sparsity: 8.21789e-02     Total:   4.1248     mean:   0.8913
 Task losses:   8.3626     mean:   1.6725     Sharing: 7.88113e-04     Sparsity: 8.21719e-02     Total:   8.4456     mean:   1.7555
 Task losses:  10.8716     mean:   2.1743     Sharing: 6.78360e-04     Sparsity: 8.21642e-02     Total:  10.9545     mean:   2.2572
 Task losses:  11.7638     mean:   2.3528     Sharing: 1.00910e-03     Sparsity: 8.21584e-02     Total:  11.8470     mean:   2.4359
 Task losses:  12.3059     mean:   2.4612     Sharing: 9.82597e-04     Sparsity: 8.21551e-02     Total:  12.3891     mean:   2.5443
 Task losses:   9.3325     mean:   1.8665     Sharing: 8.64094e-04     Sparsity: 8.21543e-02     Total:   9.4155     mean:   1.9495
 Task losses:   5.2485     mean:   1.0497     Sharing: 9.01967e-04     Sparsity: 8.21523e-02     Total:   5.3315     mean:   1.1328
 Task losses:   5.0151     mean:   1.0030     Sharing: 9.05171e-04     Spars

 Task losses:   4.4813     mean:   0.8963     Sharing: 9.36811e-04     Sparsity: 8.26008e-02     Total:   4.5648     mean:   0.9798
 Task losses:   5.5044     mean:   1.1009     Sharing: 9.67577e-04     Sparsity: 8.25938e-02     Total:   5.5880     mean:   1.1845
 Task losses:   3.6626     mean:   0.7325     Sharing: 8.23552e-04     Sparsity: 8.25879e-02     Total:   3.7460     mean:   0.8159
 Task losses:   4.1104     mean:   0.8221     Sharing: 7.17928e-04     Sparsity: 8.25819e-02     Total:   4.1937     mean:   0.9054
 Task losses:   3.9909     mean:   0.7982     Sharing: 7.63993e-04     Sparsity: 8.25772e-02     Total:   4.0742     mean:   0.8815
 Task losses:   4.8579     mean:   0.9716     Sharing: 7.55524e-04     Sparsity: 8.25721e-02     Total:   4.9412     mean:   1.0549
 Task losses:   5.4040     mean:   1.0808     Sharing: 8.22489e-04     Sparsity: 8.25660e-02     Total:   5.4874     mean:   1.1642
 Task losses:   3.9320     mean:   0.7864     Sharing: 8.10425e-04     Spars

Epoch 99 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:99 iteration:  21376 -  Total Loss: 4.6281     Task Loss: 4.6281  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

99    | 0.00019 0.86339 0.79073 0.79221 0.74475|  4.7722  4.3071  3.8614 12.9407|   43.0|

Epoch 99 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   4.6270     mean:   0.9254     Sharing: 8.21814e-04     Sparsity: 8.22942e-02     Total:   4.7101     mean:   1.0085
 Task losses:   7.7048     mean:   1.5410     Sharing: 1.18157e-03     Sparsity: 8.22877e-02     Total:   7.7882     mean:   1.6244
 Task losses:   7.9556     mean:   1.5911     Sharing: 1.20430e-03     Sparsity: 8.22818e-02     Total:   8.0391     mean:   1.6746
 Task losses:   9.1295     mean:   1.8259     Sharing: 8.41598e-04     Sparsity: 8.22742e-02     Total:   9.2126     mean:   1.9090
 Task losses:   8.3333     mean:   1.6667     Sharing: 7.13517e-04     Sparsity: 8.22656e-02     Total:   8.4163     mean:   1.7496
 Task losses:   9.0973     mean:   1.8195     Sharing: 9.35555e-04     Sparsity: 8.22580e-02     Total:   9.1805     mean:   1.9027
 Task losses:   4.8888     mean:   0.9778     Sharing: 7.97714e-04     Sparsity: 8.22510e-02     Total:   4.9719     mean:   1.0608
 Task losses:   4.2293     mean:   0.8459     Sharing: 6.19397e-04     Spars

 Task losses:   4.4612     mean:   0.8922     Sharing: 9.83362e-04     Sparsity: 8.21393e-02     Total:   4.5444     mean:   0.9754
 Task losses:   5.5442     mean:   1.1088     Sharing: 6.80392e-04     Sparsity: 8.21354e-02     Total:   5.6270     mean:   1.1917
 Task losses:   3.7352     mean:   0.7470     Sharing: 6.23062e-04     Sparsity: 8.21315e-02     Total:   3.8180     mean:   0.8298
 Task losses:   4.8199     mean:   0.9640     Sharing: 8.51884e-04     Sparsity: 8.21268e-02     Total:   4.9029     mean:   1.0470
 Task losses:   3.8497     mean:   0.7699     Sharing: 1.08212e-03     Sparsity: 8.21218e-02     Total:   3.9329     mean:   0.8531
 Task losses:   4.1854     mean:   0.8371     Sharing: 8.52172e-04     Sparsity: 8.21165e-02     Total:   4.2683     mean:   0.9200
 Task losses:   4.0929     mean:   0.8186     Sharing: 7.41656e-04     Sparsity: 8.21109e-02     Total:   4.1758     mean:   0.9014
 Task losses:   4.0369     mean:   0.8074     Sharing: 8.84026e-04     Spars

Epoch 100 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:100 iteration:  21592 -  Total Loss: 4.8387     Task Loss: 4.8387  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

100   | 0.00020 0.93685 0.78437 0.78287 0.74405|  4.4805  4.8983  4.6455 14.0243|   45.7|

Epoch 100 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   4.4450     mean:   0.8890     Sharing: 8.61471e-04     Sparsity: 8.21090e-02     Total:   4.5280     mean:   0.9720
 Task losses:   9.1878     mean:   1.8376     Sharing: 9.71004e-04     Sparsity: 8.21058e-02     Total:   9.2709     mean:   1.9206
 Task losses:  11.6705     mean:   2.3341     Sharing: 1.20269e-03     Sparsity: 8.21034e-02     Total:  11.7538     mean:   2.4174
 Task losses:  12.6444     mean:   2.5289     Sharing: 1.16858e-03     Sparsity: 8.21079e-02     Total:  12.7277     mean:   2.6122
 Task losses:  14.6914     mean:   2.9383     Sharing: 8.27049e-04     Sparsity: 8.21132e-02     Total:  14.7743     mean:   3.0212
 Task losses:  10.3994     mean:   2.0799     Sharing: 9.17474e-04     Sparsity: 8.21198e-02     Total:  10.4825     mean:   2.1629
 Task losses:   6.2193     mean:   1.2439     Sharing: 1.22700e-03     Sparsity: 8.21259e-02     Total:   6.3026     mean:   1.3272
 Task losses:   4.8700     mean:   0.9740     Sharing: 1.24199e-03     Spars

 Task losses:   5.1636     mean:   1.0327     Sharing: 9.72634e-04     Sparsity: 8.18596e-02     Total:   5.2464     mean:   1.1156
 Task losses:   6.8712     mean:   1.3742     Sharing: 1.12185e-03     Sparsity: 8.18506e-02     Total:   6.9542     mean:   1.4572
 Task losses:   4.4700     mean:   0.8940     Sharing: 1.24921e-03     Sparsity: 8.18418e-02     Total:   4.5531     mean:   0.9771
 Task losses:   4.8284     mean:   0.9657     Sharing: 1.08606e-03     Sparsity: 8.18330e-02     Total:   4.9113     mean:   1.0486
 Task losses:   5.1914     mean:   1.0383     Sharing: 7.19900e-04     Sparsity: 8.18245e-02     Total:   5.2739     mean:   1.1208
 Task losses:   5.1497     mean:   1.0299     Sharing: 1.01633e-03     Sparsity: 8.18156e-02     Total:   5.2325     mean:   1.1128
 Task losses:   5.1120     mean:   1.0224     Sharing: 1.36180e-03     Sparsity: 8.18069e-02     Total:   5.1952     mean:   1.1056
 Task losses:   5.9411     mean:   1.1882     Sharing: 1.20807e-03     Spars

Epoch 101 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:101 iteration:  21808 -  Total Loss: 5.5946     Task Loss: 5.5946  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

101   | 0.00019 0.87156 0.78725 0.79004 0.74314|  3.9025  4.2482  4.9081 13.0588|   42.5|

Epoch 101 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   4.1601     mean:   0.8320     Sharing: 1.09674e-03     Sparsity: 8.16352e-02     Total:   4.2428     mean:   0.9147
 Task losses:   9.4747     mean:   1.8949     Sharing: 8.21466e-04     Sparsity: 8.16301e-02     Total:   9.5571     mean:   1.9774
 Task losses:  11.0111     mean:   2.2022     Sharing: 7.27182e-04     Sparsity: 8.16287e-02     Total:  11.0935     mean:   2.2846
 Task losses:  12.6995     mean:   2.5399     Sharing: 9.04307e-04     Sparsity: 8.16266e-02     Total:  12.7820     mean:   2.6224
 Task losses:  11.9265     mean:   2.3853     Sharing: 8.78563e-04     Sparsity: 8.16244e-02     Total:  12.0090     mean:   2.4678
 Task losses:  10.1910     mean:   2.0382     Sharing: 8.75389e-04     Sparsity: 8.16226e-02     Total:  10.2735     mean:   2.1207
 Task losses:   5.1352     mean:   1.0270     Sharing: 6.33265e-04     Sparsity: 8.16215e-02     Total:   5.2174     mean:   1.1093
 Task losses:   4.1763     mean:   0.8353     Sharing: 7.45629e-04     Spars

 Task losses:   4.4492     mean:   0.8898     Sharing: 6.69410e-04     Sparsity: 8.13775e-02     Total:   4.5313     mean:   0.9719
 Task losses:   5.0914     mean:   1.0183     Sharing: 9.65501e-04     Sparsity: 8.13746e-02     Total:   5.1737     mean:   1.1006
 Task losses:   4.0752     mean:   0.8150     Sharing: 1.08667e-03     Sparsity: 8.13716e-02     Total:   4.1577     mean:   0.8975
 Task losses:   4.8531     mean:   0.9706     Sharing: 1.33260e-03     Sparsity: 8.13701e-02     Total:   4.9358     mean:   1.0533
 Task losses:   4.7712     mean:   0.9542     Sharing: 1.18672e-03     Sparsity: 8.13684e-02     Total:   4.8537     mean:   1.0368
 Task losses:   5.1639     mean:   1.0328     Sharing: 6.97161e-04     Sparsity: 8.13664e-02     Total:   5.2460     mean:   1.1148
 Task losses:   5.8070     mean:   1.1614     Sharing: 6.70224e-04     Sparsity: 8.13639e-02     Total:   5.8890     mean:   1.2434
 Task losses:   5.1561     mean:   1.0312     Sharing: 9.50396e-04     Spars

Epoch 102 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:102 iteration:  22024 -  Total Loss: 3.7359     Task Loss: 3.7359  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

102   | 0.00019 0.84758 0.78752 0.78717 0.74562|  4.2292  3.9030  4.5807 12.7129|   43.5|

Epoch 102 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.3565     mean:   0.6713     Sharing: 8.78170e-04     Sparsity: 8.13084e-02     Total:   3.4387     mean:   0.7535
 Task losses:   7.3708     mean:   1.4742     Sharing: 8.75130e-04     Sparsity: 8.13058e-02     Total:   7.4530     mean:   1.5563
 Task losses:  11.1368     mean:   2.2274     Sharing: 4.19224e-04     Sparsity: 8.12994e-02     Total:  11.2185     mean:   2.3091
 Task losses:  10.3907     mean:   2.0781     Sharing: 6.54946e-04     Sparsity: 8.12876e-02     Total:  10.4726     mean:   2.1601
 Task losses:  13.3828     mean:   2.6766     Sharing: 8.76461e-04     Sparsity: 8.12716e-02     Total:  13.4650     mean:   2.7587
 Task losses:   7.8359     mean:   1.5672     Sharing: 8.75687e-04     Sparsity: 8.12508e-02     Total:   7.9181     mean:   1.6493
 Task losses:   4.4522     mean:   0.8904     Sharing: 7.40131e-04     Sparsity: 8.12290e-02     Total:   4.5341     mean:   0.9724
 Task losses:   4.1412     mean:   0.8282     Sharing: 9.71983e-04     Spars

 Task losses:   3.9833     mean:   0.7967     Sharing: 1.34855e-03     Sparsity: 8.07872e-02     Total:   4.0654     mean:   0.8788
 Task losses:   3.7994     mean:   0.7599     Sharing: 9.99639e-04     Sparsity: 8.07799e-02     Total:   3.8811     mean:   0.8417
 Task losses:   3.4955     mean:   0.6991     Sharing: 5.63815e-04     Sparsity: 8.07726e-02     Total:   3.5768     mean:   0.7804
 Task losses:   4.3904     mean:   0.8781     Sharing: 1.16497e-03     Sparsity: 8.07672e-02     Total:   4.4724     mean:   0.9600
 Task losses:   5.3371     mean:   1.0674     Sharing: 1.34793e-03     Sparsity: 8.07605e-02     Total:   5.4192     mean:   1.1495
 Task losses:   4.9900     mean:   0.9980     Sharing: 1.12594e-03     Sparsity: 8.07541e-02     Total:   5.0719     mean:   1.0799
 Task losses:   5.2456     mean:   1.0491     Sharing: 9.43904e-04     Sparsity: 8.07478e-02     Total:   5.3273     mean:   1.1308
 Task losses:   5.0081     mean:   1.0016     Sharing: 9.94975e-04     Spars

Epoch 103 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:103 iteration:  22240 -  Total Loss: 3.8568     Task Loss: 3.8568  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

103   | 0.00019 0.87391 0.78622 0.79087 0.74204|  4.5421  3.8349  4.7219 13.0989|   42.4|

Epoch 103 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   4.2372     mean:   0.8474     Sharing: 1.52316e-03     Sparsity: 8.04470e-02     Total:   4.3192     mean:   0.9294
 Task losses:   8.4455     mean:   1.6891     Sharing: 1.43190e-03     Sparsity: 8.04405e-02     Total:   8.5274     mean:   1.7710
 Task losses:  10.0679     mean:   2.0136     Sharing: 1.38201e-03     Sparsity: 8.04357e-02     Total:  10.1497     mean:   2.0954
 Task losses:  11.1888     mean:   2.2378     Sharing: 1.30903e-03     Sparsity: 8.04335e-02     Total:  11.2706     mean:   2.3195
 Task losses:  11.5330     mean:   2.3066     Sharing: 1.34289e-03     Sparsity: 8.04312e-02     Total:  11.6148     mean:   2.3884
 Task losses:  10.3756     mean:   2.0751     Sharing: 1.41573e-03     Sparsity: 8.04285e-02     Total:  10.4574     mean:   2.1570
 Task losses:   4.7048     mean:   0.9410     Sharing: 1.42401e-03     Sparsity: 8.04298e-02     Total:   4.7866     mean:   1.0228
 Task losses:   4.1344     mean:   0.8269     Sharing: 1.22823e-03     Spars

 Task losses:   4.2551     mean:   0.8510     Sharing: 1.18944e-03     Sparsity: 7.99309e-02     Total:   4.3362     mean:   0.9321
 Task losses:   4.3149     mean:   0.8630     Sharing: 1.07324e-03     Sparsity: 7.99253e-02     Total:   4.3959     mean:   0.9440
 Task losses:   3.8798     mean:   0.7760     Sharing: 6.36339e-04     Sparsity: 7.99201e-02     Total:   3.9603     mean:   0.8565
 Task losses:   4.7812     mean:   0.9562     Sharing: 8.93002e-04     Sparsity: 7.99145e-02     Total:   4.8620     mean:   1.0370
 Task losses:   4.9807     mean:   0.9961     Sharing: 9.89825e-04     Sparsity: 7.99091e-02     Total:   5.0616     mean:   1.0770
 Task losses:   5.0063     mean:   1.0013     Sharing: 9.42220e-04     Sparsity: 7.99034e-02     Total:   5.0872     mean:   1.0821
 Task losses:   5.0777     mean:   1.0155     Sharing: 7.72655e-04     Sparsity: 7.98974e-02     Total:   5.1583     mean:   1.0962
 Task losses:   5.1381     mean:   1.0276     Sharing: 5.67069e-04     Spars

Epoch 104 weight training:   0%|          | 0/108 [00:00<?, ?it/s]

 current_iter_w: 108   stop_iter_w: 108   (are equal)
[e]Weight training epoch:104 iteration:  22456 -  Total Loss: 4.5576     Task Loss: 4.5576  


validation:   0%|          | 0/36 [00:00<?, ?it/s]

104   | 0.00018 0.81779 0.78394 0.78760 0.74120|  3.6731  4.3750  4.2135 12.2616|   41.9|

Epoch 104 policy training:   0%|          | 0/108 [00:00<?, ?it/s]

 Task losses:   3.4041     mean:   0.6808     Sharing: 6.50709e-04     Sparsity: 7.95896e-02     Total:   3.4844     mean:   0.7611
 Task losses:   5.6915     mean:   1.1383     Sharing: 7.12931e-04     Sparsity: 7.95842e-02     Total:   5.7718     mean:   1.2186
 Task losses:   7.5775     mean:   1.5155     Sharing: 7.40478e-04     Sparsity: 7.95749e-02     Total:   7.6578     mean:   1.5958
 Task losses:   6.6792     mean:   1.3358     Sharing: 7.20481e-04     Sparsity: 7.95771e-02     Total:   6.7595     mean:   1.4161
 Task losses:   7.1678     mean:   1.4336     Sharing: 6.73950e-04     Sparsity: 7.95788e-02     Total:   7.2480     mean:   1.5138
 Task losses:   6.8238     mean:   1.3648     Sharing: 6.92169e-04     Sparsity: 7.95800e-02     Total:   6.9041     mean:   1.4450
 Task losses:   3.9758     mean:   0.7952     Sharing: 5.95644e-04     Sparsity: 7.95836e-02     Total:   4.0560     mean:   0.8753
 Task losses:   3.5648     mean:   0.7130     Sharing: 5.19072e-04     Spars

 Task losses:   3.7859     mean:   0.7572     Sharing: 9.36776e-04     Sparsity: 7.94821e-02     Total:   3.8663     mean:   0.8376
 Task losses:   3.6483     mean:   0.7297     Sharing: 8.26637e-04     Sparsity: 7.94818e-02     Total:   3.7286     mean:   0.8100
 Task losses:   2.9111     mean:   0.5822     Sharing: 9.38058e-04     Sparsity: 7.94808e-02     Total:   2.9916     mean:   0.6626
 Task losses:   3.6320     mean:   0.7264     Sharing: 9.92383e-04     Sparsity: 7.94809e-02     Total:   3.7125     mean:   0.8069
 Task losses:   4.3803     mean:   0.8761     Sharing: 1.02717e-03     Sparsity: 7.94803e-02     Total:   4.4608     mean:   0.9566
 Task losses:   4.0451     mean:   0.8090     Sharing: 9.20778e-04     Sparsity: 7.94794e-02     Total:   4.1255     mean:   0.8894
 Task losses:   4.1916     mean:   0.8383     Sharing: 9.38202e-04     Sparsity: 7.94781e-02     Total:   4.2720     mean:   0.9187
 Task losses:   3.9677     mean:   0.7935     Sharing: 1.06413e-03     Spars

KeyboardInterrupt: 

In [None]:
[e]Policy training epoch:103 iteration:  11116 -  Total Loss: 11.3055     Task Loss: 11.2034  Policy Losses:  Sparsity: 0.1021      Sharing: 4.42266e-06 
[c]Policy training epoch:103 iteration:  11016 -  Total Loss: 11.2994     Task Loss: 11.1965  Policy Losses:  Sparsity: 0.1029      Sharing: 5.42402e-06 
[c]Policy training epoch:104 iteration:  11124 -  Total Loss: 11.2998     Task Loss: 11.1978  Policy Losses:  Sparsity: 0.1020      Sharing: 6.46710e-06                            
[e]Policy training epoch:104 iteration:  11224 -  Total Loss: 11.3005     Task Loss: 11.1994  Policy Losses:  Sparsity: 0.1012      Sharing: 7.33137e-06        
[c]Policy training epoch:105 iteration:  11232 -  Total Loss: 11.3018     Task Loss: 11.2007  Policy Losses:  Sparsity: 0.1011      Sharing: 8.82745e-06         
[e]Policy training epoch:105 iteration:  11332 -  Total Loss: 11.2967     Task Loss: 11.1964  Policy Losses:  Sparsity: 0.1003      Sharing: 3.92199e-06       
[c]Policy training epoch:106 iteration:  11340 -  Total Loss: 11.3006     Task Loss: 11.2004  Policy Losses:  Sparsity: 0.1002      Sharing: 5.03659e-06     
[e]Policy training epoch:106 iteration:  11440 -  Total Loss: 11.2970     Task Loss: 11.1976  Policy Losses:  Sparsity: 0.0994      Sharing: 6.71148e-06


In [None]:
[c]Policy training epoch:61 iteration:  6588 -  Total Loss: 9.6820     Task Loss: 9.5780  Policy Losses:  Sparsity: 0.1039      Sharing: 1.73569e-05 
[e]Policy training epoch:61 iteration:  6688 -  Total Loss: 9.6665     Task Loss: 9.5635  Policy Losses:  Sparsity: 0.1030      Sharing: 3.66569e-06 
[c]Policy training epoch:62 iteration:  6696 -  Total Loss: 9.6642     Task Loss: 9.5612  Policy Losses:  Sparsity: 0.1030      Sharing: 7.31945e-06 
[e]Policy training epoch:62 iteration:  6796 -  Total Loss: 9.6599     Task Loss: 9.5578  Policy Losses:  Sparsity: 0.1021      Sharing: 1.82986e-06 
[e]Policy training epoch:63 iteration:  6904 -  Total Loss: 9.6682     Task Loss: 9.5670  Policy Losses:  Sparsity: 0.1012      Sharing: 9.05991e-06 
[c]Policy training epoch:64 iteration:  6912 -  Total Loss: 9.6548     Task Loss: 9.5537  Policy Losses:  Sparsity: 0.1011      Sharing: 9.95398e-06 
[e]Policy training epoch:64 iteration:  7012 -  Total Loss: 9.6678     Task Loss: 9.5675  Policy Losses:  Sparsity: 0.1003      Sharing: 5.03063e-06
[c]Policy training epoch:65 iteration:  7020 -  Total Loss: 9.6578     Task Loss: 9.5576  Policy Losses:  Sparsity: 0.1002      Sharing: 2.67029e-06
[e]Policy training epoch:65 iteration:  7120 -  Total Loss: 9.6335     Task Loss: 9.5341  Policy Losses:  Sparsity: 0.0994      Sharing: 5.96642e-06 
[c]Policy training epoch:66 iteration:  7128 -  Total Loss: 9.6501     Task Loss: 9.5507  Policy Losses:  Sparsity: 0.0993      Sharing: 4.99487e-06  
[e]Policy training epoch:66 iteration:  7228 -  Total Loss: 9.6556     Task Loss: 9.5571  Policy Losses:  Sparsity: 0.0985      Sharing: 3.95775e-06 

In [35]:
print_dbg(np.concatenate(environ.get_policy_prob(), axis=-1), verbose = True)
print(num_train_layers)
print(opt['is_curriculum'])
print(p_epoch, opt['curriculum_speed'], (p_epoch // opt['curriculum_speed']), (p_epoch // opt['curriculum_speed'])  + 1)

[[0.46973792 0.530262   0.4775429  0.5224572  0.4868897  0.5131103 ]
 [0.45025694 0.549743   0.45825934 0.54174066 0.47107342 0.5289266 ]
 [0.4443086  0.5556915  0.45530966 0.5446904  0.45748708 0.5425128 ]
 [0.4138397  0.58616036 0.43196857 0.5680315  0.42434993 0.5756501 ]
 [0.4140113  0.5859887  0.43017322 0.5698268  0.4313186  0.56868154]
 [0.42114905 0.57885087 0.4333356  0.5666644  0.4339512  0.56604874]]
20
True
60 3 20 21


### Post Training stuff

In [187]:
# environ.opt['train']['Lambda_sharing'] = 0.5
# opt['train']['Lambda_sharing'] = 0.5

environ.opt['train']['policy_lr'] = 0.001
opt['train']['policy_lr'] = 0.001

In [188]:
print(opt['diff_sparsity_weights'])
print(opt['is_sharing'])
print(opt['diff_sparsity_weights'] and not opt['is_sharing'])
print(environ.opt['train']['Lambda_sharing'])
print(opt['train']['Lambda_sharing'])
print(environ.opt['train']['Lambda_sparsity'])
print(opt['train']['Lambda_sparsity'])
print(environ.opt['train']['policy_lr'])
print(opt['train']['policy_lr'])

True
True
False
0.5
0.5
0.05
0.05
0.001
0.001


In [191]:
import torch.optim as optim
arch_parameters      = environ.get_arch_parameters()
print(arch_parameters)

[Parameter containing:
tensor([[0.3342, 0.4445],
        [0.3510, 0.5543],
        [0.3844, 0.6266],
        [0.4167, 0.8072],
        [0.4499, 0.8514],
        [0.4822, 0.8833]], device='cuda:0'), Parameter containing:
tensor([[0.3342, 0.4169],
        [0.3510, 0.5310],
        [0.3845, 0.6013],
        [0.4167, 0.7355],
        [0.4499, 0.7866],
        [0.4822, 0.8466]], device='cuda:0'), Parameter containing:
tensor([[0.3342, 0.3876],
        [0.3510, 0.4803],
        [0.3845, 0.5804],
        [0.4167, 0.7637],
        [0.4499, 0.7782],
        [0.4822, 0.8540]], device='cuda:0')]


In [192]:
environ.optimizers['alphas'] = optim.Adam(arch_parameters, lr=environ.opt['train']['policy_lr'], weight_decay=5*1e-4)

#### Sample Policies

In [170]:
logs = environ.get_policy_logits()
for i in logs:
    print(i, '\n')

[[0.33060804 0.45180422]
 [0.3532204  0.5528529 ]
 [0.38884717 0.6125409 ]
 [0.4204008  0.76851547]
 [0.45199737 0.7994047 ]
 [0.48402908 0.8020872 ]] 

[[0.33064184 0.42053092]
 [0.3532089  0.52056104]
 [0.3888512  0.5680909 ]
 [0.42039296 0.694217  ]
 [0.4519742  0.73311865]
 [0.48401102 0.7522658 ]] 

[[0.33058274 0.38303584]
 [0.3532048  0.46904057]
 [0.38887233 0.5593353 ]
 [0.42041987 0.7253615 ]
 [0.45201355 0.7284871 ]
 [0.48402616 0.7497743 ]] 



In [169]:
pols = environ.sample_policy(hard_sampling = False)
for i in pols:
    print(i, '\n')

tensor([[0, 1],
        [0, 1],
        [1, 0],
        [0, 1],
        [0, 1],
        [1, 0]], device='cuda:0') 

tensor([[1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [0, 1],
        [0, 1]], device='cuda:0') 

tensor([[0, 1],
        [0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0]], device='cuda:0') 



In [177]:
pols  = (environ.get_current_policy())
for i in pols:
    print(i ,'\n')

[[0.12809551 0.8719045 ]
 [0.4568862  0.5431138 ]
 [0.4389936  0.5610064 ]
 [0.4739472  0.5260528 ]
 [0.44630152 0.5536985 ]
 [0.37613583 0.6238642 ]] 

[[0.2449313  0.7550687 ]
 [0.41031924 0.58968073]
 [0.3977335  0.60226655]
 [0.8489379  0.15106209]
 [0.3237851  0.6762149 ]
 [0.64922446 0.35077554]] 

[[0.41088596 0.589114  ]
 [0.32043332 0.6795667 ]
 [0.11188911 0.8881109 ]
 [0.27037877 0.72962123]
 [0.8158148  0.18418525]
 [0.5808661  0.41913387]] 



In [176]:
pols = environ.networks['mtl-net'].test_sample_policy(hard_sampling = False, verbose = True)
print(type(pols))
# for i in pols:
#     print(i, '\n')

 MTL3_Dev test_sample_policy() START -  hard_sampling: False

 task1 logits
--------------
 [[0.33060804 0.45180422]
 [0.3532204  0.5528529 ]
 [0.38884717 0.6125409 ]
 [0.4204008  0.76851547]
 [0.45199737 0.7994047 ]
 [0.48402908 0.8020872 ]]

  task1 softmax:
-----------------
 [[0.46973792 0.530262  ]
 [0.45025694 0.549743  ]
 [0.4443086  0.5556915 ]
 [0.4138397  0.58616036]
 [0.4140113  0.5859887 ]
 [0.42114905 0.57885087]]

 task 1 sampled policy :
-------------------------
 tensor([[0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [1, 0]], device='cuda:0')


 task2 logits
--------------
 [[0.33064184 0.42053092]
 [0.3532089  0.52056104]
 [0.3888512  0.5680909 ]
 [0.42039296 0.694217  ]
 [0.4519742  0.73311865]
 [0.48401102 0.7522658 ]]

  task2 softmax:
-----------------
 [[0.4775429  0.5224572 ]
 [0.45825934 0.54174066]
 [0.45530966 0.5446904 ]
 [0.43196857 0.5680315 ]
 [0.43017322 0.5698268 ]
 [0.4333356  0.5666644 ]]

 task 2 sampled policy :
-----

In [134]:
pols = environ.networks['mtl-net'].test_sample_policy(hard_sampling = True, verbose = True)
print(type(pols))
for i in pols:
    print(i, '\n')

 MTL3_Dev test_sample_policy() START -  hard_sampling: True

 task1 logits
--------------
 [[0.33060804 0.45180422]
 [0.3532204  0.5528529 ]
 [0.38884717 0.6125409 ]
 [0.4204008  0.76851547]
 [0.45199737 0.7994047 ]
 [0.48402908 0.8020872 ]]

 task1 argmax /hard_sampled policy
-----------------------------------
 tensor([[0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]], device='cuda:0')


 task2 logits
--------------
 [[0.33064184 0.42053092]
 [0.3532089  0.52056104]
 [0.3888512  0.5680909 ]
 [0.42039296 0.694217  ]
 [0.4519742  0.73311865]
 [0.48401102 0.7522658 ]]

 task2 argmax /hard_sampled policy
-----------------------------------
 tensor([[0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]], device='cuda:0')


 task3 logits
--------------
 [[0.33058274 0.38303584]
 [0.3532048  0.46904057]
 [0.38887233 0.5593353 ]
 [0.42041987 0.7253615 ]
 [0.45201355 0.7284871 ]
 [0.48402616 0.7497743 ]]

 task3 argmax /ha

In [122]:
logs = np.array(
[[0.33064184, 0.42053092],
 [0.3532089 , 0.52056104],
 [0.3888512 , 0.5680909 ],
 [0.42039296, 0.694217  ],
 [0.4519742 , 0.73311865],
 [0.48401102, 0.7522658 ]],
)

smax = scipy.special.softmax(logs, axis =1)
# smax = np.array( 
# [[0.46973792, 0.530262  ],
#  [0.45025694, 0.549743  ],
#  [0.4443086 , 0.5556915 ],
#  [0.4138397 , 0.58616036],
#  [0.4140113 , 0.5859887 ],
#  [0.42114905, 0.57885087]])

print(smax.shape)
print(smax)
print(smax[0])
print(smax[0].sum())
print(np.random.choice((1,0), p =smax[0]))

(6, 2)
[[0.47754285 0.52245715]
 [0.45825934 0.54174066]
 [0.45530966 0.54469034]
 [0.43196854 0.56803146]
 [0.43017322 0.56982678]
 [0.43333559 0.56666441]]
[0.47754285 0.52245715]
0.9999999999999998


ValueError: 'p' must be 1-dimensional

In [115]:
task_key = 'task1_logits'
logits = getattr(environ.networks['mtl-net'], task_key).cpu()
print(logits.shape,logits)

torch.Size([6, 2]) tensor([[0.3306, 0.4518],
        [0.3532, 0.5529],
        [0.3888, 0.6125],
        [0.4204, 0.7685],
        [0.4520, 0.7994],
        [0.4840, 0.8021]])


#### Sparsity Error 

In [178]:
environ.networks['mtl-net'].backbone.layer_config

[1, 1, 1, 1, 1, 1]

In [63]:
num_blocks = 6
num_policy_layers = 6
gt =  torch.ones((num_blocks)).long()
gt0 =  torch.zeros((num_blocks)).long()
print(gt)
print(gt0)

loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
print(loss_weights)

tensor([1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 0, 0])
tensor([0.1667, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000])


In [77]:
if environ.opt['diff_sparsity_weights'] and not environ.opt['is_sharing']:
    print(' cond 1')
    ## Assign higher weights to higher layers 
    loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
    print(f"{task_key} sparsity error:  {2 * (loss_weights[-num_blocks:] * environ.cross_entropy2(logits[-num_blocks:], gt)).mean()})")
    print_dbg(f" loss_weights :  {loss_weights}", verbose = True)
    print_dbg(f" cross_entropy:  {environ.cross_entropy2(logits[-num_blocks:], gt)}  ", verbose = True)
    print_dbg(f" loss[sparsity][{task_key}]: {self.losses['sparsity'][task_key] } ", verbose = True)

else:
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between \n Logits   : \n{logits[-num_blocks:]} \n and gt: \n{gt} \n", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-num_blocks:], gt)}")
    
    print('\n cond 2')
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt[-1:])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits      : {logits[-1:]}  and gt: {gt0[-1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[-1:], gt0[-1:])} \n")
    
    print('\n cond 3')    
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt[0:1])} \n")
    print_dbg(f"Compute CrossEntropyLoss between Logits   : {logits[0:1]}  and gt: {gt0[0:1]} ", verbose = True)
    print(f"{task_key} sparsity error:  {environ.cross_entropy_sparsity(logits[0:1], gt0[0:1])} \n")
        
        

cond 2
Compute CrossEntropyLoss between 
 Logits   : 
tensor([[0.3306, 0.4518],
        [0.3532, 0.5529],
        [0.3888, 0.6125],
        [0.4204, 0.7685],
        [0.4520, 0.7994],
        [0.4840, 0.8021]]) 
 and gt: 
tensor([1, 1, 1, 1, 1, 1]) 

task1_logits sparsity error:  0.5725929141044617

 cond 2
Compute CrossEntropyLoss between Logits      : tensor([[0.4840, 0.8021]])  and gt: 1 
task1_logits sparsity error:  0.5467103123664856 

Compute CrossEntropyLoss between Logits      : tensor([[0.4840, 0.8021]])  and gt: 0 
task1_logits sparsity error:  0.864768385887146 


 cond 3
Compute CrossEntropyLoss between Logits   : tensor([[0.3306, 0.4518]])  and gt: tensor([1]) 
task1_logits sparsity error:  0.634384036064148 

Compute CrossEntropyLoss between Logits   : tensor([[0.3306, 0.4518]])  and gt: tensor([0]) 
task1_logits sparsity error:  0.7555801868438721 



In [24]:
# flag = 'update_w'
# environ.fix_alpha
# environ.free_w(opt['fix_BN'])

flag = 'update_alpha'
environ.fix_weights()
environ.free_alpha()

In [29]:
environ.networks['mtl-net'].num_layers

6

In [25]:
print(f"current_iters         : {current_iter}")  
print(f"curr_epochs           : {curr_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

train_total_epochs += 5

print(f"current_iters         : {current_iter}")  
print(f"curr_epochs           : {curr_epoch}") 
print(f"train_total_epochs    : {train_total_epochs}") 

current_iters         : 6580
curr_epochs           : 60
train_total_epochs    : 60


In [25]:
# print_metrics_cr(curr_epoch, time.time() - t0, None, environ.val_metrics , num_prints)      

# num_prints += 1
# t0 = time.time()

# # Take check point
# environ.save_checkpoint('latest', current_iter)
# environ.train()
# #-------------------------------------------------------
# # END validation process
# #-------------------------------------------------------       
# flag = 'update_alpha'
# environ.fix_w()
# environ.free_alpha()

Epoch | logloss bceloss  aucroc   aucpr  f1_max| t1 loss t2 loss t3 lossttl loss|tr_time|
1     | 0.00020 0.89942 0.81016 0.81627 0.75738|  4.6591  4.3649  4.4500 13.4740| 333.4|

In [179]:
environ.get_policy_prob()

[array([[0.46973792, 0.530262  ],
        [0.45025694, 0.549743  ],
        [0.4443086 , 0.5556915 ],
        [0.4138397 , 0.58616036],
        [0.4140113 , 0.5859887 ],
        [0.42114905, 0.57885087]], dtype=float32),
 array([[0.4775429 , 0.5224572 ],
        [0.45825934, 0.54174066],
        [0.45530966, 0.5446904 ],
        [0.43196857, 0.5680315 ],
        [0.43017322, 0.5698268 ],
        [0.4333356 , 0.5666644 ]], dtype=float32),
 array([[0.4868897 , 0.5131103 ],
        [0.47107342, 0.5289266 ],
        [0.45748708, 0.5425128 ],
        [0.42434993, 0.5756501 ],
        [0.4313186 , 0.56868154],
        [0.4339512 , 0.56604874]], dtype=float32)]

In [59]:
environ.temp

0.1418081981208155

In [71]:
# dilation = 2
# kernel_size = np.asarray((3, 3))
# upsampled_kernel_size = (kernel_size - 1) * (dilation - 1) + kernel_size
# print(upsampled_kernel_size)

In [30]:
# environ.optimizers['weights'].param_groups[0]
# for param_group in optimizer.param_groups:
#     return param_group['lr']

In [31]:
environ.schedulers['weights'].get_last_lr()

[1e-05, 1e-05]

In [38]:
environ.losses.keys()
pp.pprint(environ.losses)

{   'losses': {   'task1': tensor(2.3789, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'task2': tensor(2.1948, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'task3': tensor(2.3476, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>),
                  'total': tensor(6.9212, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)},
    'losses_mean': {   'task1': tensor(0.4758, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'task2': tensor(0.4390, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'task3': tensor(0.4695, device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>),
                       'total': tensor(1.3842, device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)},
    'parms': {'gumbel_temp': 2.6330653896376153, 'lr_0': 0.005, 'lr_1': 0.005},
    'sharing': {   'total': tensor(2.1338e-06, device='cud

In [79]:
tmp = environ.get_loss_dict()
print(tmp.keys())
pp.pprint(tmp)

dict_keys(['task1', 'task2', 'task3', 'tasks', 'total', 'sharing', 'sparsity'])
{   'sharing': {'total': tensor(0.0002, device='cuda:0')},
    'sparsity': {   'task1_logits': tensor(0.5147, device='cuda:0'),
                    'task2_logits': tensor(0.5195, device='cuda:0'),
                    'task3_logits': tensor(0.4978, device='cuda:0'),
                    'total': tensor(0.0766, device='cuda:0')},
    'task1': {'total': tensor(3.3657, device='cuda:0', dtype=torch.float64)},
    'task2': {'total': tensor(3.5906, device='cuda:0', dtype=torch.float64)},
    'task3': {'total': tensor(3.2182, device='cuda:0', dtype=torch.float64)},
    'tasks': {   'task1': tensor(3.3657, device='cuda:0', dtype=torch.float64),
                 'task2': tensor(3.5906, device='cuda:0', dtype=torch.float64),
                 'task3': tensor(3.2182, device='cuda:0', dtype=torch.float64),
                 'total': tensor(10.1745, device='cuda:0', dtype=torch.float64)},
    'total': {'total': tensor(10.25

In [81]:
tmp['task1']['total']+tmp['task2']['total']+tmp['task3']['total']+tmp['sharing']['total']+tmp['sparsity']['total']

tensor(10.2513, device='cuda:0', dtype=torch.float64)

In [83]:
current_iter

17713

In [86]:
current_state = {}
for k, v in environ.optimizers.items():
    print(f'state dict for {k} = {v}')
    current_state[k] = v.state_dict()
pp.pprint(current_state)

state dict for weights = SGD (
Parameter Group 0
    dampening: 0
    lr: 0.0001
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001

Parameter Group 1
    dampening: 0
    lr: 0.0001
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001
)
state dict for alphas = Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0.0005
)
{   'alphas': {   'param_groups': [   {   'amsgrad': False,
                                          'betas': (0.9, 0.999),
                                          'eps': 1e-08,
                                          'lr': 0.0001,
                                          'params': [0, 1, 2],
                                          'weight_decay': 0.0005}],
                  'state': {   0: {   'exp_avg': tensor([[ 0.0607, -0.0007],
        [-0.0428, -0.0069],
        [-0.1218,  0.0138],
        [ 0.0086,  0.0238]], device='cuda:0'),
                                      'exp_

In [88]:
current_state = {}
for k, v in environ.schedulers.items():
    print(f'state dict for {k} = {v}')
    print(v.state_dict())

state dict for weights = <torch.optim.lr_scheduler.StepLR object at 0x7f90c01c0ca0>
{'step_size': 4000, 'gamma': 0.5, 'base_lrs': [0.0001, 0.0001], 'last_epoch': 9100, '_step_count': 9101, 'verbose': False, '_get_lr_called_within_step': False, '_last_lr': [2.5e-05, 2.5e-05]}


### Warm-up:  validation - Dev

In [None]:
# validation
if should(current_iter, opt['train']['val_freq']):
    print(f"**  {timestring()}  START VALIDATION iteration: {current_iter} ")    

    environ.eval()     # set to evaluation mode (train = False)
    num_seg_class = opt['tasks_num_class'][opt['tasks'].index('seg')] if 'seg' in opt['tasks'] else -1
    val_metrics = eval_dev(environ, 
                          val_loader, 
                          opt['tasks'], 
                          policy=False, 
                          num_train_layers=None, 
                          eval_iter = 4)

In [None]:
val_metrics.keys()
val_metrics['loss']

In [None]:
for i in val_metrics:
    print(f'\n {i} \n -----------------')
    print(val_metrics[i])

In [None]:
    for t_id, task in enumerate(environ.tasks):
        task_key = f"task{t_id+1}"    
        environ.print_loss(current_iter, start_time, val_metrics[task_key]["classification_agg"], title='validation')
    

In [None]:
    environ.save_checkpoint('latest', current_iter)

    print(f"** {timestring()} - END VALIDATION iteration:  {current_iter} ")                
    environ.train()    # set to training mode (train = True)

In [None]:
# for i in val_metrics.keys():
#     print(i, type(val_metrics[i]))
#     for k in val_metrics[i].keys():
#         print(i,k, type(val_metrics[i][k]))
#         if isinstance(val_metrics[i][k], pd.core.series.Series):
#             print(f"val_metrics[{i}][{k}] is a series")
#         elif isinstance(val_metrics[i][k], pd.core.frame.DataFrame):
#             print(f"val_metrics[{i}][{k}] is a dataframe")        

# s = val_metrics['task1']['classification_agg']
# print(s)
# print(s.to_dict())

### Weight Training Dev

#### Weight training - prep

In [None]:
# arch_parms = environ.networks['mtl-net'].named_parameters()
# print(arch_parms)
# for name, parm in arch_parms:
#     print(name, '    ',parm.requires_grad)
# print_underline('MTL3_Dev Policys', verbose = True)
# for i in   environ.networks['mtl-net'].policys:
#     print(i)

In [None]:
print_heading(f"** {timestring()} - Training current iteration {current_iter}  flag: {flag} ", verbose = True)    

current_iter_w = 0 
current_iter_a = 0 
batch_idx_a = 0 
batch_idx_w = 0 

if flag_warmup:
    print_heading(f"** Set optimizer and scheduler to policy_learning = True", verbose = True)
    environ.define_optimizer(policy_learning=True)
    environ.define_scheduler(policy_learning=True)
    flag_warmup = False

if current_iter == opt['train']['warm_up_iters']:
    print_heading(f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
                  f"   Take checkpoint and block gradient flow through Policy net", verbose=True)
    environ.save_checkpoint('warmup', current_iter)
    environ.fix_alpha()
    
# batch_enumerator1 = enumerate(train1_loader,1)  
# batch_enumerator2 = enumerate(train2_loader,1)  

train_total_epochs = 10

#### Weight training - main 

In [None]:
print(f"opt['train']['print_freq']         {opt['train']['print_freq']}")
print(f"opt['train']['hard_sampling']      {opt['train']['hard_sampling']}")
print(f"opt['policy']                      {opt['policy']}")
print(f"opt['tasks']                       {opt['tasks']}")
print(f"weight_iter_alternate:             {opt['train']['weight_iter_alternate']}")
print(f"alpha_iter_alternate :             {opt['train']['alpha_iter_alternate']}")
print(f"current_iter                       {current_iter  }")
print(f"current_iter_w                     {current_iter_w}")
print(f"current_iter_a                     {current_iter_a}")
print(f"batch_idx_w                        {batch_idx_w}")
print(f"flag                               {flag          }")
print(f"train_total_epochs                 {train_total_epochs}") 

In [None]:
##---------------------------------------------------------------     
## Weight / Policy Training
##--------------------------------------------------------------- 
# stop_iter = current_iter_w +  opt['train']['weight_iter_alternate']
# print(f" Current Weight iteration {current_iter_w} - Run  from {current_iter_w+1} to {stop_iter+1}")


In [None]:
from tqdm.notebook import trange, tqdm

In [None]:
# with tnrange(start_iter_t , stop_iter_t  , initial = start_iter_t , total = stop_iter_t, position=0, leave= True, desc="master") as t :
# with tqdm_notebook(total=train_total_epochs) as t:

In [None]:
curr_epoch = 0
main_iter_ctr = 0 
verbose = False
t = tqdm(total=train_total_epochs, desc=f" Alternate Weight/Policy training")

while curr_epoch < train_total_epochs:
    curr_epoch+=1
    t.update(1)

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------
    if flag == 'update_w':
        current_iter_w  = 0 
        stop_iter_w =   opt['train']['weight_iter_alternate']

        with trange(+1, stop_iter_w+1 , initial = current_iter_w, total = stop_iter_w, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} weight training") as t_weights :
            for current_iter_w in t_weights:    
                current_iter += 1

                start_time = time.time()
                environ.train()
                
#                 if batch_idx_w == len(train1_loader):
#                     print_dbg(f"  Reenumerate train1_loader -  index_w: {batch_idx_w}   len(train1_loader) = {len(train1_loader)} ", verbose)
#                     batch_enumerator1 = enumerate(train1_loader,1)    
                    
                batch = next(train1_loader)
                environ.set_inputs(batch, train1_loader.dataset.input_size)

                ##----------------------------------------------------------------------
                ## Set number of layers to train based on cirriculum_speed 
                ## and p_epoch (number of epochs of policy training)
                ## When curriculum_speed == 3, a num_train_layers is incremented 
                ## after completion of every 3 policy training epochs
                ##----------------------------------------------------------------------
                if opt['is_curriculum']:
                    num_train_layers = p_epoch // opt['curriculum_speed'] + 1
                else:
                    num_train_layers = None


#                 print_heading(f"{timestring()} CALL ENVIRON.OPTIMIZE()    current_iter: {current_iter}     flag: {flag}\n"
#                       f"{' ':10s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                       f"{' ':10s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} \n"                          
#                       f"{' ':10s} is_policy: {opt['policy']}     p_epoch: {p_epoch}       num_train_layers: {num_train_layers}", verbose = False) 

                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)

                t_weights.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                       'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = "Weight training iteration", verbose = True)
                    environ.resize_results()

#                 print_heading(f"{timestring()} - CONTINUE WEIGHT TRAINING   current_iter: {current_iter}\n"
#                   f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                   f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']}",
#                   verbose = False)        

        #-------------------------------------------------------
        # validation process
        #------------------------------------------------------- 

#         if should(current_iter_w, opt['train']['weight_iter_alternate']): 

        if (current_iter_w >= stop_iter_w):
            environ.eval()
            print_dbg("++ Weight Training Validation  and then Switch to update_alpha", verbose = False)

            val_metrics = eval_dev(environ, 
                                  val_loader, 
                                  opt['tasks'], 
                                  policy=opt['policy'],
                                  num_train_layers=num_train_layers, 
                                  hard_sampling=opt['train']['hard_sampling'],
                                  eval_iter = -1)        

            if (verbose):
                for t_id, task in enumerate(environ.tasks):
                    task_key = f"task{t_id+1}"    
                    environ.print_metrics(current_iter, start_time, val_metrics[task_key]["classification_agg"], title='validation', verbose = verbose)        

            environ.save_checkpoint('latest', current_iter)

            #----------------------------------------------------------------------------------------------
            # if number of iterations completed after the warm up phase is greater than the number of 
            # (weight/policy alternations) x (cirriculum speed) x (number of layers to be policy trained)
            #
            # check metrics for improvement, and issue a checkpoint if necessary
            #----------------------------------------------------------------------------------------------



            if current_iter - opt['train']['warm_up_iters'] >= num_blocks * opt['curriculum_speed'] * \
                    (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']):
                new_value = 0
#                 print_heading(f"  evaluate progress and make checkpoint if necessary." , verbose = True)
#                 print(f" current iter                                 : {current_iter} \n"
#                       f" opt['train']['warm_up_iters']                : {opt['train']['warm_up_iters']} \n"
#                       f" num_blocks                                   : {num_blocks} \n"
#                       f" opt['curriculum_speed']                      : {opt['curriculum_speed']}\n"
#                       f" opt['train']['weight_iter_alternate']        : {opt['train']['weight_iter_alternate']}\n"
#                       f" opt['train']['alpha_iter_alternate']         : {opt['train']['alpha_iter_alternate']}\n"
#                       f" alpha_iter_alternate + weight_iter_alternate : {opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']}\n"
#                       f" num_blks * curriculum_speed * (alpha_alternate + weight_alternate): "
#                       f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} \n"

                print(f"  {current_iter - opt['train']['warm_up_iters']} IS GREATER THAN "
                       f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} -- "
                       f"  evaluate progress and make checkpoint if necessary." )            
#               ## compare validation metrics against reference metrics.

#                 for k in refer_metrics.keys():
#                     if k in val_metrics.keys():
#                         for kk in val_metrics[k].keys():
#                             if not kk in refer_metrics[k].keys():
#                                 continue
#                             if (k == 'sn' and kk in ['Angle Mean', 'Angle Median']) or (
#                                     k == 'depth' and not kk.startswith('sigma')) or (kk == 'err'):
#                                 value = refer_metrics[k][kk] / val_metrics[k][kk]
#                             else:
#                                 value = val_metrics[k][kk] / refer_metrics[k][kk]
#                             value = value / len(list(set(val_metrics[k].keys()) & set(refer_metrics[k].keys())))
#                             new_value += value

#                 print('Best Value %.4f  New value: %.4f' % new_value)

                # if results have improved, save these results and issue a checkpoint

#                 if (new_value > best_value):
#                     print('Previous best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)
#                     best_value = new_value
#                     best_metrics = val_metrics
#                     best_iter = current_iter
#                     environ.save_checkpoint('best', current_iter)
#                     print('New      best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)                         
#                     print('Best Value %.4f  New value: %.4f' % new_value)

                # if results have improved, save these results and issue a checkpoint   

            environ.train()
            #-------------------------------------------------------
            # END validation process
            #-------------------------------------------------------       
            print_heading(f"{timestring()} - SWITCH TO ALPHA TRAINING    current_iter: {current_iter}\n"
              f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
              f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']}",
              verbose = False)       
            flag = 'update_alpha'
            environ.fix_w()
            environ.free_alpha()
        #-------------------------------------------------------
        # end validation process
        #-------------------------------------------------------               


    #-----------------------------------------
    # Train & Update the  policy 
    #-----------------------------------------
    if flag == 'update_alpha':
        current_iter_a = 0
        stop_iter_a = opt['train']['alpha_iter_alternate']

        with trange( +1, stop_iter_a+1 , initial = 0, total = stop_iter_a, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} policy training") as t_policy :
            for current_iter_a in t_policy:    
                current_iter += 1

#                 batch_idx_a, batch = next(batch_enumerator2)
                batch = next(train2_loader)
                environ.set_inputs(batch, train2_loader.dataset.input_size)

#                 if batch_idx_a == len(train2_loader):
#                     print_dbg(f" Re-enumerate train2_loader  batch_idx_a: {batch_idx_a}   len(train2_loader) = {len(train2_loader)}", verbose=False)                
#                     batch_enumerator2 = enumerate(train2_loader,1)        

#                 print_heading(f"{timestring()} - ENVIRON.OPTIMIZE()    flag: {flag}    current_iter: {current_iter}   \n"
#                               f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                               f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} \n"
#                               f" is_policy: {opt['policy']}   num_train_layers: {num_train_layers}  hard_sampling: {opt['train']['hard_sampling']}\n"
#                               f" is_curriculum: {opt['is_curriculum']}     curriculum_speed: {opt['curriculum_speed']}   p_epoch: {p_epoch}"
#                               , verbose = False) 

                if opt['is_curriculum']:
                    num_train_layers = (p_epoch // opt['curriculum_speed']) + 1
                else:
                    num_train_layers = None

                print_dbg(f" num_train_layers  : {num_train_layers}", verbose = False)


                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)
                
                t_policy.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                      'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = "Policy training iteration", verbose=True)
                    environ.resize_results()
                    # environ.visual_policy(current_iter)

#                 print_heading(f"{timestring()} - CONTINUE ALPHA TRAINING    current_iter: {current_iter}\n"
#                               f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                               f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ", 
#                               verbose = False )      

        ## if (current_iter_a % alpha_iter_alternate) == 0 
#         if should(current_iter_a, opt['train']['alpha_iter_alternate']):
#         print(f" policy loop ended - current_iter_a: {current_iter_a}   stop_iter_a: {stop_iter_a}")
        if( current_iter_a >= stop_iter_a):            
#             print_heading(f"{timestring()} - SWITCH TO WEIGHT TRAINING  urrent_iter: {current_iter}\n"
#                           f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                           f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ",
#                           verbose = False )       

            flag = 'update_w'
            environ.fix_alpha()
            environ.free_w(opt['fix_BN'])
            environ.decay_temperature()

            # print the distribution
            print_dbg(np.concatenate(environ.get_policy_prob(), axis=-1), verbose = False)
            
            p_epoch += 1
            print_dbg(f"** p_epoch incremented: {p_epoch}")

In [None]:
print(f"{opt['train']['Lambda_sharing']:.5e}")

In [None]:
print('Previous best iter: %d, best_value: %.4f' % (best_iter, best_value))
print(best_metrics)
best_value = new_value
best_metrics = val_metrics
best_iter = current_iter
environ.save_checkpoint('best', current_iter)
print('New best iter : %d, best_value: %.4f \n' % (best_iter, best_value))                         
print(best_metrics)

In [None]:
# environ.losses['tasks'] = {'total' : torch.tensor(0.0, device  = environ.device, dtype=torch.float64)}
# environ.device

# print(val_metrics)
pp.pprint(environ.losses)
# environ.print_loss_2(current_iter, start_time, verbose=True)

### Policy Training 

In [None]:
# print(f" current iter                                 : {current_iter} \n"
#       f" opt['train']['warm_up_iters']                : {opt['train']['warm_up_iters']} \n"
#       f" num_blocks                                   : {num_blocks} \n"
#       f" opt['curriculum_speed']                      : {opt['curriculum_speed']}\n"
#       f" opt['train']['weight_iter_alternate']        : {opt['train']['weight_iter_alternate']}\n"
#       f" opt['train']['alpha_iter_alternate']         : {opt['train']['alpha_iter_alternate']}\n"
#       f" alpha_iter_alternate + weight_iter_alternate : {opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']}\n"
#       f" num_blocks * curriculum_speed * (alpha_iter_alternate + weight_iter_alternate): \
#           {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} \n"
#       f" IF {current_iter - opt['train']['warm_up_iters']} IS GREATER THAN  ??"
#       f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])}")


In [None]:
# print(f" task1_logits: {environ.networks['mtl-net'].task1_logits} \n")
# print(f" task2_logits: {environ.networks['mtl-net'].task2_logits} \n")
# print(f" task3_logits: {environ.networks['mtl-net'].task3_logits} \n")

In [None]:
# print(current_iter_a , opt['train']['alpha_iter_alternate'],flag)

In [None]:
##---------------------------------------------------------------     
## part one: warm up
##--------------------------------------------------------------- 
# print(current_iter_a , opt['train']['alpha_iter_alternate'],flag)
# stop_iter = current_iter_a +  opt['train']['alpha_iter_alternate']
# print(f" Run iteration {current_iter_a+1} to {stop_iter+1}")

In [None]:
# print(current_iter_a, stop_iter, flag)
# print(current_iter_a , opt['train']['alpha_iter_alternate'],flag)

In [None]:


if flag == 'update_alpha':

    stop_iter = current_iter_a +  opt['train']['alpha_iter_alternate']
    print(f" Current Alpha iteration {current_iter_a} - Run  from {current_iter_a+1} to {stop_iter+1}")
    
    with tnrange(current_iter_a+1, stop_iter+1 , initial = current_iter_a+1, total = stop_iter+1, position=0, leave= True, desc="weight training") as t :
        for current_iter_a in t:    
            current_iter += 1
 
            batch_idx_a, batch = next(batch_enumerator2)
            environ.set_inputs(batch, train2_loader.dataset.input_size)

            if batch_idx_a == len(train2_loader):
                print_dbg(f" Re-enumerate train2_loader  batch_idx_a: {batch_idx_a}   len(train2_loader) = {len(train2_loader)}", verbose=False)                
                batch_enumerator2 = enumerate(train2_loader,1)        
                  
            print_heading(f"{timestring()} - ENVIRON.OPTIMIZE()    flag: {flag}    current_iter: {current_iter}   \n"
                          f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
                          f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} \n"
                          f" is_policy: {opt['policy']}   num_train_layers: {num_train_layers}  hard_sampling: {opt['train']['hard_sampling']}\n"
                          f" is_curriculum: {opt['is_curriculum']}     curriculum_speed: {opt['curriculum_speed']}   p_epoch: {p_epoch}"
                          , verbose = False) 
    
            if opt['is_curriculum']:
                num_train_layers = (p_epoch // opt['curriculum_speed']) + 1
            else:
                num_train_layers = None

            print_dbg(f" num_train_layers  : {num_train_layers}", verbose = False)


            environ.optimize(opt['lambdas'], 
                             is_policy=opt['policy'], 
                             flag=flag, 
                             num_train_layers=num_train_layers,
                             hard_sampling=opt['train']['hard_sampling'],
                             verbose = False)

            if should(current_iter, opt['train']['print_freq']):
                environ.print_loss_2(current_iter, start_time, verbose=True)
                environ.resize_results()
                # environ.visual_policy(current_iter)

            print_heading(f"{timestring()} - CONTINUE ALPHA TRAINING    current_iter: {current_iter}\n"
                          f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
                          f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ", 
                          verbose = False )      
    
    ## if (current_iter_a % alpha_iter_alternate) == 0 
    if should(current_iter_a, opt['train']['alpha_iter_alternate']):
        print_dbg(f"** Switch training to update_weight")                
        print_heading(f"{timestring()} - SWITCH TO WEIGHT TRAINING  urrent_iter: {current_iter}\n"
                      f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
                      f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ",
                      verbose = True )       
        
        flag = 'update_w'
        environ.fix_alpha()
        environ.free_w(opt['fix_BN'])
        environ.decay_temperature()

        # print the distribution
        dists = environ.get_policy_prob()

        print(np.concatenate(dists, axis=-1))
        p_epoch += 1
        print(f"** p_epoch incremented: {p_epoch}")




In [None]:
print(current_iter)

In [None]:
print(f" task1_logits: \n {environ.networks['mtl-net'].task1_logits.detach().cpu().numpy()} \n")
print(f" task2_logits: \n {environ.networks['mtl-net'].task2_logits.detach().cpu().numpy()} \n")
print(f" task3_logits: \n {environ.networks['mtl-net'].task3_logits.detach().cpu().numpy()} \n")
print(f" task1 softmax: \n {softmax(environ.networks['mtl-net'].task1_logits.detach().cpu().numpy(), axis = -1)} \n")
print(f" task2 softmax: \n {softmax(environ.networks['mtl-net'].task2_logits.detach().cpu().numpy(), axis = -1)} \n")
print(f" task3 softmax: \n {softmax(environ.networks['mtl-net'].task3_logits.detach().cpu().numpy(), axis = -1)} \n")

In [None]:
# for i in [1,2,3]:
#     task_pred = f"task{i}_pred"
#     task_logits = f"task{i}_logits"
#     policy_attr = f"policy{i}"
#     logits_attr = f"logit{i}"
#     print_heading(f"{task_pred}")
#     print(getattr(environ, task_pred))
#     print(policy_attr)
#     print(getattr(environ, policy_attr)) 
#     print(logits_attr)
#     print(getattr(environ, logits_attr)) 
#     print(task_logits)
#     print(getattr(environ.networks['mtl-net'], task_logits)) 

In [None]:
import tqdm.notebook

In [None]:
current_iter    = 2174
current_iter_a  = 348
current_iter_w  = 348

## Load previously saved model

In [6]:
input_args = " --config yamls/adashare/chembl_2task.yml --cpu --batch_size 09999".split()
# get command line arguments
args = get_command_line_args(input_args)
print(args)

print()

if args.exp_instance is None:
    args.exp_instance = datetime.now().strftime("%m%d_%H%M")
    
print(args.exp_instance, args.config)

 command line parms :  {'config': 'yamls/adashare/chembl_2task.yml', 'exp_instance': None, 'exp_ids': [0], 'batch_size': 9999, 'backbone_lr': None, 'task_lr': None, 'decay_lr_rate': None, 'decay_lr_freq': None, 'gpus': [0], 'cpu': True}
Namespace(config='yamls/adashare/chembl_2task.yml', exp_instance=None, exp_ids=[0], batch_size=9999, backbone_lr=None, task_lr=None, decay_lr_rate=None, decay_lr_freq=None, gpus=[0], cpu=True)

0119_1230 yamls/adashare/chembl_2task.yml


In [7]:
print_separator('READ YAML')
opt, gpu_ids, exp_ids =read_yaml_from_input(args)
print(gpu_ids, exp_ids,  opt['train']['policy_iter'])

##################################################
####################READ YAML#####################
##################################################
[0] [0] best


In [None]:
current_iter = environ.load_checkpoint('latest')

print('Evaluating the snapshot saved at %d iter' % current_iter)

In [None]:
opt['train']['weight_iter_alternate'] = opt['train'].get('weight_iter_alternate', len(train1_loader))
opt['train']['alpha_iter_alternate'] = opt['train'].get('alpha_iter_alternate'  , len(train2_loader))

print(opt['train']['weight_iter_alternate'], opt['train']['alpha_iter_alternate'])

## Softmax & Gumbel Softmax

###  Softmax, LogSoftMax, NegLogLikelihood and Cross Entropy

In [None]:
from torch import nn
# print(nn.CrossEntropyLoss.__doc__)
loss = nn.CrossEntropyLoss(reduction ='none')
# i1 = torch.randn(3, 5, requires_grad=True)
# t1 = torch.empty(3, dtype=torch.long).random_(5)

# print(i1)
# print(i1)
# print(t1)
# output = loss(i1, t1)
# print(output, output.sum(), output.mean())

# i2 = torch.randn(1, 2, requires_grad=True)
i0 = torch.tensor([[0.0, 1.0]], dtype=torch.float)
i1 = torch.tensor([[1.0, 0.0]], dtype=torch.float)
i2 = torch.tensor([[0.5, 0.5]], dtype=torch.float)
i3 = torch.tensor([[0.4656, 0.5388]], dtype=torch.float)
sm = nn.Softmax(dim =-1)
lsm = nn.LogSoftmax(dim = -1)
nll = nn.NLLLoss(reduction='none')

t1 = torch.tensor([1], dtype=torch.int64)
t0 = torch.tensor([0], dtype=torch.int64)
t2 = torch.tensor([2], dtype=torch.int64)
print('i0     : ', i0)
print('sm(i0) : ', sm(i0))
print('lsm(i0): ', lsm(i0))
print()
print('i1     : ', i1)
print('sm(i1) : ', sm(i1))
print('lsm(i1): ', lsm(i1))
print()
print('i2     : ', i2)
print('sm(i2) : ', sm(i2))
print('lsm(i2): ', lsm(i2))
print()

print('t0: ',t0)
print('t1: ',t1)
print()
output1 = loss(i0, t0)
output2 = nll(lsm(i0), t0)
print('loss [0,1] and [0] : ', output1)
print('nll between lsm(i0): ', output2)
print()
output1 = loss(i0, t1)
output2 = nll(lsm(i0), t1)
print('loss [0,1] and [1] : ', output1)
print('nll between lsm(i0): ', output2)
print()

output1 = loss(i1, t0)
output2 = nll(lsm(i1), t0)
print('loss [1,0] and [0] : ', output1)
print('nll between lsm(i1): ', output2)
print()

output1 = loss(i1, t1)
output2 = nll(lsm(i1), t1)
print('loss [1,0] and [1] : ', output1)
print('nll between lsm(i1): ', output2)
print()

output1 = loss(i2, t0)
output2 = nll(lsm(i2), t0)
print('loss [0.5, 0.5] and [0] : ', output1)
print('nll between lsm(i1)   and [0] : ', output2)
print()

output1 = loss(i2, t1)
output2 = nll(lsm(i2), t1)
print('loss [0.5, 0.5] and [1] : ', output1)
print('nll between lsm(i1)   and [1] : ', output2)
print()


### Gumbel Softmax

In [None]:
a = torch.tensor([[0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000]], device='cuda:0', requires_grad=True) 

b = torch.tensor([0.5000, 0.5000,  0.5000, 0.5000], device='cuda:0', requires_grad=True) 
print(b.shape)
c = torch.tensor([[0.000, 0.000,  0.000, 0.000]], device='cuda:0', requires_grad=True) 
print(c.shape)
d = torch.tensor([[0.5000], [0.5000],  [0.5000], [0.5000]], device='cuda:0', requires_grad=True) 
print(d.shape)

In [None]:
print(i0)
print(i1)
print(i2)

In [None]:
temp  = 2.5
print(F.gumbel_softmax( i1, temp, hard=False))
print(F.gumbel_softmax( i1, temp, hard=False))
print(F.gumbel_softmax( i1, temp, hard=False))
print()

print(F.gumbel_softmax( i2, temp, hard=False))
print(F.gumbel_softmax( i2, temp, hard=False))
print(F.gumbel_softmax( i2, temp, hard=False))
print()

print(F.gumbel_softmax( i3, temp, hard=False))
print(F.gumbel_softmax( i3, temp, hard=False))
print(F.gumbel_softmax( i3, temp, hard=False))
# print(F.gumbel_softmax( d, 5, hard=False))

In [None]:
logits = torch.randn(20, 32)

In [None]:
print(logits[:2,:])

In [None]:
tmp = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
print(tmp[0])

In [None]:
print(tmp[0])
tmp1 = tmp.exponential_()
print(tmp1[0])

In [None]:
tmp2 = tmp1.log()
print(tmp2[0])

In [None]:
 gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    ) 

In [None]:
print(gumbels.shape)
print(gumbels[0])

In [None]:
# Sample soft categorical using reparametrization trick:
gumbel_soft = F.gumbel_softmax(logits, tau=1, hard=False)

# Sample hard categorical using "Straight-through" trick:
gumbel_hard  = F.gumbel_softmax(logits, tau=1, hard=True)

In [None]:
print(logits.shape)
print(logits[0])
print(np.argmax(logits[0]))
print('\n')

print(gumbel_soft.shape)
print(gumbel_soft[0])
print(np.argmax(gumbel_soft[0]))
print('\n')

print(gumbel_hard.shape)
print(gumbel_hard[0])
print(np.argmax(gumbel_hard[0]))

In [None]:
gumbel_soft.sum(axis=1)

In [None]:
tau = 1

In [None]:
a = torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)

In [None]:
a[:2]

fill tensor `a` with elements drawn from exponential distribution

In [None]:
a_e = a.exponential_()

In [None]:
a_e[:2]

draw natural log `ln()` on elements of a_e

In [None]:
a_e_l = a_e.log()

In [None]:
a_e_l[:2]

Neg log

In [None]:
a_el_neg = -a_e_l

a_el_neg[:2]

In [None]:
logits[:2]

In [None]:
gumbels = (logits + a_el_neg) / tau 

gumbels[:2]

In [None]:
dim = -1
gumbels.shape

In [None]:
 y_soft = gumbels.softmax(dim)

In [None]:
y_soft.shape

In [None]:
y_soft[:2]

In [None]:
index = y_soft.max(dim, keepdim=True)
print(index[0].T)
print(index[1].T)
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index[1], 1.0)

In [None]:
np.argmax(y_hard,axis=1)

In [None]:
tmp_d= [0,1,0]
for i in range(10):
    sampled = np.random.choice((2, 1, 0), p=tmp_d)
    print(sampled)

## Scratch Pad

In [3]:
# from numba import cuda

# cuda_device = 0 

# def free_gpu_cache(cuda_device):
#     print("Initial GPU Usage")    
#     gpu_usage()                             
#     print("GPU Usage after emptying the cache")
#     gpu_usage()
#     print("CUDA empty cache")
#     torch.cuda.empty_cache()
#     print("Close and reopen device")
#     cuda.select_device(cuda_device)
#     print("Close device")    
#     cuda.close()
#     print("Reopen device")    
#     cuda.select_device(cuda_device)
#     print("GPU Usage after closing and reopening")
#     gpu_usage()

# free_gpu_cache(0)

In [7]:
# def print_separator(text, total_len=50):
#     print('#' * total_len)
#     left_width = (total_len - len(text))//2
#     right_width = total_len - len(text) - left_width
#     print("#" * left_width + text + "#" * right_width)
#     print('#' * total_len)

# def print_dbg(text, verbose = False):
#     if verbose:
#         print(text)

# @debug_off
# def print_heading(text,  verbose = False):
#     len_ttl = max(len(text)+4, 50)
#     if verbose:
#         print('-' * len_ttl)
#         print(f" {text}")
#         # left_width = (total_len - len(text))//2
#         # right_width = total_len - len(text) - left_width
#         # print("#" * left_width + text + "#" * right_width)
#         print('-' * len_ttl,'\n')

# print_heading("hello_kevin", verbose=False)

### Chembl Data feed

In [5]:
# dataroot = opt['dataload']['dataroot']
# ecfp     = load_sparse(dataroot, opt['dataload']['x'])

# total_input = ecfp.shape[0]
# ranges      = (np.cumsum([0]+opt['dataload']['x_split_ratios'])* total_input).astype(np.int32)


# idx_train  = np.arange(ranges[0], ranges[1])
# idx_train1 = np.arange(ranges[1], ranges[2])
# idx_train2 = np.arange(ranges[2], ranges[3])
# idx_val    = np.arange(ranges[3], ranges[4])

# print(f" Total input    :  {total_input}   Cummulative dataset sizes: {ranges}")
# print(f" Ranges         :  {ranges}")
# print()
# print(f" X Dataset      :  {os.path.join(opt['dataload']['dataroot'], opt['dataload']['x'])}")
# print(f" y Dataset      :  {os.path.join(opt['dataload']['dataroot'], opt['dataload']['y_tasks'][0])}")
# print(f" Folding Dataset:  {os.path.join(opt['dataload']['dataroot'], opt['dataload']['folding'])}")
# print(f" Weights_class  :  {opt['dataload']['weights_class']}")
# print()
# print(f' idx_train    dataset size: {len(idx_train)  :6d}  - rows: {(idx_train)} ')
# print(f' idx_train1   dataset size: {len(idx_train1) :6d}  - rows: {(idx_train1)} ')
# print(f' idx_train2   dataset size: {len(idx_train2) :6d}  - rows: {(idx_train2)} ')
# print(f' val_train    dataset size: {len(idx_val)    :6d}  - rows: {(idx_val)} ')

### Test dataloader output

In [14]:
# val_batch_idx, val_batch = next(val_enumerator)
# print(type(val_batch['row_id']),val_batch['row_id'][0], val_batch['row_id'][-1] )

# ctr = 0
# for i in range(100):
#     val_batch_1 = next(val_loader)
#     print(' iteration: ', ctr,' len: ', len(val_batch_1['row_id']),'start: [', val_batch_1['row_id'][0],   val_batch_1['row_id'][-1],']' )
#     ctr += 1

# ctr = 0    
# for val_batch_1 in iter(val_loader):
# #     val_batch_1 = next(val_loader)
#     print(' iteration: ', ctr,' len: ', len(val_batch_1['row_id']),'start: [', val_batch_1['row_id'][0],   val_batch_1['row_id'][-1],']' )
#     ctr += 1    
#     if ctr == 105:
#         break

In [15]:
# val_batch_1 = next(val_iterator)


In [16]:
#  batch_idx, batch = next(batch_enumerator)

# print(batch.keys())
# print(batch['x_ind'].shape)
# print(type(batch['batch_size']))
# for i in batch.keys():
#     if not isinstance(batch[i], int):
#         print(i, batch[i].shape)

# task0_Y =  torch.sparse_coo_tensor(
#         batch["task0_ind"],
#         batch["task0_data"],
#         size = [batch["batch_size"], 5]).to("cpu", non_blocking=True).to_dense().numpy()

# print(task0_Y)

In [17]:
# print(f" train_loader: dataset input size       :  {train_loader.dataset.input_size}")
# print(f" train_loader: class output size        :  {train_loader.dataset.class_output_size}")
# print()
# print(f" size of training set 0 (warm up)       :  {len(trainset)}")
# print(f" size of training set 1 (network parms) :  {len(trainset1)}")
# print(f" size of training set 2 (policy weights):  {len(trainset2)}")
# print(f" size of validation set                 :  {len(valset)}")
# print(f"                                Total   :  {len(trainset)+len(trainset1)+len(trainset2)+len(valset)}")

# print(f" batch size       : {opt['train']['batch_size']}")
# print(f' len train_loader : {len(train_loader)}')
# print(f' len train1_loader: {len(train1_loader)}')
# print(f' len train2_loader: {len(train2_loader)}')
# print(f' len val_loader   : {len(val_loader)}')

### tqdm

In [None]:
from tqdm.notebook import tnrange

In [None]:
curr_iter_t  = 0
curr_iter_a  = 0
curr_iter_w  = 0
stop_iter_t  = 0
stop_iter_w  = 0 
stop_iter_a  = 0
total_weight_epochs = 0
total_policy_epochs = 0 
train_total_iters = 8
weight_iter_alternate = 17
alpha_iter_alternate = 17

In [None]:
print(curr_iter_t, stop_iter_t, flag, train_total_iters,opt['train']['print_freq'] )
start_iter_t = curr_iter_t
stop_iter_t = curr_iter_t +  train_total_iters 
print(f" Current iteration {curr_iter_t} - Run  from {start_iter_t} to {stop_iter_t}")

print(curr_iter_w, weight_iter_alternate , flag)
stop_iter_w = curr_iter_w +  weight_iter_alternate 
print(f" Current Weight iteration {curr_iter_w} - Run  from {curr_iter_w+1} to {stop_iter_w}")


print(curr_iter_a ,  alpha_iter_alternate ,flag)
stop_iter_a = curr_iter_a +  alpha_iter_alternate 
print(f" Current alpha iteration {curr_iter_a} - Run  from {curr_iter_a+1} to {stop_iter_a}")

In [None]:
# del t, t_w, t_a
main_iter_ctr = 0 
with tnrange(start_iter_t , stop_iter_t  , initial = start_iter_t , total = stop_iter_t, position=0, leave= True, desc="master") as t :
    for curr_t in t:
        
        with  tnrange(0, weight_iter_alternate , initial = 0, total = weight_iter_alternate, 
                      position=1, leave= False, desc=f"epoch {curr_t} weight training") as t_w :
            for curr_w in t_w:    
                sleep(0.35)
                main_iter_ctr += 1
                curr_iter_w  = curr_w
                t.set_postfix({'epoch': f"{curr_t}/{train_total_iters}", 'main_iter_ctr': main_iter_ctr})
                t_w.set_postfix({'weight training epoch': curr_t, 'batch #': curr_iter_w})

            print(f"** Epoch {curr_t}/{train_total_iters} weight training complete - Loss: "
                  f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_t:{curr_t}  main_iter_ctr:{main_iter_ctr}" )
                 
        
        with  tnrange(0, alpha_iter_alternate  , initial = 0, total = alpha_iter_alternate , 
                      position=2, leave= False, desc=f"epoch {curr_t} policy training") as t_a :
            for curr_a in t_a:    
                sleep(0.35)
                main_iter_ctr += 1                
                curr_iter_a = curr_a
                t.set_postfix({'epoch': f"{curr_t}/{train_total_iters}", 'main_iter_ctr':main_iter_ctr})
                t_a.set_postfix({'policy training epoch': curr_t, 'batch #': curr_iter_a})            
                
            print(f"** Epoch {curr_t}/{train_total_iters} policy training complete - Loss: "
                  f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_t:{curr_t}  main_iter_ctr:{main_iter_ctr}" )
        

In [None]:
curr_iter_t  = 0
curr_iter_a  = 0
curr_iter_w  = 0
stop_iter_t  = 0
stop_iter_w  = 0 
stop_iter_a  = 0
total_weight_epochs = 0
total_policy_epochs = 0 
train_total_iters = 100
train_total_epochs = 10
weight_iter_alternate = 17
alpha_iter_alternate = 17

print(curr_iter_t, stop_iter_t, flag, train_total_iters,opt['train']['print_freq'] )
start_iter_t = curr_iter_t
stop_iter_t = curr_iter_t +  train_total_iters 
print(f" Current iteration {curr_iter_t} - Run  from {start_iter_t} to {stop_iter_t}")

print(curr_iter_w, weight_iter_alternate , flag)
stop_iter_w = curr_iter_w +  weight_iter_alternate 
print(f" Current Weight iteration {curr_iter_w} - Run  from {curr_iter_w+1} to {stop_iter_w}")


print(curr_iter_a ,  alpha_iter_alternate ,flag)
stop_iter_a = curr_iter_a +  alpha_iter_alternate 
print(f" Current alpha iteration {curr_iter_a} - Run  from {curr_iter_a+1} to {stop_iter_a}")

In [None]:
# del t, t_w, t_a
curr_epoch = 0
main_iter_ctr = 0 
# with tnrange(start_iter_t , stop_iter_t  , initial = start_iter_t , total = stop_iter_t, position=0, leave= True, desc="master") as t :
# with tqdm_notebook(total=train_total_epochs) as t:
t = tqdm_notebook(total=train_total_epochs)

while curr_epoch < train_total_epochs:
    curr_epoch+=1
    t.update(1)

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------        
    with  tnrange(0, weight_iter_alternate , initial = 0, total = weight_iter_alternate, 
                  position=1, leave= False, desc=f"epoch {curr_epoch} weight training") as t_w :
        for curr_w in t_w:    
            sleep(0.35)
            main_iter_ctr += 1
            curr_iter_w  = curr_w

            t.set_postfix({'epoch': f"{curr_epoch}/{train_total_epochs}", 'main_iter_ctr': main_iter_ctr})
            t_w.set_postfix({'weight training epoch': curr_epoch, 'batch #': curr_iter_w})

        tqdm.write(f"** Epoch {curr_epoch}/{train_total_epochs} weight training complete - Loss: "
              f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_epoch:{curr_epoch}  main_iter_ctr:{main_iter_ctr}" )

    #-----------------------------------------
    # Train & Update the  policy 
    #-----------------------------------------        
    with  tnrange(0, alpha_iter_alternate  , initial = 0, total = alpha_iter_alternate , 
                  position=2, leave= False, desc=f"epoch {curr_epoch} policy training") as t_a :
        for curr_a in t_a:    
            sleep(0.35)
            main_iter_ctr += 1                
            curr_iter_a = curr_a

            t.set_postfix({'epoch': f"{curr_epoch}/{train_total_epochs}", 'main_iter_ctr':main_iter_ctr})
            t_a.set_postfix({'policy training epoch': curr_epoch, 'batch #': curr_iter_a})            

        tqdm.write(f"** Epoch {curr_epoch}/{train_total_epochs} policy training complete - Loss: "
              f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_epoch:{curr_epoch}  main_iter_ctr:{main_iter_ctr}" )


In [None]:
# with tnrange(start_iter, stop_iter , initial = current_iter_w, total = stop_iter,  position=0, leave= True, desc="training") as t:
#     for current_iter_w in t:
#         print(current_iter_w)
#         current_iter_w += 1
#         print(current_iter_w)        

In [None]:
# start = current_iter
# end = current_iter + opt['train']['warm_up_iters']
# curr_range = range(start,end)
# print(start, end)

# for i in tqdm.notebook.tnrange(start, end, initial = start, total = end):
#     sleep(0.25)
#     current_iter += 1
# #     print(i)
#     pass

# print(current_iter)

In [None]:
# start = current_iter
# end = current_iter + opt['train']['warm_up_iters']
# curr_range = range(start,end)
# print(start, end)

# for i in tqdm.notebook.tqdm_notebook(cur_range, initial = start, total = end, disable=False, position=0, desc = "validation"):
#     current_iter += 1
#     pass

# print(current_iter)

In [None]:
# from tqdm import trange
# from time import sleep

# for i in trange(40, desc='1st loop', position=0, leave = False):
#     sleep(1.1)
#     for j in trange(5, desc='2nd loop', position =1, leave = False):
#         sleep(0.01)
#         for k in trange(50, desc='3rd loop', position =0,leave=False):
#             sleep(0.01)

In [None]:
# with tqdm(batch_enumerator, leave=False, disable=False) as t:
# with tqdm(total=10, bar_format="{postfix[0]} {postfix[1][value]:>8.2g}", postfix=["Batch", dict(value=0)]) as t:
# with trange(opt['train']['warm_up_iters'], bar_format="{postfix[0]} {postfix[1][value]:>8.2g}", postfix=["Batch", dict(value=0)]) as t:
# with trange(opt['train']['warm_up_iters']) as t:

#     for current_iter in t:
#         batch_idx, batch = next(batch_enumerator)
#         ran = random.randint(1, 100)
#         start_time = time.time()

#         environ.train()

#         print_heading(f" {timestring()} - WARMUP Training iter {current_iter}/{opt['train']['warm_up_iters']}    batch_idx: {batch_idx}"    
#                       f"    Warm-up iters: {opt['train']['warm_up_iters']}"
#                       f"    Validation freq:  {opt['train']['val_freq']}", verbose = False)

#         if batch_idx == len(train_loader) :
#     #         print_heading(f" ******* {timestring()}  re-enumerate train_loader() *******")
#             batch_enumerator = enumerate(train_loader,1)   

#         t.set_postfix({'batch_idx': batch_idx, 'num_vowels': ran})


In [None]:
# import tqdm.notebook

In [None]:
a = np.array([[1,2,3,4,5,6,7,8,9,10],[11,12,13,14,15,16,17,18,19,20]])
print(a)

In [None]:
print(a[:,::-1])
print()
print(a[::-1])

In [None]:
a = 4
b = 4
c = 1

0 // b + c

### Scipy Sparse

In [133]:
import scipy.sparse
row = np.array([5])
col = np.array([0])
data = np.array([6])
y_class = scipy.sparse.csr_matrix((data, (row, col)), shape=(15,1))
y_2 = scipy.sparse.csr_matrix((15,0))
print(f"Created y_class # dims: {y_class.ndim}    shape: {y_class.shape}")
print(f"Created y_2 # dims: {y_2.ndim}    shape: {y_2.shape}")

# y_class[5]= 2

Created y_class # dims: 2    shape: (15, 1)
Created y_2 # dims: 2    shape: (15, 0)


In [None]:
print(y_class.toarray().T)
print(y_2.toarray())

In [None]:
y_class[8,0]= 999
y_class.toarray().squeeze().shape

### folding step by step

In [None]:
x_dev = copy.copy(x_file)
print(x_dev.shape, type(x_dev))

idx = x_dev.nonzero()
print(idx[0][:82])
print(idx[1][:82])
print(x_dev[0].sum(), x_dev[1].sum())
print(x_dev[0,0:-1])

In [None]:
folding_size = 30
print(f" fold - folding_size:{folding_size}" )

## collapse x into folding_size columns
idx = x_dev.nonzero()
folded = idx[1] % folding_size

print(folded[:82])

In [None]:
x_fold = scipy.sparse.csr_matrix((x_dev.data, (idx[0], folded)), shape=(x_dev.shape[0], folding_size))
print(x_fold.shape)

In [None]:
x_fold.sum_duplicates()

In [None]:
print(x_fold)

In [None]:
print(type(y_files[0]), type(y_class[0]))

### Torch tensor manipulations

In [175]:
# a = np.random.rand(4,3)
# print(a.shape)
# print(a)
# rows = [1,3,2,0,2]
# cols = [0,0,1,2,2]
# a1 = a[rows,cols]
# print(a1.shape)

In [176]:
# a = torch.randn(64, 3, 16, 16)
# print(a.shape, a.view(-1).shape, a.permute(0,2,3,1).shape)
# b = torch.randint(0,2, [64,3,16,16])
# print(b.shape, b.view(-1).shape, b.permute(0,2,3,1).shape)

In [None]:
input = torch.randn(3, 5, 2)

In [None]:
input

In [None]:
input.contiguous().view(-1).shape

### using eval()

In [None]:
print(opt['dataload']['y_tasks'])

In [None]:
task1 = [1,2,3]
task2 = None
task3 = {'4': 'Kevin', '5':'Bardool'}

for i in [1,2,3]:
    print('task{:d}'.format(i))
    if eval('task{:d}'.format(i)) is None:
        print('task{:d} :  has not been defined '.format(i))
#         exec_str = 'task{:d} =  np.random.rand({:d},{:d})'.format(i,3,2)
        exec_str = 'y_task{:d} = scipy.sparse.csr_matrix(({:d}, {:d})) '.format(i,3,2)
        print(exec_str)
        exec(exec_str)
        print('task{:d} : '.format(i), eval('task{:d}'.format(i)))
        print(eval('type(task{:d})'.format(i)))
    else:
        print('task{:d} : '.format(i), eval('task{:d}'.format(i)))
        
print(f"Created task{i} shape        : {eval('len(task{:d})'.format(i))}")
print(len(task3))

In [None]:
a = np.random.rand(3,2)
print(a)

In [None]:
print(eval('len(task{:d})'.format(i)))

## Load datasets, perform folding

In [None]:
## Verify presence of Y label data
if (opt['dataload']['y_tasks'] is None) and (opt['dataload']['y_regr'] is None):
   raise ValueError("No label data specified, please add --y_class and/or --y_regr.")

print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['x']))  
print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['folding']))
for fl in opt['dataload']['y_tasks']:
    print(os.path.join(opt['dataload']['dataroot'], fl))

#### Load X data file

In [None]:
##
## Load data files 
##
# ecfp     = sc.load_sparse(args.x)
# y_class  = sc.load_sparse(args.y_class)
# y_regr   = sc.load_sparse(args.y_regr)
# y_censor = sc.load_sparse(args.y_censor)

dataroot = opt['dataload']['dataroot']

ecfp     = load_sparse(dataroot, opt['dataload']['x'])
# x_file   = copy.copy(ecfp)

print(f" Input    {opt['dataload']['x']} - type : {type(ecfp)} shape : {ecfp.shape}")
# print(f" Input    {opt['dataload']['x']} - type : {type(x_file)} shape : {x_file.shape}")


#### Load Y label files

In [None]:
for i,( y_task, y_type )in enumerate(zip(opt['dataload']['y_tasks'],opt['sc_tasks']),1):
    print(full_path := os.path.join(dataroot, y_task ))
    tmp = load_sparse(full_path)
#     np.load(full_path, allow_pickle=True) 
#     print(type(tmp), tmp.shape)
#     tmp_sparse = scipy.sparse.csr_matrix(tmp)
    print(type(tmp_sparse), tmp_sparse.shape)
    print('indicies: ', len(tmp_sparse.__dict__['indices']), tmp_sparse.__dict__['indices'])
    print('indptr  : ', len(tmp_sparse.__dict__['indptr']) , tmp_sparse.__dict__['indptr'])
    print('data    : ', len(tmp_sparse.__dict__['data'])   , tmp_sparse.__dict__['data'])    

In [None]:
# y_class  = load_sparse(dataroot, opt['dataload']['y_tasks'][0])
# print(f" Input     - type : {type(y_class)} shape : {y_class.shape}")

# y_regr  = load_sparse(dataroot, opt['dataload']['y_tasks'][1])
# print(f" Input     - type : {type(y_regr)} shape : {y_regr.shape}")

##
## Load Y label files 
##
y_files=[]

for i,( y_task, y_type )in enumerate(zip(opt['dataload']['y_tasks'],opt['sc_tasks']),1):
    y_tmp = load_sparse(dataroot,  y_task)
    print(f" y_task:{i}  task type: {y_type:5s}  dataset: {y_task} - type : {type(y_tmp)} shape : {y_tmp.shape}")
    ## Get number of positive / neg and total for each classes
    num_pos    = np.array((y_tmp == +1).sum(0)).flatten()
    num_neg    = np.array((y_tmp == -1).sum(0)).flatten()
    num_class  = np.array((y_tmp !=  0).sum(0)).flatten()
    if (num_class != num_pos + num_neg).any():
        raise ValueError("For classification all y values (--y_class/--y) must be 1 or -1.")
    else:
        y_files.append(y_tmp)

y_class = copy.copy(y_files[0])
# y_regr = copy.copy(y_files[1])

#### Load folding file 

In [None]:
##
## load folding file
##
folding_file = os.path.join(dataroot,opt['dataload']['folding'])
folding  = np.load(folding_file)
print(f" Folding  {folding_file} - type : {type(folding)} shape : {folding.shape}")
print(f"          {folding[:20]}")

assert ecfp.shape[0] == folding.shape[0], "x and folding must have same number of rows" 

#### Load Y censor file

In [None]:
# y_censor = load_sparse(dataroot, opt['dataload']['y_censor'])
# if y_censor is not None:
#     print(f" Input     - type : {type(y_censor)} shape : {y_censor.shape}") 

##
## Load Y censor file
##

# y_censor = load_sparse(dataroot, opt['dataload']['y_censor'])
# if y_censor is None:
#     y_censor = scipy.sparse.csr_matrix(y_regr.shape)
#     vprint(f" y_sensor is {opt['dataload']['y_censor']}   Created y_censor shape       : {y_censor.shape}")
    
# y_censor_shape = y_censor.shape if y_censor is not None else "n/a"
# print(f" y_censor  - type : {type(y_censor)}  shape: {y_censor_shape}")

In [None]:
# if (y_regr is None) and (y_censor is not None):
#     raise ValueError("y_censor provided please also provide --y_regr.")

# # if y_class is None:
# #     y_class = scipy.sparse.csr_matrix((ecfp.shape[0], 0))
# #     vprint(f"Created y_class shape        : {y_class.shape}")

# if y_regr is None:
#     y_regr  = scipy.sparse.csr_matrix((ecfp.shape[0], 0))
#     vprint(f"Created y_regr shape         : {y_regr.shape}")

#### Input folding & transformation

In [None]:
print(f"args.fold_inputs : {opt['dataload']['fold_inputs']} \t\t  transform: {opt['dataload']['input_transform']}\n")
print(repr(ecfp))
ecfp = fold_and_transform_inputs(ecfp, folding_size=opt['dataload']['fold_inputs'], transform=opt['dataload']['input_transform'])
print(repr(ecfp))

print(type(ecfp), ecfp.shape)


####  Loading weights files for tasks

In [None]:
# num_regr   = np.bincount(y_regr.indices, minlength=y_regr.shape[1])
print(' Classification weights: ',opt['dataload']['weights_class'])
tasks_class = load_task_weights(opt['dataload']['weights_class'], y=y_class[0], label="y_class")
# tasks_regr  = load_task_weights(opt['dataload']['weights_regr'] , y=y_regr , label="y_regr")

print(tasks_class)
print(tasks_class.training_weight.shape)

In [None]:
print(y_files[0].shape, y_files[1].shape, y_class.shape)

In [None]:
if tasks_class.aggregation_weight is None:
    '''
    fold classes 
    '''
    ## using min_samples rule
    fold_pos, fold_neg = class_fold_counts(y_class, folding)
    n = opt['dataload']['min_samples_class']
    tasks_class.aggregation_weight = ((fold_pos >= n).all(0) & (fold_neg >= n)).all(0).astype(np.float64)
    print(f" tasks_class.aggregation_weight WAS NOT passed ")
    print(f" min_samples_class: opt['dataload']['min_samples_class']")
    print(f" Class fold counts: \n  fold_pos:\n{fold_pos}  \n\n  fold_neg:\n{fold_neg}") 
else:
    print(f"  tasks_class.aggregation_weight passed ")
    
print(f" tasks_class.aggregation_weight.shape: {tasks_class.aggregation_weight.shape} \n {tasks_class.aggregation_weight}")

In [None]:
# if tasks_regr.aggregation_weight is None:
#     if y_censor.nnz == 0:
#         y_regr2 = y_regr.copy()
#         y_regr2.data[:] = 1
#     else:
#         ## only counting uncensored data
#         y_regr2      = y_censor.copy()
#         y_regr2.data = (y_regr2.data == 0).astype(np.int32)
  
#     fold_regr, _ = sc.class_fold_counts(y_regr2, folding)
#     del y_regr2
#     tasks_regr.aggregation_weight = (fold_regr >= args.min_samples_regr).all(0).astype(np.float64)

In [None]:
##
## Display dataset dimensions 
##
print(f"Input dimension      : {ecfp.shape[1]}")
print(f"#samples             : {ecfp.shape[0]}")
print(f"#classification tasks: {y_class[0].shape[1]}")
print(f"Using {(tasks_class.aggregation_weight > 0).sum()} classification tasks for calculating aggregated metrics (AUCROC, F1_max, etc).")

# vprint(f"#regression tasks    : {y_regr.shape[1]}")
# vprint(f"Using {(tasks_regr.aggregation_weight > 0).sum()} regression tasks for calculating metrics (RMSE, Rsquared, correlation).")
# print(ecfp[18387,:10].toarray())

####  Compute batch size

In [None]:
print(f" batch_ratio        : {opt['batch_ratio']}")
print(f" internal_batch_max : {opt['internal_batch_max']}")

# batch_size  = int(np.ceil(opt['batch_ratio'] * idx_tr.shape[0]))
# num_int_batches = 1
# print(f" batch_ratio * # idx_tr:   {opt['batch_ratio']} * {idx_tr.shape[0]} = {opt['batch_ratio'] * idx_tr.shape[0]}")


# if opt['internal_batch_max'] is not None:
#     if opt['internal_batch_max'] < batch_size:
#         num_int_batches = int(np.ceil(batch_size / opt['internal_batch_max']))
#         print(f"\n\n internal_batch_max: {opt['internal_batch_max']}   batch_size: {batch_size}")
#         print(f" batch_size / internal_batch_max: {batch_size / opt['internal_batch_max']}   num_int_batches: {num_int_batches}")
#         batch_size      = int(np.ceil(batch_size / num_int_batches))
#         print(f" batch_size / num_int_batches: {batch_size / num_int_batches}   modified batch_size: {batch_size}")
        

batch_size = 320 
print(f" batch size:   {batch_size}")

#### Separate test dataset

In [None]:
print(f"opt['dataload']['fold_te'] : {opt['dataload']['fold_te'] }")
print(f"opt['dataload']['fold_va'] : {opt['dataload']['fold_va'] }")

In [None]:
if opt['dataload']['fold_te'] is not None and opt['dataload']['fold_te'] >= 0:
    ## removing test data
    print(f" Remove test data")
    assert opt['dataload']['fold_te'] != opt['dataload']['fold_va'], "fold_va and fold_te must not be equal."
    keep    = (folding != args.fold_te)
    ecfp    = ecfp[keep]
    y_class = y_class[keep]
    y_regr  = y_regr[keep]
    y_censor= y_censor[keep]
    folding = folding[keep]

In [None]:
prop = (np.cumsum([0.3, 0.3, 0.3, 0.1])* ecfp.shape[0]+1).astype(np.int32)
print(prop, prop.astype(np.int32))

#### Separate train, train1, train2, and validation  dataset

In [None]:
fold_va = opt['dataload']['fold_va']
fold_va = 0
fold_train1 = 1
fold_train2 = 2

idx_val    = np.where(folding == fold_va)[0]
idx_train  = np.where(folding == fold_train1)[0]
idx_train1 = np.where(folding == fold_train2)[0]
idx_train2 = np.where(folding >  fold_train2)[0]



In [None]:
dataroot = opt['dataload']['dataroot']
ecfp     = load_sparse(dataroot, opt['dataload']['x'])

total_input = ecfp.shape[0]
ranges      = (np.cumsum([0.3, 0.3, 0.1, 0.3])* total_input).astype(np.int32)
print(total_input, '     ', ranges)

idx_train  = np.arange(ranges[0])
idx_train1 = np.arange(ranges[0], ranges[1])
idx_train2 = np.arange(ranges[1], ranges[2])
idx_val    = np.arange(ranges[2], ranges[-1])

print( f' idx_train   len: {len(idx_train) :6d}  - {(idx_train)} ')
print( f' idx_train1  len: {len(idx_train1):6d}  - {(idx_train1)}')
print( f' idx_train2  len: {len(idx_train2):6d}  - {(idx_train2)}')
print( f' idx_val     len: {len(idx_val)   :6d}  - {(idx_val)}   ')

In [None]:
# y_class_tr = y_class[idx_train]
# y_class_va = y_class[idx_va]

# y_regr_tr  = y_regr[idx_tr]
# y_regr_va  = y_regr[idx_va]

# y_censor_tr = y_censor[idx_tr]
# y_censor_va = y_censor[idx_va]

# num_pos_va  = np.array((y_class_va == +1).sum(0)).flatten()
# num_neg_va  = np.array((y_class_va == -1).sum(0)).flatten()
# num_regr_va = np.bincount(y_regr_va.indices, minlength=y_regr.shape[1])