In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
%load_ext autoreload
%autoreload 2

import os 
import time
import argparse
import yaml
from tqdm import tqdm, tqdm_notebook, trange
# import tqdm.notebook.trange as tnrange
import copy
import numpy as np
import torch
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader 
import scipy.sparse
from time import sleep
from scipy.special import softmax
from tqdm.notebook import trange, tqdm

np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')
import pprint 
pp = pprint.PrettyPrinter(indent=4)
# torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
torch.set_printoptions( linewidth=132)

In [2]:
print(' Cuda is available  : ', torch.cuda.is_available())
print(' CUDA device count  : ', torch.cuda.device_count())
print(' CUDA current device: ', torch.cuda.current_device())

print(' GPU Processes : ', torch.cuda.list_gpu_processes())
print()

for i in range(torch.cuda.device_count()):
    print(' Device : ', i)
    print('   name:       ', torch.cuda.get_device_name())
    print('   capability: ', torch.cuda.get_device_capability())
    print('   properties: ', torch.cuda.get_device_properties(i))


 Cuda is available  :  True
 CUDA device count  :  1
 CUDA current device:  0
 GPU Processes :  pynvml module not found, please install pynvml

 Device :  0
   name:        NVIDIA GeForce GTX 970M
   capability:  (5, 2)
   properties:  _CudaDeviceProperties(name='NVIDIA GeForce GTX 970M', major=5, minor=2, total_memory=3071MB, multi_processor_count=10)


In [3]:
from GPUtil import showUtilization as gpu_usage
gpu_usage()                             

import torch
from GPUtil import showUtilization as gpu_usage
from numba import cuda

print(' Allocated : ', torch.cuda.memory_allocated("cuda:0") ) #returns you the current GPU memory usage by tensors in bytes for a given device
print(' Reserved  : ', torch.cuda.memory_reserved("cuda:0") )#returns you the current GPU memory managed by caching allocator in bytes for a given device, in previous PyTorch versions the command was torch.cuda.memory_cached
 

cuda_device = 0 

def free_gpu_cache(cuda_device):
    print("Initial GPU Usage")    
    gpu_usage()                             

    print("GPU Usage after emptying the cache")
    gpu_usage()
    
    print("CUDA empty cache")
    torch.cuda.empty_cache()

    print("Close and reopen device")
    cuda.select_device(cuda_device)
    print("Close device")    
    cuda.close()
    print("Reopen device")    
    cuda.select_device(cuda_device)

    print("GPU Usage after closing and reopening")
    gpu_usage()

            
free_gpu_cache(0)

| ID | GPU  | MEM |
-------------------
|  0 | nan% |  1% |
 Allocated :  0
 Reserved  :  0
Initial GPU Usage
| ID | GPU  | MEM |
-------------------
|  0 | nan% |  1% |
GPU Usage after emptying the cache
| ID | GPU  | MEM |
-------------------
|  0 | nan% |  1% |
CUDA empty cache
Close and reopen device
Close device
Reopen device
GPU Usage after closing and reopening
| ID | GPU  | MEM |
-------------------
|  0 | nan% |  3% |


In [4]:
from dev.sparsechem_utils import load_sparse, load_task_weights, class_fold_counts, fold_and_transform_inputs
from dev.chembl_dataloader_dev import ClassRegrSparseDataset_v3, ClassRegrSparseDataset
from utils.util import (makedir, print_separator, create_path, print_yaml, should, fix_random_seed, 
                        read_yaml_from_input, timestring, print_heading, print_dbg, print_underline)
from dev.sparsechem_env_dev import SparseChemEnv_Dev
from dev.train_dev import eval_dev

def vprint(s="", verbose = False):
    if verbose:
        print(s)

vprint(f"\nArgs : \n--------------")

 


In [5]:
# def print_separator(text, total_len=50):
#     print('#' * total_len)
#     left_width = (total_len - len(text))//2
#     right_width = total_len - len(text) - left_width
#     print("#" * left_width + text + "#" * right_width)
#     print('#' * total_len)

# def print_dbg(text, verbose = False):
#     if verbose:
#         print(text)

# @debug_off
# def print_heading(text,  verbose = False):
#     len_ttl = max(len(text)+4, 50)
#     if verbose:
#         print('-' * len_ttl)
#         print(f" {text}")
#         # left_width = (total_len - len(text))//2
#         # right_width = total_len - len(text) - left_width
#         # print("#" * left_width + text + "#" * right_width)
#         print('-' * len_ttl,'\n')

In [6]:
print_heading("hello_kevin", verbose=False)

## Read yaml config file

In [7]:
input_args = " --config yamls/adashare/chembl_2task.yml --cpu ".split()

print_separator('READ YAML')
opt, gpu_ids, _ = read_yaml_from_input(input_args)
fix_random_seed(opt["seed"][0])


##################################################
####################READ YAML#####################
##################################################
{'config': 'yamls/adashare/chembl_2task.yml', 'exp_ids': [0], 'gpus': [0], 'cpu': True}


In [8]:
from datetime import datetime
date_time = datetime.now().strftime("%Y%m%d_%H%M%S")
print(opt['exp_name'], date_time)

SparseChem 20211224_112853


In [9]:
opt['exp_name'] = date_time

In [10]:
create_path(opt)

 Create folder ../experiments/logs/SparseChem/20211224_112853
 Create folder ../experiments/results/SparseChem/20211224_112853
 Create folder ../experiments/checkpoints/SparseChem/20211224_112853


In [11]:
# print yaml on the screen

lines = print_yaml(opt)
with open(os.path.join(opt['paths']['log_dir'], opt['exp_name'], 'opt.txt'), 'w+') as f:
    f.writelines(lines)

for line in lines: print(line)
    


exp_name. : 20211224_112853
seed. : [88, 45, 50, 100, 44, 48, 2048, 2222, 9999]
backbone. : SparseChem
backbone_orig. : ResNet18
orig_tasks. : ['seg', 'sn']
tasks. : ['class', 'class', 'class']
tasks_num_class. : [5, 5, 5]
lambdas. : [1, 1, 1]
policy_model. : task-specific
verbose. : True
paths.
paths.log_dir. : ../experiments/logs/SparseChem
paths.result_dir. : ../experiments/results/SparseChem
paths.checkpoint_dir. : ../experiments/checkpoints/SparseChem
dataload.
dataload.dataset. : Chembl_23_mini
dataload.dataroot. : /home/kbardool/kusanagi/MLDatasets/chembl_23_mini
dataload.x. : chembl_23mini_x.npy
dataload.folding. : chembl_23mini_folds.npy
dataload.weights_class. : None
dataload.fold_inputs. : 32000
dataload.input_transform. : None
dataload.y_tasks. : ['chembl_23_adashare_y1_bin_sparse.npy', 'chembl_23_adashare_y2_bin_sparse.npy', 'chembl_23_adashare_y3_bin_sparse.npy']
dataload.y_censor. : None
dataload.fold_te. : None
dataload.crop_h. : 321
dataload.crop_w. : 321
dataload.min

## Chembl Dataloader V3

#### Chembl Data feed

In [12]:
dataroot = opt['dataload']['dataroot']
ecfp     = load_sparse(dataroot, opt['dataload']['x'])

total_input = ecfp.shape[0]
ranges      = (np.cumsum([0.3, 0.3, 0.3, 0.1])* total_input).astype(np.int32)
print(total_input, '     ', ranges)

idx_train  = np.arange(ranges[0])
idx_train1 = np.arange(ranges[0], ranges[1])
idx_train2 = np.arange(ranges[1], ranges[2])
idx_val    = np.arange(ranges[2], ranges[-1])


print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['x']))
print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['y_tasks'][0]))
print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['folding']))
print(' weights_class: ',opt['dataload']['weights_class'])

print(f' idx_train    len: {len(idx_train)  :6d}  - {(idx_train)} ')
print(f' idx_train1   len: {len(idx_train1) :6d}  - {(idx_train1)} ')
print(f' idx_train2   len: {len(idx_train2) :6d}  - {(idx_train2)} ')
print(f' val_train    len: {len(idx_val)    :6d}  - {(idx_val)} ')


18388       [ 5516 11032 16549 18387]
/home/kbardool/kusanagi/MLDatasets/chembl_23_mini/chembl_23mini_x.npy
/home/kbardool/kusanagi/MLDatasets/chembl_23_mini/chembl_23_adashare_y1_bin_sparse.npy
/home/kbardool/kusanagi/MLDatasets/chembl_23_mini/chembl_23mini_folds.npy
 weights_class:  None
 idx_train    len:   5516  - [   0    1    2 ... 5513 5514 5515] 
 idx_train1   len:   5516  - [ 5516  5517  5518 ... 11029 11030 11031] 
 idx_train2   len:   5517  - [11032 11033 11034 ... 16546 16547 16548] 
 val_train    len:   1838  - [16549 16550 16551 ... 18384 18385 18386] 


In [13]:
class InfiniteDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize an iterator over the dataset.
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            # Dataset exhausted, use a new fresh iterator.
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch

### Validation dataset

In [14]:
valset = ClassRegrSparseDataset_v3(opt, index = idx_val, verbose = True)
val_loader = InfiniteDataLoader(valset, batch_size=opt['train']['batch_size'], num_workers = 1, pin_memory=True, collate_fn=valset.collate, shuffle=False)






### Training dataset

In [15]:
trainset = ClassRegrSparseDataset_v3(opt, index = idx_train) 
train_loader = InfiniteDataLoader(trainset, batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset.collate, shuffle=False)

trainset1 = ClassRegrSparseDataset_v3(opt, index = idx_train1)
train1_loader = InfiniteDataLoader(trainset1, batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset1.collate, shuffle=False)

trainset2 = ClassRegrSparseDataset_v3(opt, index = idx_train2)
train2_loader = InfiniteDataLoader(trainset2, batch_size=opt['train']['batch_size'], num_workers = 2, pin_memory=True, collate_fn=trainset2.collate, shuffle=False)












#### Test dataloader output

In [16]:
# val_enumerator = enumerate(val_loader)
# val_iterator = iter(val_loader)

In [17]:
# val_batch_idx, val_batch = next(val_enumerator)
# print(type(val_batch['row_id']),val_batch['row_id'][0], val_batch['row_id'][-1] )

# ctr = 0
# for i in range(100):
#     val_batch_1 = next(val_loader)
#     print(' iteration: ', ctr,' len: ', len(val_batch_1['row_id']),'start: [', val_batch_1['row_id'][0],   val_batch_1['row_id'][-1],']' )
#     ctr += 1

# ctr = 0    
# for val_batch_1 in iter(val_loader):
# #     val_batch_1 = next(val_loader)
#     print(' iteration: ', ctr,' len: ', len(val_batch_1['row_id']),'start: [', val_batch_1['row_id'][0],   val_batch_1['row_id'][-1],']' )
#     ctr += 1    
#     if ctr == 105:
#         break

In [18]:
# val_batch_1 = next(val_iterator)


In [19]:
#  batch_idx, batch = next(batch_enumerator)

# print(batch.keys())
# print(batch['x_ind'].shape)
# print(type(batch['batch_size']))
# for i in batch.keys():
#     if not isinstance(batch[i], int):
#         print(i, batch[i].shape)

# task0_Y =  torch.sparse_coo_tensor(
#         batch["task0_ind"],
#         batch["task0_data"],
#         size = [batch["batch_size"], 5]).to("cpu", non_blocking=True).to_dense().numpy()

# print(task0_Y)

In [20]:
# print(f" train_loader: dataset input size       :  {train_loader.dataset.input_size}")
# print(f" train_loader: class output size        :  {train_loader.dataset.class_output_size}")
# print()
# print(f" size of training set 0 (warm up)       :  {len(trainset)}")
# print(f" size of training set 1 (network parms) :  {len(trainset1)}")
# print(f" size of training set 2 (policy weights):  {len(trainset2)}")
# print(f" size of validation set                 :  {len(valset)}")
# print(f"                                Total   :  {len(trainset)+len(trainset1)+len(trainset2)+len(valset)}")

# print(f" batch size       : {opt['train']['batch_size']}")
# print(f' len train_loader : {len(train_loader)}')
# print(f' len train1_loader: {len(train1_loader)}')
# print(f' len train2_loader: {len(train2_loader)}')
# print(f' len val_loader   : {len(val_loader)}')

## Create Environment

In [21]:
opt['train']['weight_iter_alternate'] = opt['train'].get('weight_iter_alternate', len(train1_loader))
opt['train']['alpha_iter_alternate'] = opt['train'].get('alpha_iter_alternate'  , len(train2_loader))

In [22]:
print(f" trainset.y_class     :  {[ i.shape  for i in trainset.y_class_list]}")
print(f" trainset1.y_class    :  {[ i.shape  for i in trainset1.y_class_list]}")
print(f" trainset2.y_class    :  {[ i.shape  for i in trainset2.y_class_list]}")
print(f" valset.y_class       :  {[ i.shape  for i in valset.y_class_list  ]} ")
print()
print(f" length train_loader  :  {len(train_loader)}")
print(f" length train1_loader :  {len(train1_loader)}")
print(f" length train2_loader :  {len(train2_loader)}")
print(f" length val_loader    :  {len(val_loader)}")

print()
print(f' size of training set 0 (warm up)       :  {len(trainset)}')
print(f' size of training set 1 (network parms) :  {len(trainset1)}')
print(f' size of training set 2 (policy weights):  {len(trainset2)}')
print(f' size of validation set                 :  {len(valset)}')
print(f'                               Total    :  {len(trainset)+len(trainset1)+len(trainset2)+len(valset)}')
print()
print(f' len train_loader set 0 (warm up)       :  {len(train_loader)}')
print(f' len train_loader set 1 (network parms) :  {len(train1_loader)}')
print(f' len train_loader set 2 (policy weights):  {len(train2_loader)}')
print(f' len val_loader                         :  {len(val_loader)}')
print()
# print(f' length of train_loader training set 0 (warm up) : {len(train_loader)}')
# print(f' len train1_loader: {len(train1_loader)}')
# print(f' len train2_loader: {len(train2_loader)}')


print(f" batch size                             : {opt['train']['batch_size']}")
print()

 trainset.y_class     :  [(5516, 5), (5516, 5), (5516, 5)]
 trainset1.y_class    :  [(5516, 5), (5516, 5), (5516, 5)]
 trainset2.y_class    :  [(5517, 5), (5517, 5), (5517, 5)]
 valset.y_class       :  [(1838, 5), (1838, 5), (1838, 5)] 

 length train_loader  :  87
 length train1_loader :  87
 length train2_loader :  87
 length val_loader    :  29

 size of training set 0 (warm up)       :  5516
 size of training set 1 (network parms) :  5516
 size of training set 2 (policy weights):  5517
 size of validation set                 :  1838
                               Total    :  18387

 len train_loader set 0 (warm up)       :  87
 len train_loader set 1 (network parms) :  87
 len train_loader set 2 (policy weights):  87
 len val_loader                         :  29

 batch size                             : 64



In [23]:
print( 
    f"\n backbone                : {opt['backbone']}",
    f"\n paths.log_dir           : {opt['paths']['log_dir']}", 
    f"\n paths.checkpoint_dir    : {opt['paths']['checkpoint_dir']}", 
    f"\n experiment name         : {opt['exp_name']}",
    f"\n tasks_num_class         : {opt['tasks_num_class'],}",
    f"\n Hidden sizes            : {opt['hidden_sizes']}",     
    f"\n init_neg_logits         : {opt['init_neg_logits'],}",
    f"\n device id               : {gpu_ids[0]}",
    f"\n init temp               : {opt['train']['init_temp'],}",
    f"\n decay temp              : {opt['train']['decay_temp']}",
    f"\n fix BN parms            : {opt['fix_BN']}",
    f"\n skip_layer              : {opt['skip_layer']}",
    f"\n train.init_method       : {opt['train']['init_method']}",
    f"\n Total iterations        : {opt['train']['total_iters']}",
    f"\n Warm-up iterations      : {opt['train']['warm_up_iters']}",
    f"\n Print Frequency         : {opt['train']['print_freq']}",
    f"\n Validation Frequency    : {opt['train']['val_freq']} \n",
    f"\n Weight iter alternate   : {opt['train']['weight_iter_alternate'] }",
    f"\n Alpha  iter alternate   : {opt['train']['alpha_iter_alternate'] }")
# print('\n\n Opt file \n ------------ \n')
# pp.pprint(opt)


 backbone                : SparseChem 
 paths.log_dir           : ../experiments/logs/SparseChem 
 paths.checkpoint_dir    : ../experiments/checkpoints/SparseChem 
 experiment name         : 20211224_112853 
 tasks_num_class         : ([5, 5, 5],) 
 Hidden sizes            : [41, 42, 43, 44] 
 init_neg_logits         : (None,) 
 device id               : 0 
 init temp               : (5,) 
 decay temp              : 0.965 
 fix BN parms            : False 
 skip_layer              : 0 
 train.init_method       : equal 
 Total iterations        : 200 
 Warm-up iterations      : 4000 
 Print Frequency         : 100 
 Validation Frequency    : 400 
 
 Weight iter alternate   : 87 
 Alpha  iter alternate   : 87


### Create model


In [24]:
# from dev.MTL2_Dev import MTL2_Dev
# from dev.blockdrop_env_dev import BlockDropEnv_Dev
# del environ
# print(num_int_)


In [25]:
environ = SparseChemEnv_Dev(log_dir = opt['paths']['log_dir'], 
                            checkpoint_dir = opt['paths']['checkpoint_dir'], 
                            exp_name = opt['exp_name'],
                            tasks_num_class = opt['tasks_num_class'], 
                            init_neg_logits = opt['init_neg_logits'], 
                            device = gpu_ids[0],
                            init_temperature = opt['train']['init_temp'], 
                            temperature_decay= opt['train']['decay_temp'], 
                            is_train=True,
                            opt=opt, 
                            verbose = False)

-------------------------------------------------------
 * SparseChemEnv_Dev  Initializtion - verbose: False
------------------------------------------------------- 

------------------------------------------------------------
 SparseChemEnv_Dev.super() init()  Start - verbose: False
------------------------------------------------------------ 

 log_dir        :  ../experiments/logs/SparseChem/20211224_112853 
 checkpoint_dir :  ../experiments/checkpoints/SparseChem/20211224_112853 
 exp_name       :  20211224_112853 
 tasks_num_class:  [5, 5, 5] 
 device         :  cuda:0 
 device id      :  0 
 dataset        :  Chembl_23_mini 
 tasks          :  ['class', 'class', 'class'] 

--------------------------------------------------
 SparseChemEnv_Dev.super() init()  end
-------------------------------------------------- 

 is_train       :  True 
 init_neg_logits:  None 
 init temp      :  5 
 decay temp     :  0.965 
 input_size     :  32000 
 normalize loss :  None 
 num_tasks      :  

## Training

### Training Preparation

In [26]:
# print(environ.get_arch_parameters())
# print()
# print(environ.get_task_specific_parameters())
# print()
# print(environ.get_backbone_parameters())

In [27]:
print( 
    f"\n backbone                : {opt['backbone']}",
    f"\n paths.log_dir           : {opt['paths']['log_dir']}", 
    f"\n paths.checkpoint_dir    : {opt['paths']['checkpoint_dir']}", 
    f"\n experiment name         : {opt['exp_name']}",
    f"\n tasks_num_class         : {opt['tasks_num_class'],}",
    f"\n init_neg_logits         : {opt['init_neg_logits'],}",
    f"\n device id               : {gpu_ids[0]}",
    f"\n init temp               : {opt['train']['init_temp'],}",
    f"\n decay temp              : {opt['train']['decay_temp']}",
    f"\n fix BN parms            : {opt['fix_BN']}",
    f"\n skip_layer              : {opt['skip_layer']}",
    f"\n train.init_method       : {opt['train']['init_method']}",
    f"\n Total iterations        : {opt['train']['total_iters']}",
    f"\n Warm-up iterations      : {opt['train']['warm_up_iters']}",
    f"\n Print Frequency         : {opt['train']['print_freq']}",
    f"\n Validation Frequency    : {opt['train']['val_freq']}",
    f"\n Weight iter alternate   : {opt['train']['weight_iter_alternate'] }",
    f"\n Alpha  iter alternate   : {opt['train']['alpha_iter_alternate'] }")


 backbone                : SparseChem 
 paths.log_dir           : ../experiments/logs/SparseChem 
 paths.checkpoint_dir    : ../experiments/checkpoints/SparseChem 
 experiment name         : 20211224_112853 
 tasks_num_class         : ([5, 5, 5],) 
 init_neg_logits         : (None,) 
 device id               : 0 
 init temp               : (5,) 
 decay temp              : 0.965 
 fix BN parms            : False 
 skip_layer              : 0 
 train.init_method       : equal 
 Total iterations        : 200 
 Warm-up iterations      : 4000 
 Print Frequency         : 100 
 Validation Frequency    : 400 
 Weight iter alternate   : 87 
 Alpha  iter alternate   : 87


In [28]:
# val_enumerator = enumerate(val_loader)
# batch_enumerator  = enumerate(train_loader,1)
# batch_enumerator1 = enumerate(train1_loader,1)
# batch_enumerator2 = enumerate(train2_loader,1)

In [29]:
from tqdm.notebook import trange, tqdm

In [30]:
environ.define_optimizer(policy_learning=False)
environ.define_scheduler(policy_learning=False)

In [31]:

current_iter   = 0
current_iter_w = 0 
current_iter_a = 0

flag         = 'update_w'
best_value   = 0 
best_iter    = 0
p_epoch      = 0
best_metrics = None
flag_warmup  = True
num_blocks = sum(environ.networks['mtl-net'].layers)

if opt['train']['resume']:
    print('Resume training')
    current_iter = environ.load(opt['train']['which_iter'])
    environ.networks['mtl-net'].reset_logits()
else:
    print('Initiate Training ')


if torch.cuda.is_available():
    print('cuda available', gpu_ids)   
    environ.cuda(gpu_ids)
else:
    print('cuda not available')
    environ.cpu()
print('\n')


print(f"which_iter   : {opt['train']['which_iter']}\n"
      f"train_resume :  {opt['train']['resume']}")

print(f"network[mtl_net].layers: {environ.networks['mtl-net'].layers}")
print(opt['backbone'], 'num_blocks: ', num_blocks)

# Fix Alpha - 
environ.fix_alpha()
environ.free_w(opt['fix_BN'])

Initiate Training 
cuda available [0]


which_iter   : warmup
train_resume :  False
network[mtl_net].layers: [1, 1, 1, 1]
SparseChem num_blocks:  4


### Warm-up Training

In [32]:
print(f"current_iter        : {current_iter}") 
print(f"warm_up_iters       : {opt['train']['warm_up_iters']}")              
print(f"val_freq            : {opt['train']['val_freq']     }")       
print(f"batch_size          : {opt['train']['batch_size']   }")         
print(f"total_iters         : {opt['train']['total_iters']}")  
print(f"warm_up_iters       : {opt['train']['warm_up_iters']}")
print(f"print_freq          : {opt['train']['print_freq']  }")
print(f"val_freq            : {opt['train']['val_freq']    }")
print(f"Length train_loader :  {len(train_loader)}")

current_iter        : 0
warm_up_iters       : 4000
val_freq            : 400
batch_size          : 64
total_iters         : 200
warm_up_iters       : 4000
print_freq          : 100
val_freq            : 400
Length train_loader :  87


In [33]:
opt['train']['warm_up_iters'] = 200
opt['train']['total_iters'] = 2000
opt['train']['val_freq']   = 50
eval_iter = 20
batch_idx = 0
# opt['train']['print_freq'] =1
# opt['train']['weight_iter_alternate'] = 10
# opt['train']['alpha_iter_alternate']  = 10


# opt['train']['warm_up_iters'] = 200
# opt['train']['val_freq']   = 50

In [34]:
 print(current_iter)

0


In [35]:
stop_iter  = current_iter +  opt['train']['warm_up_iters']
print(f" Current_iter: {current_iter}  warm-up iterations:{opt['train']['warm_up_iters']} - Run  from {current_iter+1} to {stop_iter+1}")


 Current_iter: 0  warm-up iterations:200 - Run  from 1 to 201


### Warm-Up Training Loop

In [36]:
# print_heading(f" {timestring()} - Training iteration {current_iter}  flag: {flag}  p_epoch: {p_epoch} ")    
# print(current_iter , , flag)

In [38]:
##---------------------------------------------------------------     
## part one: warm up
##--------------------------------------------------------------- 
stop_iter  = current_iter + opt['train']['warm_up_iters']
print(f" Current_iter: {current_iter}  warm-up iterations:{opt['train']['warm_up_iters']} - Run  from {current_iter+1} to {stop_iter+1}")

with trange(current_iter+1, stop_iter+1 , initial = current_iter, total = stop_iter, position=0, leave= True, desc="training") as t_warmup :
    
    for current_iter in t_warmup:
        start_time = time.time()
        environ.train()
        batch_idx += 1
        batch = next(train_loader)    
#         print_heading(f" {timestring()} - WARMUP Training iter {current_iter}/{opt['train']['warm_up_iters']}    batch_idx: {batch_idx}"    
#                       f"    Warm-up iters: {opt['train']['warm_up_iters']}"
#                       f"    Validation freq:  {opt['train']['val_freq']}", verbose = False)

        environ.set_inputs(batch,train_loader.dataset.input_size)

        environ.optimize(opt['lambdas'], is_policy=False, flag='update_w', verbose = False)
        
        t_warmup.set_postfix({'curr_iter':current_iter, 
                              'bch_idx': batch_idx, 
                              'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                              'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})

        if should(current_iter, opt['train']['print_freq']):
            environ.print_loss(current_iter, start_time, verbose = True)

#         print(f"**  {timestring()}  iteration: {current_iter}  Complete - Loss: {environ.losses['total']['total']:.4f}" )

        ##--------------------------------------------------------------- 
        # validation
        ##--------------------------------------------------------------- 
        if should(current_iter, opt['train']['val_freq']):
            print_dbg(f"**  {timestring()}  START VALIDATION iteration: {current_iter}    Validation freq {opt['train']['val_freq']}") 

            num_seg_class = opt['tasks_num_class'][opt['tasks'].index('seg')] if 'seg' in opt['tasks'] else -1
            val_metrics = eval_dev(environ, 
                                  val_loader, 
                                  opt['tasks'], 
                                  policy=False, 
                                  num_train_layers=None,
                                  eval_iter = 20)

            environ.print_metrics(current_iter, start_time, val_metrics, title='validation')
            environ.save_checkpoint('warmup', current_iter)
            
            print_dbg(f"** {timestring()} - END VALIDATION iteration:  {current_iter}  validation loss: {val_metrics['loss']['total']:.4f}", verbose = True)                
print()     

 Current_iter: 200  warm-up iterations:200 - Run  from 201 to 401


training:  50%|#####     | 200/400 [00:00<?, ?it/s]

validation:   0%|          | 0/20 [00:00<?, ?it/s]

 ++ Validation - eval_iter:20    sum_task_loss: 69.8336   task_loss_avg: 3.4917
 ++ Validation - eval_iter:20    sum_task_loss: 68.7340   task_loss_avg: 3.4367
 ++ Validation - eval_iter:20    sum_task_loss: 68.7788   task_loss_avg: 3.4389
 ++ Validation - eval_iter:20    loss_sum : 207.3463  loss_sum_avg:10.3673 
** 2021-12-24 11:32:34:600296 - END VALIDATION iteration:  250  validation loss: 10.3673
Iteration  300 -  Total Loss: 10.1068     Task Loss: 10.1068  


validation:   0%|          | 0/20 [00:00<?, ?it/s]

 ++ Validation - eval_iter:20    sum_task_loss: 69.3143   task_loss_avg: 3.4657
 ++ Validation - eval_iter:20    sum_task_loss: 68.3496   task_loss_avg: 3.4175
 ++ Validation - eval_iter:20    sum_task_loss: 68.5352   task_loss_avg: 3.4268
 ++ Validation - eval_iter:20    loss_sum : 206.1991  loss_sum_avg:10.3100 
** 2021-12-24 11:32:57:455293 - END VALIDATION iteration:  300  validation loss: 10.3100


validation:   0%|          | 0/20 [00:00<?, ?it/s]

 ++ Validation - eval_iter:20    sum_task_loss: 69.2379   task_loss_avg: 3.4619
 ++ Validation - eval_iter:20    sum_task_loss: 68.3216   task_loss_avg: 3.4161
 ++ Validation - eval_iter:20    sum_task_loss: 68.5635   task_loss_avg: 3.4282
 ++ Validation - eval_iter:20    loss_sum : 206.1230  loss_sum_avg:10.3061 
** 2021-12-24 11:33:20:350911 - END VALIDATION iteration:  350  validation loss: 10.3061
Iteration  400 -  Total Loss: 9.5312     Task Loss: 9.5312  


validation:   0%|          | 0/20 [00:00<?, ?it/s]

 ++ Validation - eval_iter:20    sum_task_loss: 69.2979   task_loss_avg: 3.4649
 ++ Validation - eval_iter:20    sum_task_loss: 68.0641   task_loss_avg: 3.4032
 ++ Validation - eval_iter:20    sum_task_loss: 68.2750   task_loss_avg: 3.4137
 ++ Validation - eval_iter:20    loss_sum : 205.6369  loss_sum_avg:10.2818 
** 2021-12-24 11:33:43:052799 - END VALIDATION iteration:  400  validation loss: 10.2818



In [39]:
print(environ.metrics.keys())
# pp.pprint(environ.metrics['task1'])

print(environ.losses.keys())

pp.pprint(environ.losses)
pp.pprint(environ.losses['task1']['total'])


dict_keys(['task1', 'task2', 'task3'])
dict_keys(['task1', 'task2', 'task3'])
{   'task1': {'total': tensor(3.2914, device='cuda:0', dtype=torch.float64)},
    'task2': {'total': tensor(3.2416, device='cuda:0', dtype=torch.float64)},
    'task3': {'total': tensor(3.3270, device='cuda:0', dtype=torch.float64)}}
tensor(3.2914, device='cuda:0', dtype=torch.float64)


In [40]:
print(val_metrics.keys())
pp.pprint(val_metrics['loss'])
pp.pprint(val_metrics['task1']['classification_agg']['err'].cpu().numpy())
pp.pprint(val_metrics['task2']['classification_agg']['err'].cpu().numpy())
pp.pprint(val_metrics['task3']['classification_agg']['err'].cpu().numpy())
 

dict_keys(['loss', 'task1', 'task2', 'task3'])
{'total': 10.204687159220782}
array(3.29143871)
array(3.2415543)
array(3.32701487)


In [None]:
# environ.print_loss(current_iter, start_time, metrics = val_metrics['loss'], verbose=True)
# print(opt['lambdas'])
# p = (opt['lambdas'][0] * environ.losses['tasks']['task1'])
# print(p)

# environ.print_metrics(current_iter, start_time, val_metrics , title='validation', verbose=True)    

In [50]:
pp.pprint(environ.optimizers)
pp.pprint(environ.schedulers['weights'])

{   'alphas': Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0.0005
),
    'weights': SGD (
Parameter Group 0
    dampening: 0
    lr: 0.0001
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001

Parameter Group 1
    dampening: 0
    lr: 0.0001
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001
)}
<torch.optim.lr_scheduler.StepLR object at 0x7fc0b4037910>


In [39]:
print_heading(f"** {timestring()} - Training current iteration {current_iter}  flag: {flag} ", verbose = True)    

current_iter_w = 0 
current_iter_a = 0 
batch_idx_a = 0 
batch_idx_w = 0 


if flag_warmup:
    print_heading(f"** Set optimizer and scheduler to policy_learning = True", verbose = True)
    environ.define_optimizer(policy_learning=True)
    environ.define_scheduler(policy_learning=True)
    flag_warmup = False

if current_iter == opt['train']['warm_up_iters']:
    print_heading(f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
                  f"   Take checkpoint and block gradient flow through Policy net", verbose=True)
    environ.save_checkpoint('warmup', current_iter)
    environ.fix_alpha()
    
# batch_enumerator1 = enumerate(train1_loader,1)  
# batch_enumerator2 = enumerate(train2_loader,1)  

train_total_epochs = 10

-----------------------------------------------------------------------------------
 ** 2021-12-24 11:33:58:258425 - Training current iteration 400  flag: update_w 
----------------------------------------------------------------------------------- 

------------------------------------------------------------
 ** Set optimizer and scheduler to policy_learning = True
------------------------------------------------------------ 



In [46]:
print(f"opt['train']['print_freq']         {opt['train']['print_freq']}")
print(f"opt['train']['hard_sampling']      {opt['train']['hard_sampling']}")
print(f"opt['policy']                      {opt['policy']}")
print(f"opt['tasks']                       {opt['tasks']}")
print(f"weight_iter_alternate:             {opt['train']['weight_iter_alternate']}")
print(f"alpha_iter_alternate :             {opt['train']['alpha_iter_alternate']}")
print(f"current_iter                       {current_iter  }")
print(f"current_iter_w                     {current_iter_w}")
print(f"current_iter_a                     {current_iter_a}")
print(f"batch_idx_w                        {batch_idx_w}")
print(f"flag                               {flag          }")
curr_epoch = 0
train_total_epochs = 30
print(f"curr_epochs                        {curr_epoch}") 
print(f"train_total_epochs                 {train_total_epochs}") 


opt['train']['print_freq']         100
opt['train']['hard_sampling']      False
opt['policy']                      True
opt['tasks']                       ['class', 'class', 'class']
weight_iter_alternate:             87
alpha_iter_alternate :             87
current_iter                       4750
current_iter_w                     87
current_iter_a                     87
batch_idx_w                        0
flag                               update_w
curr_epochs                        0
train_total_epochs                 30


In [47]:
curr_epoch = 0
main_iter_ctr = 0 
verbose = False
t = tqdm(total=train_total_epochs, desc=f" Alternate Weight/Policy training")

while curr_epoch < train_total_epochs:
    curr_epoch+=1
    t.update(1)

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------
    if flag == 'update_w':
        current_iter_w  = 0 
        stop_iter_w =   opt['train']['weight_iter_alternate']

        with trange(+1, stop_iter_w+1 , initial = current_iter_w, total = stop_iter_w, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} weight training") as t_weights :
            for current_iter_w in t_weights:    
                current_iter += 1

                start_time = time.time()
                environ.train()
                
                batch = next(train1_loader)
                environ.set_inputs(batch, train1_loader.dataset.input_size)

                ##----------------------------------------------------------------------
                ## Set number of layers to train based on cirriculum_speed 
                ## and p_epoch (number of epochs of policy training)
                ## When curriculum_speed == 3, a num_train_layers is incremented 
                ## after completion of every 3 policy training epochs
                ##----------------------------------------------------------------------
                if opt['is_curriculum']:
                    num_train_layers = p_epoch // opt['curriculum_speed'] + 1
                else:
                    num_train_layers = None

                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)

                t_weights.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                       'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = f"Weight training epoch:{curr_epoch} iteration:", verbose = True)
                    environ.resize_results()

        #-------------------------------------------------------
        # validation process
        #------------------------------------------------------- 
#       if should(current_iter_w, opt['train']['weight_iter_alternate']): 
        if (current_iter_w >= stop_iter_w):
            environ.print_loss(current_iter, start_time, title = f"Weight training epoch:{curr_epoch} iteration:", verbose = True)
            environ.eval()

            val_metrics = eval_dev(environ, 
                                  val_loader, 
                                  opt['tasks'], 
                                  policy=opt['policy'],
                                  num_train_layers=num_train_layers, 
                                  hard_sampling=opt['train']['hard_sampling'],
                                  eval_iter = -1)        

            if (verbose):
                for t_id, task in enumerate(environ.tasks):
                    task_key = f"task{t_id+1}"    
                    environ.print_metrics(current_iter, start_time, val_metrics[task_key]["classification_agg"], title='validation', verbose = verbose)        

            environ.save_checkpoint('latest', current_iter)

            #----------------------------------------------------------------------------------------------
            # if number of iterations completed after the warm up phase is greater than the number of 
            # (weight/policy alternations) x (cirriculum speed) x (number of layers to be policy trained)
            #
            # check metrics for improvement, and issue a checkpoint if necessary
            #----------------------------------------------------------------------------------------------

            if current_iter - opt['train']['warm_up_iters'] >= num_blocks * opt['curriculum_speed'] * \
                    (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']):
                new_value = 0
#                 print(f"  {current_iter - opt['train']['warm_up_iters']} IS GREATER THAN "
#                        f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} -- "
#                        f"  evaluate progress and make checkpoint if necessary." )  
                
                ##------------------------------------------------------
                ## compare validation metrics against reference metrics.
                ##------------------------------------------------------
#                 for k in refer_metrics.keys():
#                     if k in val_metrics.keys():
#                         for kk in val_metrics[k].keys():
#                             if not kk in refer_metrics[k].keys():
#                                 continue
#                             if (k == 'sn' and kk in ['Angle Mean', 'Angle Median']) or (
#                                     k == 'depth' and not kk.startswith('sigma')) or (kk == 'err'):
#                                 value = refer_metrics[k][kk] / val_metrics[k][kk]
#                             else:
#                                 value = val_metrics[k][kk] / refer_metrics[k][kk]
#                             value = value / len(list(set(val_metrics[k].keys()) & set(refer_metrics[k].keys())))
#                             new_value += value

#                 print('Best Value %.4f  New value: %.4f' % new_value)

                # if results have improved, save these results and issue a checkpoint

#                 if (new_value > best_value):
#                     print('Previous best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)
#                     best_value = new_value
#                     best_metrics = val_metrics
#                     best_iter = current_iter
#                     environ.save_checkpoint('best', current_iter)
#                     print('New      best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)                         
#                     print('Best Value %.4f  New value: %.4f' % new_value)

                # if results have improved, save these results and issue a checkpoint   

            environ.train()
            #-------------------------------------------------------
            # END validation process
            #-------------------------------------------------------       
#             print_heading(f"{timestring()} - SWITCH TO ALPHA TRAINING    current_iter: {current_iter}\n"
#               f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#               f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']}",
#               verbose = False)       
            flag = 'update_alpha'
            environ.fix_w()
            environ.free_alpha()
        #-------------------------------------------------------
        # end validation process
        #-------------------------------------------------------               


    #-----------------------------------------
    # Train & Update the  policy 
    #-----------------------------------------
    if flag == 'update_alpha':
        current_iter_a = 0
        stop_iter_a = opt['train']['alpha_iter_alternate']

        with trange( +1, stop_iter_a+1 , initial = 0, total = stop_iter_a, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} policy training") as t_policy :
            for current_iter_a in t_policy:    
                current_iter += 1

                batch = next(train2_loader)
                environ.set_inputs(batch, train2_loader.dataset.input_size)

                if opt['is_curriculum']:
                    num_train_layers = (p_epoch // opt['curriculum_speed']) + 1
                else:
                    num_train_layers = None

                print_dbg(f" num_train_layers  : {num_train_layers}", verbose = False)

                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)
                
                t_policy.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                      'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = f"Policy training epoch:{curr_epoch} iteration:", verbose=True)
                    environ.resize_results()
                    # environ.visual_policy(current_iter)

        if( current_iter_a >= stop_iter_a):            
            environ.print_loss(current_iter, start_time, title = f"Policy training epoch:{curr_epoch} iteration:", verbose=True)
            flag = 'update_w'
            environ.fix_alpha()
            environ.free_w(opt['fix_BN'])
            environ.decay_temperature()

            # print the distribution
            print_dbg(np.concatenate(environ.get_policy_prob(), axis=-1), verbose = False)
            
            p_epoch += 1
            print_dbg(f"** p_epoch incremented: {p_epoch}")

 Alternate Weight/Policy training:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:1 iteration:  4800 -  Total Loss: 10.3443     Task Loss: 10.3443  
Weight training epoch:1 iteration:  4837 -  Total Loss: 10.4821     Task Loss: 10.4821  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 1 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:1 iteration:  4900 -  Total Loss: 10.4637     Task Loss: 10.3726  Policy Losses:  Sparsity: 0.0910      Sharing: 5.99474e-05 
Policy training epoch:1 iteration:  4924 -  Total Loss: 10.2198     Task Loss: 10.1289  Policy Losses:  Sparsity: 0.0908      Sharing: 8.92505e-05 


Epoch 2 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:2 iteration:  5000 -  Total Loss: 10.3604     Task Loss: 10.3604  
Weight training epoch:2 iteration:  5011 -  Total Loss: 10.3584     Task Loss: 10.3584  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 2 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:2 iteration:  5098 -  Total Loss: 10.2184     Task Loss: 10.1280  Policy Losses:  Sparsity: 0.0903      Sharing: 6.05732e-05 


Epoch 3 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:3 iteration:  5100 -  Total Loss: 10.3918     Task Loss: 10.3918  
Weight training epoch:3 iteration:  5185 -  Total Loss: 10.4739     Task Loss: 10.4739  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 3 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:3 iteration:  5200 -  Total Loss: 10.4660     Task Loss: 10.3757  Policy Losses:  Sparsity: 0.0902      Sharing: 8.31410e-05 
Policy training epoch:3 iteration:  5272 -  Total Loss: 10.2202     Task Loss: 10.1303  Policy Losses:  Sparsity: 0.0898      Sharing: 4.99636e-05 


Epoch 4 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:4 iteration:  5300 -  Total Loss: 10.2392     Task Loss: 10.2392  
Weight training epoch:4 iteration:  5359 -  Total Loss: 10.4066     Task Loss: 10.4066  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 4 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:4 iteration:  5400 -  Total Loss: 10.4815     Task Loss: 10.3918  Policy Losses:  Sparsity: 0.0896      Sharing: 3.52263e-05 
Policy training epoch:4 iteration:  5446 -  Total Loss: 10.1920     Task Loss: 10.1026  Policy Losses:  Sparsity: 0.0893      Sharing: 8.32677e-05 


Epoch 5 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:5 iteration:  5500 -  Total Loss: 10.3851     Task Loss: 10.3851  
Weight training epoch:5 iteration:  5533 -  Total Loss: 10.3619     Task Loss: 10.3619  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 5 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:5 iteration:  5600 -  Total Loss: 10.3649     Task Loss: 10.2759  Policy Losses:  Sparsity: 0.0890      Sharing: 2.87369e-05 
Policy training epoch:5 iteration:  5620 -  Total Loss: 10.1988     Task Loss: 10.1099  Policy Losses:  Sparsity: 0.0889      Sharing: 3.63886e-05 


Epoch 6 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:6 iteration:  5700 -  Total Loss: 10.3132     Task Loss: 10.3132  
Weight training epoch:6 iteration:  5707 -  Total Loss: 10.4056     Task Loss: 10.4056  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 6 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:6 iteration:  5794 -  Total Loss: 10.2053     Task Loss: 10.1168  Policy Losses:  Sparsity: 0.0884      Sharing: 4.96805e-05 


Epoch 7 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:7 iteration:  5800 -  Total Loss: 10.3344     Task Loss: 10.3344  
Weight training epoch:7 iteration:  5881 -  Total Loss: 10.4343     Task Loss: 10.4343  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 7 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:7 iteration:  5900 -  Total Loss: 10.4905     Task Loss: 10.4022  Policy Losses:  Sparsity: 0.0883      Sharing: 4.77433e-05 
Policy training epoch:7 iteration:  5968 -  Total Loss: 10.2470     Task Loss: 10.1591  Policy Losses:  Sparsity: 0.0879      Sharing: 4.71324e-05 


Epoch 8 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:8 iteration:  6000 -  Total Loss: 10.4128     Task Loss: 10.4128  
Weight training epoch:8 iteration:  6055 -  Total Loss: 10.4556     Task Loss: 10.4556  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 8 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:8 iteration:  6100 -  Total Loss: 10.4868     Task Loss: 10.3991  Policy Losses:  Sparsity: 0.0876      Sharing: 5.10141e-05 
Policy training epoch:8 iteration:  6142 -  Total Loss: 10.1770     Task Loss: 10.0895  Policy Losses:  Sparsity: 0.0874      Sharing: 5.91502e-05 


Epoch 9 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:9 iteration:  6200 -  Total Loss: 10.3505     Task Loss: 10.3505  
Weight training epoch:9 iteration:  6229 -  Total Loss: 10.4509     Task Loss: 10.4509  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 9 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:9 iteration:  6300 -  Total Loss: 10.4771     Task Loss: 10.3900  Policy Losses:  Sparsity: 0.0870      Sharing: 9.53376e-05 
Policy training epoch:9 iteration:  6316 -  Total Loss: 10.1940     Task Loss: 10.1070  Policy Losses:  Sparsity: 0.0870      Sharing: 5.18635e-05 


Epoch 10 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:10 iteration:  6400 -  Total Loss: 10.4157     Task Loss: 10.4157  
Weight training epoch:10 iteration:  6403 -  Total Loss: 10.4430     Task Loss: 10.4430  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 10 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:10 iteration:  6490 -  Total Loss: 10.2282     Task Loss: 10.1417  Policy Losses:  Sparsity: 0.0865      Sharing: 4.17605e-05 


Epoch 11 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:11 iteration:  6500 -  Total Loss: 10.4742     Task Loss: 10.4742  
Weight training epoch:11 iteration:  6577 -  Total Loss: 10.4467     Task Loss: 10.4467  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 11 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:11 iteration:  6600 -  Total Loss: 10.4873     Task Loss: 10.4009  Policy Losses:  Sparsity: 0.0864      Sharing: 6.02156e-05 
Policy training epoch:11 iteration:  6664 -  Total Loss: 10.2219     Task Loss: 10.1359  Policy Losses:  Sparsity: 0.0860      Sharing: 5.37261e-05 


Epoch 12 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:12 iteration:  6700 -  Total Loss: 10.3150     Task Loss: 10.3150  
Weight training epoch:12 iteration:  6751 -  Total Loss: 10.4088     Task Loss: 10.4088  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 12 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:12 iteration:  6800 -  Total Loss: 10.4325     Task Loss: 10.3466  Policy Losses:  Sparsity: 0.0858      Sharing: 8.94219e-05 
Policy training epoch:12 iteration:  6838 -  Total Loss: 10.1496     Task Loss: 10.0639  Policy Losses:  Sparsity: 0.0856      Sharing: 4.69163e-05 


Epoch 13 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:13 iteration:  6900 -  Total Loss: 10.3839     Task Loss: 10.3839  
Weight training epoch:13 iteration:  6925 -  Total Loss: 10.3225     Task Loss: 10.3225  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 13 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:13 iteration:  7000 -  Total Loss: 10.4876     Task Loss: 10.4023  Policy Losses:  Sparsity: 0.0852      Sharing: 5.40465e-05 
Policy training epoch:13 iteration:  7012 -  Total Loss: 10.1795     Task Loss: 10.0943  Policy Losses:  Sparsity: 0.0851      Sharing: 5.45904e-05 


Epoch 14 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:14 iteration:  7099 -  Total Loss: 10.4584     Task Loss: 10.4584  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 14 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:14 iteration:  7100 -  Total Loss: 10.5757     Task Loss: 10.4905  Policy Losses:  Sparsity: 0.0851      Sharing: 6.83367e-05 
Policy training epoch:14 iteration:  7186 -  Total Loss: 10.1511     Task Loss: 10.0664  Policy Losses:  Sparsity: 0.0847      Sharing: 5.76228e-05 


Epoch 15 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:15 iteration:  7200 -  Total Loss: 10.4350     Task Loss: 10.4350  
Weight training epoch:15 iteration:  7273 -  Total Loss: 10.3049     Task Loss: 10.3049  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 15 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:15 iteration:  7300 -  Total Loss: 10.4687     Task Loss: 10.3841  Policy Losses:  Sparsity: 0.0846      Sharing: 6.34938e-05 
Policy training epoch:15 iteration:  7360 -  Total Loss: 10.2308     Task Loss: 10.1465  Policy Losses:  Sparsity: 0.0842      Sharing: 8.51154e-05 


Epoch 16 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:16 iteration:  7400 -  Total Loss: 10.3178     Task Loss: 10.3178  
Weight training epoch:16 iteration:  7447 -  Total Loss: 10.4329     Task Loss: 10.4329  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 16 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:16 iteration:  7500 -  Total Loss: 10.5069     Task Loss: 10.4229  Policy Losses:  Sparsity: 0.0840      Sharing: 2.59131e-05 
Policy training epoch:16 iteration:  7534 -  Total Loss: 10.1672     Task Loss: 10.0833  Policy Losses:  Sparsity: 0.0838      Sharing: 8.43182e-05 


Epoch 17 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:17 iteration:  7600 -  Total Loss: 10.3373     Task Loss: 10.3373  
Weight training epoch:17 iteration:  7621 -  Total Loss: 10.4282     Task Loss: 10.4282  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 17 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:17 iteration:  7700 -  Total Loss: 10.4995     Task Loss: 10.4160  Policy Losses:  Sparsity: 0.0834      Sharing: 4.99710e-05 
Policy training epoch:17 iteration:  7708 -  Total Loss: 10.1474     Task Loss: 10.0640  Policy Losses:  Sparsity: 0.0834      Sharing: 3.43844e-05 


Epoch 18 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:18 iteration:  7795 -  Total Loss: 10.4948     Task Loss: 10.4948  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 18 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:18 iteration:  7800 -  Total Loss: 10.5333     Task Loss: 10.4499  Policy Losses:  Sparsity: 0.0833      Sharing: 4.71473e-05 
Policy training epoch:18 iteration:  7882 -  Total Loss: 10.1764     Task Loss: 10.0934  Policy Losses:  Sparsity: 0.0829      Sharing: 4.29526e-05 


Epoch 19 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:19 iteration:  7900 -  Total Loss: 10.4252     Task Loss: 10.4252  
Weight training epoch:19 iteration:  7969 -  Total Loss: 10.4391     Task Loss: 10.4391  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 19 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:19 iteration:  8000 -  Total Loss: 10.5044     Task Loss: 10.4215  Policy Losses:  Sparsity: 0.0828      Sharing: 5.36218e-05 
Policy training epoch:19 iteration:  8056 -  Total Loss: 10.2169     Task Loss: 10.1343  Policy Losses:  Sparsity: 0.0825      Sharing: 4.26099e-05 


Epoch 20 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:20 iteration:  8100 -  Total Loss: 10.3818     Task Loss: 10.3818  
Weight training epoch:20 iteration:  8143 -  Total Loss: 10.4374     Task Loss: 10.4374  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 20 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:20 iteration:  8200 -  Total Loss: 10.4586     Task Loss: 10.3763  Policy Losses:  Sparsity: 0.0822      Sharing: 6.64741e-05 
Policy training epoch:20 iteration:  8230 -  Total Loss: 10.1490     Task Loss: 10.0668  Policy Losses:  Sparsity: 0.0821      Sharing: 6.95288e-05 


Epoch 21 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:21 iteration:  8300 -  Total Loss: 10.3546     Task Loss: 10.3546  
Weight training epoch:21 iteration:  8317 -  Total Loss: 10.3423     Task Loss: 10.3423  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 21 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:21 iteration:  8400 -  Total Loss: 10.3516     Task Loss: 10.2698  Policy Losses:  Sparsity: 0.0817      Sharing: 8.78572e-05 
Policy training epoch:21 iteration:  8404 -  Total Loss: 10.2236     Task Loss: 10.1418  Policy Losses:  Sparsity: 0.0817      Sharing: 9.13367e-05 


Epoch 22 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:22 iteration:  8491 -  Total Loss: 10.3814     Task Loss: 10.3814  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 22 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:22 iteration:  8500 -  Total Loss: 10.4398     Task Loss: 10.3581  Policy Losses:  Sparsity: 0.0816      Sharing: 7.10785e-05 
Policy training epoch:22 iteration:  8578 -  Total Loss: 10.1638     Task Loss: 10.0825  Policy Losses:  Sparsity: 0.0813      Sharing: 4.48525e-05 


Epoch 23 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:23 iteration:  8600 -  Total Loss: 10.3301     Task Loss: 10.3301  
Weight training epoch:23 iteration:  8665 -  Total Loss: 10.4598     Task Loss: 10.4598  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 23 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:23 iteration:  8700 -  Total Loss: 10.4922     Task Loss: 10.4110  Policy Losses:  Sparsity: 0.0811      Sharing: 4.86225e-05 
Policy training epoch:23 iteration:  8752 -  Total Loss: 10.1523     Task Loss: 10.0714  Policy Losses:  Sparsity: 0.0809      Sharing: 5.57154e-05 


Epoch 24 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:24 iteration:  8800 -  Total Loss: 10.3752     Task Loss: 10.3752  
Weight training epoch:24 iteration:  8839 -  Total Loss: 10.4587     Task Loss: 10.4587  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 24 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:24 iteration:  8900 -  Total Loss: 10.5191     Task Loss: 10.4384  Policy Losses:  Sparsity: 0.0806      Sharing: 4.79445e-05 
Policy training epoch:24 iteration:  8926 -  Total Loss: 10.1637     Task Loss: 10.0832  Policy Losses:  Sparsity: 0.0805      Sharing: 3.72082e-05 


Epoch 25 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:25 iteration:  9000 -  Total Loss: 10.3947     Task Loss: 10.3947  
Weight training epoch:25 iteration:  9013 -  Total Loss: 10.4423     Task Loss: 10.4423  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 25 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:25 iteration:  9100 -  Total Loss: 10.0917     Task Loss: 10.0115  Policy Losses:  Sparsity: 0.0801      Sharing: 8.66130e-05 
Policy training epoch:25 iteration:  9100 -  Total Loss: 10.0917     Task Loss: 10.0115  Policy Losses:  Sparsity: 0.0801      Sharing: 8.66130e-05 


Epoch 26 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:26 iteration:  9187 -  Total Loss: 10.3982     Task Loss: 10.3982  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 26 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:26 iteration:  9200 -  Total Loss: 10.5662     Task Loss: 10.4861  Policy Losses:  Sparsity: 0.0800      Sharing: 5.70118e-05 
Policy training epoch:26 iteration:  9274 -  Total Loss: 10.1781     Task Loss: 10.0983  Policy Losses:  Sparsity: 0.0797      Sharing: 7.51391e-05 


Epoch 27 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:27 iteration:  9300 -  Total Loss: 10.3665     Task Loss: 10.3665  
Weight training epoch:27 iteration:  9361 -  Total Loss: 10.4524     Task Loss: 10.4524  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 27 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:27 iteration:  9400 -  Total Loss: 10.5266     Task Loss: 10.4470  Policy Losses:  Sparsity: 0.0795      Sharing: 5.79357e-05 
Policy training epoch:27 iteration:  9448 -  Total Loss: 10.1543     Task Loss: 10.0749  Policy Losses:  Sparsity: 0.0793      Sharing: 6.80909e-05 


Epoch 28 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:28 iteration:  9500 -  Total Loss: 10.4233     Task Loss: 10.4233  
Weight training epoch:28 iteration:  9535 -  Total Loss: 10.5062     Task Loss: 10.5062  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 28 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:28 iteration:  9600 -  Total Loss: 10.3582     Task Loss: 10.2792  Policy Losses:  Sparsity: 0.0790      Sharing: 3.50550e-05 
Policy training epoch:28 iteration:  9622 -  Total Loss: 10.1589     Task Loss: 10.0799  Policy Losses:  Sparsity: 0.0789      Sharing: 5.70640e-05 


Epoch 29 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:29 iteration:  9700 -  Total Loss: 10.2570     Task Loss: 10.2570  
Weight training epoch:29 iteration:  9709 -  Total Loss: 10.4888     Task Loss: 10.4888  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 29 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:29 iteration:  9796 -  Total Loss: 10.1821     Task Loss: 10.1035  Policy Losses:  Sparsity: 0.0785      Sharing: 7.33137e-05 


Epoch 30 weight training:   0%|          | 0/87 [00:00<?, ?it/s]

Weight training epoch:30 iteration:  9800 -  Total Loss: 10.4237     Task Loss: 10.4237  
Weight training epoch:30 iteration:  9883 -  Total Loss: 10.5041     Task Loss: 10.5041  


validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 30 policy training:   0%|          | 0/87 [00:00<?, ?it/s]

Policy training epoch:30 iteration:  9900 -  Total Loss: 10.4802     Task Loss: 10.4017  Policy Losses:  Sparsity: 0.0785      Sharing: 4.76837e-05 
Policy training epoch:30 iteration:  9970 -  Total Loss: 10.1590     Task Loss: 10.0808  Policy Losses:  Sparsity: 0.0781      Sharing: 6.90147e-05 


In [72]:
# num_train_layers = 2
# print(num_train_layers)

# logits = torch.rand((16,2)) 
# print(logits)

# if num_policy_layers is None:
#     num_policy_layers = logits.shape[0]
# else:
#     assert (num_policy_layers == logits.shape[0])
# print(num_policy_layers)

# num_blocks = min(num_train_layers, logits.shape[0])
# print(num_blocks)

# gt = torch.ones((num_blocks)).long() 
# print('gt: ', gt)
# gt_z = torch.zeros((num_blocks)).long() 
# print('gt: ', gt_z)

# loss_weights = ((torch.arange(0, num_policy_layers, 1) + 1).float() / num_policy_layers)
# print('loss_weights: ', loss_weights)

# t1 = loss_weights[-num_blocks:] 
# print(t1)

# t2 = logits[-num_blocks:]
# print(t2)

# ce = environ.cross_entropy_sparsity(logits[-num_blocks:], gt)
# print(ce)

# ce = environ.cross_entropy_sparsity(logits[-num_blocks:], gt_z)
# print(ce)

### Warm-up:  validation - Dev

In [None]:
# validation
if should(current_iter, opt['train']['val_freq']):
    print(f"**  {timestring()}  START VALIDATION iteration: {current_iter} ")    

    environ.eval()     # set to evaluation mode (train = False)
    num_seg_class = opt['tasks_num_class'][opt['tasks'].index('seg')] if 'seg' in opt['tasks'] else -1
    val_metrics = eval_dev(environ, 
                          val_loader, 
                          opt['tasks'], 
                          policy=False, 
                          num_train_layers=None, 
                          eval_iter = 4)

In [None]:
val_metrics.keys()
val_metrics['loss']

In [None]:
for i in val_metrics:
    print(f'\n {i} \n -----------------')
    print(val_metrics[i])

In [None]:
    for t_id, task in enumerate(environ.tasks):
        task_key = f"task{t_id+1}"    
        environ.print_loss(current_iter, start_time, val_metrics[task_key]["classification_agg"], title='validation')
    

In [None]:
    environ.save_checkpoint('latest', current_iter)

    print(f"** {timestring()} - END VALIDATION iteration:  {current_iter} ")                
    environ.train()    # set to training mode (train = True)

In [None]:
# for i in val_metrics.keys():
#     print(i, type(val_metrics[i]))
#     for k in val_metrics[i].keys():
#         print(i,k, type(val_metrics[i][k]))
#         if isinstance(val_metrics[i][k], pd.core.series.Series):
#             print(f"val_metrics[{i}][{k}] is a series")
#         elif isinstance(val_metrics[i][k], pd.core.frame.DataFrame):
#             print(f"val_metrics[{i}][{k}] is a dataframe")        

# s = val_metrics['task1']['classification_agg']
# print(s)
# print(s.to_dict())

### Load previously saved model

In [None]:
print_separator('READ YAML')
opt, gpu_ids, exp_ids =read_yaml_from_input(input_args)
print(gpu_ids, exp_ids,  opt['train']['policy_iter'])

current_iter = environ.load_checkpoint('latest')

print('Evaluating the snapshot saved at %d iter' % current_iter)

opt['train']['weight_iter_alternate'] = opt['train'].get('weight_iter_alternate', len(train1_loader))
opt['train']['alpha_iter_alternate'] = opt['train'].get('alpha_iter_alternate'  , len(train2_loader))

print(opt['train']['weight_iter_alternate'], opt['train']['alpha_iter_alternate'])

### Weight Training Dev

#### Weight training - prep

In [None]:
# arch_parms = environ.networks['mtl-net'].named_parameters()
# print(arch_parms)
# for name, parm in arch_parms:
#     print(name, '    ',parm.requires_grad)
# print_underline('MTL3_Dev Policys', verbose = True)
# for i in   environ.networks['mtl-net'].policys:
#     print(i)

In [None]:
print_heading(f"** {timestring()} - Training current iteration {current_iter}  flag: {flag} ", verbose = True)    

current_iter_w = 0 
current_iter_a = 0 
batch_idx_a = 0 
batch_idx_w = 0 

if flag_warmup:
    print_heading(f"** Set optimizer and scheduler to policy_learning = True", verbose = True)
    environ.define_optimizer(policy_learning=True)
    environ.define_scheduler(policy_learning=True)
    flag_warmup = False

if current_iter == opt['train']['warm_up_iters']:
    print_heading(f"** Switch from Warm Up training to Alternate training Weights & Policy \n"
                  f"   Take checkpoint and block gradient flow through Policy net", verbose=True)
    environ.save_checkpoint('warmup', current_iter)
    environ.fix_alpha()
    
# batch_enumerator1 = enumerate(train1_loader,1)  
# batch_enumerator2 = enumerate(train2_loader,1)  

train_total_epochs = 10

#### Weight training - main 

In [None]:
print(f"opt['train']['print_freq']         {opt['train']['print_freq']}")
print(f"opt['train']['hard_sampling']      {opt['train']['hard_sampling']}")
print(f"opt['policy']                      {opt['policy']}")
print(f"opt['tasks']                       {opt['tasks']}")
print(f"weight_iter_alternate:             {opt['train']['weight_iter_alternate']}")
print(f"alpha_iter_alternate :             {opt['train']['alpha_iter_alternate']}")
print(f"current_iter                       {current_iter  }")
print(f"current_iter_w                     {current_iter_w}")
print(f"current_iter_a                     {current_iter_a}")
print(f"batch_idx_w                        {batch_idx_w}")
print(f"flag                               {flag          }")
print(f"train_total_epochs                 {train_total_epochs}") 

In [None]:
##---------------------------------------------------------------     
## Weight / Policy Training
##--------------------------------------------------------------- 
# stop_iter = current_iter_w +  opt['train']['weight_iter_alternate']
# print(f" Current Weight iteration {current_iter_w} - Run  from {current_iter_w+1} to {stop_iter+1}")


In [None]:
from tqdm.notebook import trange, tqdm

In [None]:
# with tnrange(start_iter_t , stop_iter_t  , initial = start_iter_t , total = stop_iter_t, position=0, leave= True, desc="master") as t :
# with tqdm_notebook(total=train_total_epochs) as t:

In [None]:
curr_epoch = 0
main_iter_ctr = 0 
verbose = False
t = tqdm(total=train_total_epochs, desc=f" Alternate Weight/Policy training")

while curr_epoch < train_total_epochs:
    curr_epoch+=1
    t.update(1)

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------
    if flag == 'update_w':
        current_iter_w  = 0 
        stop_iter_w =   opt['train']['weight_iter_alternate']

        with trange(+1, stop_iter_w+1 , initial = current_iter_w, total = stop_iter_w, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} weight training") as t_weights :
            for current_iter_w in t_weights:    
                current_iter += 1

                start_time = time.time()
                environ.train()
                
#                 if batch_idx_w == len(train1_loader):
#                     print_dbg(f"  Reenumerate train1_loader -  index_w: {batch_idx_w}   len(train1_loader) = {len(train1_loader)} ", verbose)
#                     batch_enumerator1 = enumerate(train1_loader,1)    
                    
                batch = next(train1_loader)
                environ.set_inputs(batch, train1_loader.dataset.input_size)

                ##----------------------------------------------------------------------
                ## Set number of layers to train based on cirriculum_speed 
                ## and p_epoch (number of epochs of policy training)
                ## When curriculum_speed == 3, a num_train_layers is incremented 
                ## after completion of every 3 policy training epochs
                ##----------------------------------------------------------------------
                if opt['is_curriculum']:
                    num_train_layers = p_epoch // opt['curriculum_speed'] + 1
                else:
                    num_train_layers = None


#                 print_heading(f"{timestring()} CALL ENVIRON.OPTIMIZE()    current_iter: {current_iter}     flag: {flag}\n"
#                       f"{' ':10s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                       f"{' ':10s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} \n"                          
#                       f"{' ':10s} is_policy: {opt['policy']}     p_epoch: {p_epoch}       num_train_layers: {num_train_layers}", verbose = False) 

                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)

                t_weights.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                       'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = "Weight training iteration", verbose = True)
                    environ.resize_results()

#                 print_heading(f"{timestring()} - CONTINUE WEIGHT TRAINING   current_iter: {current_iter}\n"
#                   f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                   f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']}",
#                   verbose = False)        

        #-------------------------------------------------------
        # validation process
        #------------------------------------------------------- 

#         if should(current_iter_w, opt['train']['weight_iter_alternate']): 

        if (current_iter_w >= stop_iter_w):
            environ.eval()
            print_dbg("++ Weight Training Validation  and then Switch to update_alpha", verbose = False)

            val_metrics = eval_dev(environ, 
                                  val_loader, 
                                  opt['tasks'], 
                                  policy=opt['policy'],
                                  num_train_layers=num_train_layers, 
                                  hard_sampling=opt['train']['hard_sampling'],
                                  eval_iter = -1)        

            if (verbose):
                for t_id, task in enumerate(environ.tasks):
                    task_key = f"task{t_id+1}"    
                    environ.print_metrics(current_iter, start_time, val_metrics[task_key]["classification_agg"], title='validation', verbose = verbose)        

            environ.save_checkpoint('latest', current_iter)

            #----------------------------------------------------------------------------------------------
            # if number of iterations completed after the warm up phase is greater than the number of 
            # (weight/policy alternations) x (cirriculum speed) x (number of layers to be policy trained)
            #
            # check metrics for improvement, and issue a checkpoint if necessary
            #----------------------------------------------------------------------------------------------



            if current_iter - opt['train']['warm_up_iters'] >= num_blocks * opt['curriculum_speed'] * \
                    (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']):
                new_value = 0
#                 print_heading(f"  evaluate progress and make checkpoint if necessary." , verbose = True)
#                 print(f" current iter                                 : {current_iter} \n"
#                       f" opt['train']['warm_up_iters']                : {opt['train']['warm_up_iters']} \n"
#                       f" num_blocks                                   : {num_blocks} \n"
#                       f" opt['curriculum_speed']                      : {opt['curriculum_speed']}\n"
#                       f" opt['train']['weight_iter_alternate']        : {opt['train']['weight_iter_alternate']}\n"
#                       f" opt['train']['alpha_iter_alternate']         : {opt['train']['alpha_iter_alternate']}\n"
#                       f" alpha_iter_alternate + weight_iter_alternate : {opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']}\n"
#                       f" num_blks * curriculum_speed * (alpha_alternate + weight_alternate): "
#                       f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} \n"

                print(f"  {current_iter - opt['train']['warm_up_iters']} IS GREATER THAN "
                       f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} -- "
                       f"  evaluate progress and make checkpoint if necessary." )            
#               ## compare validation metrics against reference metrics.

#                 for k in refer_metrics.keys():
#                     if k in val_metrics.keys():
#                         for kk in val_metrics[k].keys():
#                             if not kk in refer_metrics[k].keys():
#                                 continue
#                             if (k == 'sn' and kk in ['Angle Mean', 'Angle Median']) or (
#                                     k == 'depth' and not kk.startswith('sigma')) or (kk == 'err'):
#                                 value = refer_metrics[k][kk] / val_metrics[k][kk]
#                             else:
#                                 value = val_metrics[k][kk] / refer_metrics[k][kk]
#                             value = value / len(list(set(val_metrics[k].keys()) & set(refer_metrics[k].keys())))
#                             new_value += value

#                 print('Best Value %.4f  New value: %.4f' % new_value)

                # if results have improved, save these results and issue a checkpoint

#                 if (new_value > best_value):
#                     print('Previous best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)
#                     best_value = new_value
#                     best_metrics = val_metrics
#                     best_iter = current_iter
#                     environ.save_checkpoint('best', current_iter)
#                     print('New      best iter: %d, best_value: %.4f' % (best_iter, best_value), best_metrics)                         
#                     print('Best Value %.4f  New value: %.4f' % new_value)

                # if results have improved, save these results and issue a checkpoint   

            environ.train()
            #-------------------------------------------------------
            # END validation process
            #-------------------------------------------------------       
            print_heading(f"{timestring()} - SWITCH TO ALPHA TRAINING    current_iter: {current_iter}\n"
              f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
              f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']}",
              verbose = False)       
            flag = 'update_alpha'
            environ.fix_w()
            environ.free_alpha()
        #-------------------------------------------------------
        # end validation process
        #-------------------------------------------------------               


    #-----------------------------------------
    # Train & Update the  policy 
    #-----------------------------------------
    if flag == 'update_alpha':
        current_iter_a = 0
        stop_iter_a = opt['train']['alpha_iter_alternate']

        with trange( +1, stop_iter_a+1 , initial = 0, total = stop_iter_a, 
                     position=0, leave= False, desc=f"Epoch {curr_epoch} policy training") as t_policy :
            for current_iter_a in t_policy:    
                current_iter += 1

#                 batch_idx_a, batch = next(batch_enumerator2)
                batch = next(train2_loader)
                environ.set_inputs(batch, train2_loader.dataset.input_size)

#                 if batch_idx_a == len(train2_loader):
#                     print_dbg(f" Re-enumerate train2_loader  batch_idx_a: {batch_idx_a}   len(train2_loader) = {len(train2_loader)}", verbose=False)                
#                     batch_enumerator2 = enumerate(train2_loader,1)        

#                 print_heading(f"{timestring()} - ENVIRON.OPTIMIZE()    flag: {flag}    current_iter: {current_iter}   \n"
#                               f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                               f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} \n"
#                               f" is_policy: {opt['policy']}   num_train_layers: {num_train_layers}  hard_sampling: {opt['train']['hard_sampling']}\n"
#                               f" is_curriculum: {opt['is_curriculum']}     curriculum_speed: {opt['curriculum_speed']}   p_epoch: {p_epoch}"
#                               , verbose = False) 

                if opt['is_curriculum']:
                    num_train_layers = (p_epoch // opt['curriculum_speed']) + 1
                else:
                    num_train_layers = None

                print_dbg(f" num_train_layers  : {num_train_layers}", verbose = False)


                environ.optimize(opt['lambdas'], 
                                 is_policy=opt['policy'], 
                                 flag=flag, 
                                 num_train_layers=num_train_layers,
                                 hard_sampling=opt['train']['hard_sampling'],
                                 verbose = False)
                
                t_policy.set_postfix({'iteration': current_iter, 'Loss': f"{environ.losses['total']['total'].item():.4f}" , 
                                      'row_ids':f"{batch['row_id'][0]}-{batch['row_id'][-1]}"})
                
                if should(current_iter, opt['train']['print_freq']):
                    environ.print_loss(current_iter, start_time, title = "Policy training iteration", verbose=True)
                    environ.resize_results()
                    # environ.visual_policy(current_iter)

#                 print_heading(f"{timestring()} - CONTINUE ALPHA TRAINING    current_iter: {current_iter}\n"
#                               f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                               f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ", 
#                               verbose = False )      

        ## if (current_iter_a % alpha_iter_alternate) == 0 
#         if should(current_iter_a, opt['train']['alpha_iter_alternate']):
#         print(f" policy loop ended - current_iter_a: {current_iter_a}   stop_iter_a: {stop_iter_a}")
        if( current_iter_a >= stop_iter_a):            
#             print_heading(f"{timestring()} - SWITCH TO WEIGHT TRAINING  urrent_iter: {current_iter}\n"
#                           f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
#                           f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ",
#                           verbose = False )       

            flag = 'update_w'
            environ.fix_alpha()
            environ.free_w(opt['fix_BN'])
            environ.decay_temperature()

            # print the distribution
            print_dbg(np.concatenate(environ.get_policy_prob(), axis=-1), verbose = False)
            
            p_epoch += 1
            print_dbg(f"** p_epoch incremented: {p_epoch}")

In [None]:
print(f"{opt['train']['Lambda_sharing']:.5e}")

In [None]:
print('Previous best iter: %d, best_value: %.4f' % (best_iter, best_value))
print(best_metrics)
best_value = new_value
best_metrics = val_metrics
best_iter = current_iter
environ.save_checkpoint('best', current_iter)
print('New best iter : %d, best_value: %.4f \n' % (best_iter, best_value))                         
print(best_metrics)

In [None]:
# environ.losses['tasks'] = {'total' : torch.tensor(0.0, device  = environ.device, dtype=torch.float64)}
# environ.device

# print(val_metrics)
pp.pprint(environ.losses)
# environ.print_loss_2(current_iter, start_time, verbose=True)

### Policy Training 

In [None]:
# print(f" current iter                                 : {current_iter} \n"
#       f" opt['train']['warm_up_iters']                : {opt['train']['warm_up_iters']} \n"
#       f" num_blocks                                   : {num_blocks} \n"
#       f" opt['curriculum_speed']                      : {opt['curriculum_speed']}\n"
#       f" opt['train']['weight_iter_alternate']        : {opt['train']['weight_iter_alternate']}\n"
#       f" opt['train']['alpha_iter_alternate']         : {opt['train']['alpha_iter_alternate']}\n"
#       f" alpha_iter_alternate + weight_iter_alternate : {opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate']}\n"
#       f" num_blocks * curriculum_speed * (alpha_iter_alternate + weight_iter_alternate): \
#           {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])} \n"
#       f" IF {current_iter - opt['train']['warm_up_iters']} IS GREATER THAN  ??"
#       f" {num_blocks * opt['curriculum_speed'] * (opt['train']['weight_iter_alternate'] + opt['train']['alpha_iter_alternate'])}")


In [None]:
# print(f" task1_logits: {environ.networks['mtl-net'].task1_logits} \n")
# print(f" task2_logits: {environ.networks['mtl-net'].task2_logits} \n")
# print(f" task3_logits: {environ.networks['mtl-net'].task3_logits} \n")

In [None]:
# print(current_iter_a , opt['train']['alpha_iter_alternate'],flag)

In [None]:
##---------------------------------------------------------------     
## part one: warm up
##--------------------------------------------------------------- 
# print(current_iter_a , opt['train']['alpha_iter_alternate'],flag)
# stop_iter = current_iter_a +  opt['train']['alpha_iter_alternate']
# print(f" Run iteration {current_iter_a+1} to {stop_iter+1}")

In [None]:
# print(current_iter_a, stop_iter, flag)
# print(current_iter_a , opt['train']['alpha_iter_alternate'],flag)

In [None]:


if flag == 'update_alpha':

    stop_iter = current_iter_a +  opt['train']['alpha_iter_alternate']
    print(f" Current Alpha iteration {current_iter_a} - Run  from {current_iter_a+1} to {stop_iter+1}")
    
    with tnrange(current_iter_a+1, stop_iter+1 , initial = current_iter_a+1, total = stop_iter+1, position=0, leave= True, desc="weight training") as t :
        for current_iter_a in t:    
            current_iter += 1
 
            batch_idx_a, batch = next(batch_enumerator2)
            environ.set_inputs(batch, train2_loader.dataset.input_size)

            if batch_idx_a == len(train2_loader):
                print_dbg(f" Re-enumerate train2_loader  batch_idx_a: {batch_idx_a}   len(train2_loader) = {len(train2_loader)}", verbose=False)                
                batch_enumerator2 = enumerate(train2_loader,1)        
                  
            print_heading(f"{timestring()} - ENVIRON.OPTIMIZE()    flag: {flag}    current_iter: {current_iter}   \n"
                          f" current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
                          f" current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} \n"
                          f" is_policy: {opt['policy']}   num_train_layers: {num_train_layers}  hard_sampling: {opt['train']['hard_sampling']}\n"
                          f" is_curriculum: {opt['is_curriculum']}     curriculum_speed: {opt['curriculum_speed']}   p_epoch: {p_epoch}"
                          , verbose = False) 
    
            if opt['is_curriculum']:
                num_train_layers = (p_epoch // opt['curriculum_speed']) + 1
            else:
                num_train_layers = None

            print_dbg(f" num_train_layers  : {num_train_layers}", verbose = False)


            environ.optimize(opt['lambdas'], 
                             is_policy=opt['policy'], 
                             flag=flag, 
                             num_train_layers=num_train_layers,
                             hard_sampling=opt['train']['hard_sampling'],
                             verbose = False)

            if should(current_iter, opt['train']['print_freq']):
                environ.print_loss_2(current_iter, start_time, verbose=True)
                environ.resize_results()
                # environ.visual_policy(current_iter)

            print_heading(f"{timestring()} - CONTINUE ALPHA TRAINING    current_iter: {current_iter}\n"
                          f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
                          f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ", 
                          verbose = False )      
    
    ## if (current_iter_a % alpha_iter_alternate) == 0 
    if should(current_iter_a, opt['train']['alpha_iter_alternate']):
        print_dbg(f"** Switch training to update_weight")                
        print_heading(f"{timestring()} - SWITCH TO WEIGHT TRAINING  urrent_iter: {current_iter}\n"
                      f"{' ':15s} current_iter_w: {current_iter_w}  batch_idx_w:{batch_idx_w}   weight_iter_alternate: {opt['train']['weight_iter_alternate']}\n"
                      f"{' ':15s} current_iter_a: {current_iter_a}  batch_idx_a:{batch_idx_a}   alpha_iter_alternate : {opt['train']['alpha_iter_alternate']} ",
                      verbose = True )       
        
        flag = 'update_w'
        environ.fix_alpha()
        environ.free_w(opt['fix_BN'])
        environ.decay_temperature()

        # print the distribution
        dists = environ.get_policy_prob()

        print(np.concatenate(dists, axis=-1))
        p_epoch += 1
        print(f"** p_epoch incremented: {p_epoch}")




In [None]:
print(current_iter)

In [None]:
print(f" task1_logits: \n {environ.networks['mtl-net'].task1_logits.detach().cpu().numpy()} \n")
print(f" task2_logits: \n {environ.networks['mtl-net'].task2_logits.detach().cpu().numpy()} \n")
print(f" task3_logits: \n {environ.networks['mtl-net'].task3_logits.detach().cpu().numpy()} \n")
print(f" task1 softmax: \n {softmax(environ.networks['mtl-net'].task1_logits.detach().cpu().numpy(), axis = -1)} \n")
print(f" task2 softmax: \n {softmax(environ.networks['mtl-net'].task2_logits.detach().cpu().numpy(), axis = -1)} \n")
print(f" task3 softmax: \n {softmax(environ.networks['mtl-net'].task3_logits.detach().cpu().numpy(), axis = -1)} \n")

In [None]:
# for i in [1,2,3]:
#     task_pred = f"task{i}_pred"
#     task_logits = f"task{i}_logits"
#     policy_attr = f"policy{i}"
#     logits_attr = f"logit{i}"
#     print_heading(f"{task_pred}")
#     print(getattr(environ, task_pred))
#     print(policy_attr)
#     print(getattr(environ, policy_attr)) 
#     print(logits_attr)
#     print(getattr(environ, logits_attr)) 
#     print(task_logits)
#     print(getattr(environ.networks['mtl-net'], task_logits)) 

In [None]:
import tqdm.notebook

In [None]:
current_iter    = 2174
current_iter_a  = 348
current_iter_w  = 348

## Softmax & Gumbel Softmax

###  Softmax, LogSoftMax, NegLogLikelihood and Cross Entropy

In [None]:
from torch import nn
# print(nn.CrossEntropyLoss.__doc__)
loss = nn.CrossEntropyLoss(reduction ='none')
# i1 = torch.randn(3, 5, requires_grad=True)
# t1 = torch.empty(3, dtype=torch.long).random_(5)

# print(i1)
# print(i1)
# print(t1)
# output = loss(i1, t1)
# print(output, output.sum(), output.mean())

# i2 = torch.randn(1, 2, requires_grad=True)
i0 = torch.tensor([[0.0, 1.0]], dtype=torch.float)
i1 = torch.tensor([[1.0, 0.0]], dtype=torch.float)
i2 = torch.tensor([[0.5, 0.5]], dtype=torch.float)
i3 = torch.tensor([[0.4656, 0.5388]], dtype=torch.float)
sm = nn.Softmax(dim =-1)
lsm = nn.LogSoftmax(dim = -1)
nll = nn.NLLLoss(reduction='none')

t1 = torch.tensor([1], dtype=torch.int64)
t0 = torch.tensor([0], dtype=torch.int64)
t2 = torch.tensor([2], dtype=torch.int64)
print('i0     : ', i0)
print('sm(i0) : ', sm(i0))
print('lsm(i0): ', lsm(i0))
print()
print('i1     : ', i1)
print('sm(i1) : ', sm(i1))
print('lsm(i1): ', lsm(i1))
print()
print('i2     : ', i2)
print('sm(i2) : ', sm(i2))
print('lsm(i2): ', lsm(i2))
print()

print('t0: ',t0)
print('t1: ',t1)
print()
output1 = loss(i0, t0)
output2 = nll(lsm(i0), t0)
print('loss [0,1] and [0] : ', output1)
print('nll between lsm(i0): ', output2)
print()
output1 = loss(i0, t1)
output2 = nll(lsm(i0), t1)
print('loss [0,1] and [1] : ', output1)
print('nll between lsm(i0): ', output2)
print()

output1 = loss(i1, t0)
output2 = nll(lsm(i1), t0)
print('loss [1,0] and [0] : ', output1)
print('nll between lsm(i1): ', output2)
print()

output1 = loss(i1, t1)
output2 = nll(lsm(i1), t1)
print('loss [1,0] and [1] : ', output1)
print('nll between lsm(i1): ', output2)
print()

output1 = loss(i2, t0)
output2 = nll(lsm(i2), t0)
print('loss [0.5, 0.5] and [0] : ', output1)
print('nll between lsm(i1)   and [0] : ', output2)
print()

output1 = loss(i2, t1)
output2 = nll(lsm(i2), t1)
print('loss [0.5, 0.5] and [1] : ', output1)
print('nll between lsm(i1)   and [1] : ', output2)
print()


### Gumbel Softmax

In [None]:
a = torch.tensor([[0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000]], device='cuda:0', requires_grad=True) 

b = torch.tensor([0.5000, 0.5000,  0.5000, 0.5000], device='cuda:0', requires_grad=True) 
print(b.shape)
c = torch.tensor([[0.000, 0.000,  0.000, 0.000]], device='cuda:0', requires_grad=True) 
print(c.shape)
d = torch.tensor([[0.5000], [0.5000],  [0.5000], [0.5000]], device='cuda:0', requires_grad=True) 
print(d.shape)

In [None]:
print(i0)
print(i1)
print(i2)

In [None]:
temp  = 2.5
print(F.gumbel_softmax( i1, temp, hard=False))
print(F.gumbel_softmax( i1, temp, hard=False))
print(F.gumbel_softmax( i1, temp, hard=False))
print()

print(F.gumbel_softmax( i2, temp, hard=False))
print(F.gumbel_softmax( i2, temp, hard=False))
print(F.gumbel_softmax( i2, temp, hard=False))
print()

print(F.gumbel_softmax( i3, temp, hard=False))
print(F.gumbel_softmax( i3, temp, hard=False))
print(F.gumbel_softmax( i3, temp, hard=False))
# print(F.gumbel_softmax( d, 5, hard=False))

In [None]:
logits = torch.randn(20, 32)

In [None]:
print(logits[:2,:])

In [None]:
tmp = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
print(tmp[0])

In [None]:
print(tmp[0])
tmp1 = tmp.exponential_()
print(tmp1[0])

In [None]:
tmp2 = tmp1.log()
print(tmp2[0])

In [None]:
 gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    ) 

In [None]:
print(gumbels.shape)
print(gumbels[0])

In [None]:
# Sample soft categorical using reparametrization trick:
gumbel_soft = F.gumbel_softmax(logits, tau=1, hard=False)

# Sample hard categorical using "Straight-through" trick:
gumbel_hard  = F.gumbel_softmax(logits, tau=1, hard=True)

In [None]:
print(logits.shape)
print(logits[0])
print(np.argmax(logits[0]))
print('\n')

print(gumbel_soft.shape)
print(gumbel_soft[0])
print(np.argmax(gumbel_soft[0]))
print('\n')

print(gumbel_hard.shape)
print(gumbel_hard[0])
print(np.argmax(gumbel_hard[0]))

In [None]:
gumbel_soft.sum(axis=1)

In [None]:
tau = 1

In [None]:
a = torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)

In [None]:
a[:2]

fill tensor `a` with elements drawn from exponential distribution

In [None]:
a_e = a.exponential_()

In [None]:
a_e[:2]

draw natural log `ln()` on elements of a_e

In [None]:
a_e_l = a_e.log()

In [None]:
a_e_l[:2]

Neg log

In [None]:
a_el_neg = -a_e_l

a_el_neg[:2]

In [None]:
logits[:2]

In [None]:
gumbels = (logits + a_el_neg) / tau 

gumbels[:2]

In [None]:
dim = -1
gumbels.shape

In [None]:
 y_soft = gumbels.softmax(dim)

In [None]:
y_soft.shape

In [None]:
y_soft[:2]

In [None]:
index = y_soft.max(dim, keepdim=True)
print(index[0].T)
print(index[1].T)
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index[1], 1.0)

In [None]:
np.argmax(y_hard,axis=1)

In [None]:
tmp_d= [0,1,0]
for i in range(10):
    sampled = np.random.choice((2, 1, 0), p=tmp_d)
    print(sampled)

## Scratch Pad

### tqdm

In [None]:
from tqdm.notebook import tnrange

In [None]:
curr_iter_t  = 0
curr_iter_a  = 0
curr_iter_w  = 0
stop_iter_t  = 0
stop_iter_w  = 0 
stop_iter_a  = 0
total_weight_epochs = 0
total_policy_epochs = 0 
train_total_iters = 8
weight_iter_alternate = 17
alpha_iter_alternate = 17

In [None]:
print(curr_iter_t, stop_iter_t, flag, train_total_iters,opt['train']['print_freq'] )
start_iter_t = curr_iter_t
stop_iter_t = curr_iter_t +  train_total_iters 
print(f" Current iteration {curr_iter_t} - Run  from {start_iter_t} to {stop_iter_t}")

print(curr_iter_w, weight_iter_alternate , flag)
stop_iter_w = curr_iter_w +  weight_iter_alternate 
print(f" Current Weight iteration {curr_iter_w} - Run  from {curr_iter_w+1} to {stop_iter_w}")


print(curr_iter_a ,  alpha_iter_alternate ,flag)
stop_iter_a = curr_iter_a +  alpha_iter_alternate 
print(f" Current alpha iteration {curr_iter_a} - Run  from {curr_iter_a+1} to {stop_iter_a}")

In [None]:
# del t, t_w, t_a
main_iter_ctr = 0 
with tnrange(start_iter_t , stop_iter_t  , initial = start_iter_t , total = stop_iter_t, position=0, leave= True, desc="master") as t :
    for curr_t in t:
        
        with  tnrange(0, weight_iter_alternate , initial = 0, total = weight_iter_alternate, 
                      position=1, leave= False, desc=f"epoch {curr_t} weight training") as t_w :
            for curr_w in t_w:    
                sleep(0.35)
                main_iter_ctr += 1
                curr_iter_w  = curr_w
                t.set_postfix({'epoch': f"{curr_t}/{train_total_iters}", 'main_iter_ctr': main_iter_ctr})
                t_w.set_postfix({'weight training epoch': curr_t, 'batch #': curr_iter_w})

            print(f"** Epoch {curr_t}/{train_total_iters} weight training complete - Loss: "
                  f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_t:{curr_t}  main_iter_ctr:{main_iter_ctr}" )
                 
        
        with  tnrange(0, alpha_iter_alternate  , initial = 0, total = alpha_iter_alternate , 
                      position=2, leave= False, desc=f"epoch {curr_t} policy training") as t_a :
            for curr_a in t_a:    
                sleep(0.35)
                main_iter_ctr += 1                
                curr_iter_a = curr_a
                t.set_postfix({'epoch': f"{curr_t}/{train_total_iters}", 'main_iter_ctr':main_iter_ctr})
                t_a.set_postfix({'policy training epoch': curr_t, 'batch #': curr_iter_a})            
                
            print(f"** Epoch {curr_t}/{train_total_iters} policy training complete - Loss: "
                  f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_t:{curr_t}  main_iter_ctr:{main_iter_ctr}" )
        

In [None]:
curr_iter_t  = 0
curr_iter_a  = 0
curr_iter_w  = 0
stop_iter_t  = 0
stop_iter_w  = 0 
stop_iter_a  = 0
total_weight_epochs = 0
total_policy_epochs = 0 
train_total_iters = 100
train_total_epochs = 10
weight_iter_alternate = 17
alpha_iter_alternate = 17

print(curr_iter_t, stop_iter_t, flag, train_total_iters,opt['train']['print_freq'] )
start_iter_t = curr_iter_t
stop_iter_t = curr_iter_t +  train_total_iters 
print(f" Current iteration {curr_iter_t} - Run  from {start_iter_t} to {stop_iter_t}")

print(curr_iter_w, weight_iter_alternate , flag)
stop_iter_w = curr_iter_w +  weight_iter_alternate 
print(f" Current Weight iteration {curr_iter_w} - Run  from {curr_iter_w+1} to {stop_iter_w}")


print(curr_iter_a ,  alpha_iter_alternate ,flag)
stop_iter_a = curr_iter_a +  alpha_iter_alternate 
print(f" Current alpha iteration {curr_iter_a} - Run  from {curr_iter_a+1} to {stop_iter_a}")

In [None]:
# del t, t_w, t_a
curr_epoch = 0
main_iter_ctr = 0 
# with tnrange(start_iter_t , stop_iter_t  , initial = start_iter_t , total = stop_iter_t, position=0, leave= True, desc="master") as t :
# with tqdm_notebook(total=train_total_epochs) as t:
t = tqdm_notebook(total=train_total_epochs)

while curr_epoch < train_total_epochs:
    curr_epoch+=1
    t.update(1)

    #-----------------------------------------
    # Train & Update the network weights
    #-----------------------------------------        
    with  tnrange(0, weight_iter_alternate , initial = 0, total = weight_iter_alternate, 
                  position=1, leave= False, desc=f"epoch {curr_epoch} weight training") as t_w :
        for curr_w in t_w:    
            sleep(0.35)
            main_iter_ctr += 1
            curr_iter_w  = curr_w

            t.set_postfix({'epoch': f"{curr_epoch}/{train_total_epochs}", 'main_iter_ctr': main_iter_ctr})
            t_w.set_postfix({'weight training epoch': curr_epoch, 'batch #': curr_iter_w})

        tqdm.write(f"** Epoch {curr_epoch}/{train_total_epochs} weight training complete - Loss: "
              f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_epoch:{curr_epoch}  main_iter_ctr:{main_iter_ctr}" )

    #-----------------------------------------
    # Train & Update the  policy 
    #-----------------------------------------        
    with  tnrange(0, alpha_iter_alternate  , initial = 0, total = alpha_iter_alternate , 
                  position=2, leave= False, desc=f"epoch {curr_epoch} policy training") as t_a :
        for curr_a in t_a:    
            sleep(0.35)
            main_iter_ctr += 1                
            curr_iter_a = curr_a

            t.set_postfix({'epoch': f"{curr_epoch}/{train_total_epochs}", 'main_iter_ctr':main_iter_ctr})
            t_a.set_postfix({'policy training epoch': curr_epoch, 'batch #': curr_iter_a})            

        tqdm.write(f"** Epoch {curr_epoch}/{train_total_epochs} policy training complete - Loss: "
              f"curr_w:{curr_w}    curr_iter_w:{curr_iter_w}  curr_epoch:{curr_epoch}  main_iter_ctr:{main_iter_ctr}" )


In [None]:
# with tnrange(start_iter, stop_iter , initial = current_iter_w, total = stop_iter,  position=0, leave= True, desc="training") as t:
#     for current_iter_w in t:
#         print(current_iter_w)
#         current_iter_w += 1
#         print(current_iter_w)        

In [None]:
# start = current_iter
# end = current_iter + opt['train']['warm_up_iters']
# curr_range = range(start,end)
# print(start, end)

# for i in tqdm.notebook.tnrange(start, end, initial = start, total = end):
#     sleep(0.25)
#     current_iter += 1
# #     print(i)
#     pass

# print(current_iter)

In [None]:
# start = current_iter
# end = current_iter + opt['train']['warm_up_iters']
# curr_range = range(start,end)
# print(start, end)

# for i in tqdm.notebook.tqdm_notebook(cur_range, initial = start, total = end, disable=False, position=0, desc = "validation"):
#     current_iter += 1
#     pass

# print(current_iter)

In [None]:
# from tqdm import trange
# from time import sleep

# for i in trange(40, desc='1st loop', position=0, leave = False):
#     sleep(1.1)
#     for j in trange(5, desc='2nd loop', position =1, leave = False):
#         sleep(0.01)
#         for k in trange(50, desc='3rd loop', position =0,leave=False):
#             sleep(0.01)

In [None]:
# with tqdm(batch_enumerator, leave=False, disable=False) as t:
# with tqdm(total=10, bar_format="{postfix[0]} {postfix[1][value]:>8.2g}", postfix=["Batch", dict(value=0)]) as t:
# with trange(opt['train']['warm_up_iters'], bar_format="{postfix[0]} {postfix[1][value]:>8.2g}", postfix=["Batch", dict(value=0)]) as t:
# with trange(opt['train']['warm_up_iters']) as t:

#     for current_iter in t:
#         batch_idx, batch = next(batch_enumerator)
#         ran = random.randint(1, 100)
#         start_time = time.time()

#         environ.train()

#         print_heading(f" {timestring()} - WARMUP Training iter {current_iter}/{opt['train']['warm_up_iters']}    batch_idx: {batch_idx}"    
#                       f"    Warm-up iters: {opt['train']['warm_up_iters']}"
#                       f"    Validation freq:  {opt['train']['val_freq']}", verbose = False)

#         if batch_idx == len(train_loader) :
#     #         print_heading(f" ******* {timestring()}  re-enumerate train_loader() *******")
#             batch_enumerator = enumerate(train_loader,1)   

#         t.set_postfix({'batch_idx': batch_idx, 'num_vowels': ran})


In [None]:
# import tqdm.notebook

In [None]:
a = np.array([[1,2,3,4,5,6,7,8,9,10],[11,12,13,14,15,16,17,18,19,20]])
print(a)

In [None]:
print(a[:,::-1])
print()
print(a[::-1])

In [None]:
a = 4
b = 4
c = 1

0 // b + c

### Scipy Sparse

In [None]:
import scipy.sparse
row = np.array([5])
col = np.array([0])
data = np.array([6])
y_class = scipy.sparse.csr_matrix((data, (row, col)), shape=(15,1))
y_2 = scipy.sparse.csr_matrix((15,0))
print(f"Created y_class # dims: {y_class.ndim}    shape: {y_class.shape}")
print(f"Created y_2 # dims: {y_2.ndim}    shape: {y_2.shape}")

# y_class[5]= 2

In [None]:
print(y_class.toarray().T)
print(y_2.toarray())

In [None]:
y_class[8,0]= 999
y_class.toarray().squeeze().shape

### folding step by step

In [None]:
x_dev = copy.copy(x_file)
print(x_dev.shape, type(x_dev))

idx = x_dev.nonzero()
print(idx[0][:82])
print(idx[1][:82])
print(x_dev[0].sum(), x_dev[1].sum())
print(x_dev[0,0:-1])

In [None]:
folding_size = 30
print(f" fold - folding_size:{folding_size}" )

## collapse x into folding_size columns
idx = x_dev.nonzero()
folded = idx[1] % folding_size

print(folded[:82])

In [None]:
x_fold = scipy.sparse.csr_matrix((x_dev.data, (idx[0], folded)), shape=(x_dev.shape[0], folding_size))
print(x_fold.shape)

In [None]:
x_fold.sum_duplicates()

In [None]:
print(x_fold)

In [None]:
print(type(y_files[0]), type(y_class[0]))

### Torch tensor manipulations

In [None]:
input = torch.randn(3, 5, 2)

In [None]:
input

In [None]:
input.contiguous().view(-1).shape

### using eval()

In [None]:
print(opt['dataload']['y_tasks'])

In [None]:
task1 = [1,2,3]
task2 = None
task3 = {'4': 'Kevin', '5':'Bardool'}

for i in [1,2,3]:
    print('task{:d}'.format(i))
    if eval('task{:d}'.format(i)) is None:
        print('task{:d} :  has not been defined '.format(i))
#         exec_str = 'task{:d} =  np.random.rand({:d},{:d})'.format(i,3,2)
        exec_str = 'y_task{:d} = scipy.sparse.csr_matrix(({:d}, {:d})) '.format(i,3,2)
        print(exec_str)
        exec(exec_str)
        print('task{:d} : '.format(i), eval('task{:d}'.format(i)))
        print(eval('type(task{:d})'.format(i)))
    else:
        print('task{:d} : '.format(i), eval('task{:d}'.format(i)))
        
print(f"Created task{i} shape        : {eval('len(task{:d})'.format(i))}")
print(len(task3))

In [None]:
a = np.random.rand(3,2)
print(a)

In [None]:
print(eval('len(task{:d})'.format(i)))

## Load datasets, perform folding

In [None]:
## Verify presence of Y label data
if (opt['dataload']['y_tasks'] is None) and (opt['dataload']['y_regr'] is None):
   raise ValueError("No label data specified, please add --y_class and/or --y_regr.")

print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['x']))  
print(os.path.join(opt['dataload']['dataroot'], opt['dataload']['folding']))
for fl in opt['dataload']['y_tasks']:
    print(os.path.join(opt['dataload']['dataroot'], fl))

#### Load X data file

In [None]:
##
## Load data files 
##
# ecfp     = sc.load_sparse(args.x)
# y_class  = sc.load_sparse(args.y_class)
# y_regr   = sc.load_sparse(args.y_regr)
# y_censor = sc.load_sparse(args.y_censor)

dataroot = opt['dataload']['dataroot']

ecfp     = load_sparse(dataroot, opt['dataload']['x'])
# x_file   = copy.copy(ecfp)

print(f" Input    {opt['dataload']['x']} - type : {type(ecfp)} shape : {ecfp.shape}")
# print(f" Input    {opt['dataload']['x']} - type : {type(x_file)} shape : {x_file.shape}")


#### Load Y label files

In [None]:
for i,( y_task, y_type )in enumerate(zip(opt['dataload']['y_tasks'],opt['sc_tasks']),1):
    print(full_path := os.path.join(dataroot, y_task ))
    tmp = load_sparse(full_path)
#     np.load(full_path, allow_pickle=True) 
#     print(type(tmp), tmp.shape)
#     tmp_sparse = scipy.sparse.csr_matrix(tmp)
    print(type(tmp_sparse), tmp_sparse.shape)
    print('indicies: ', len(tmp_sparse.__dict__['indices']), tmp_sparse.__dict__['indices'])
    print('indptr  : ', len(tmp_sparse.__dict__['indptr']) , tmp_sparse.__dict__['indptr'])
    print('data    : ', len(tmp_sparse.__dict__['data'])   , tmp_sparse.__dict__['data'])    

In [None]:
# y_class  = load_sparse(dataroot, opt['dataload']['y_tasks'][0])
# print(f" Input     - type : {type(y_class)} shape : {y_class.shape}")

# y_regr  = load_sparse(dataroot, opt['dataload']['y_tasks'][1])
# print(f" Input     - type : {type(y_regr)} shape : {y_regr.shape}")

##
## Load Y label files 
##
y_files=[]

for i,( y_task, y_type )in enumerate(zip(opt['dataload']['y_tasks'],opt['sc_tasks']),1):
    y_tmp = load_sparse(dataroot,  y_task)
    print(f" y_task:{i}  task type: {y_type:5s}  dataset: {y_task} - type : {type(y_tmp)} shape : {y_tmp.shape}")
    ## Get number of positive / neg and total for each classes
    num_pos    = np.array((y_tmp == +1).sum(0)).flatten()
    num_neg    = np.array((y_tmp == -1).sum(0)).flatten()
    num_class  = np.array((y_tmp !=  0).sum(0)).flatten()
    if (num_class != num_pos + num_neg).any():
        raise ValueError("For classification all y values (--y_class/--y) must be 1 or -1.")
    else:
        y_files.append(y_tmp)

y_class = copy.copy(y_files[0])
# y_regr = copy.copy(y_files[1])

#### Load folding file 

In [None]:
##
## load folding file
##
folding_file = os.path.join(dataroot,opt['dataload']['folding'])
folding  = np.load(folding_file)
print(f" Folding  {folding_file} - type : {type(folding)} shape : {folding.shape}")
print(f"          {folding[:20]}")

assert ecfp.shape[0] == folding.shape[0], "x and folding must have same number of rows" 

#### Load Y censor file

In [None]:
# y_censor = load_sparse(dataroot, opt['dataload']['y_censor'])
# if y_censor is not None:
#     print(f" Input     - type : {type(y_censor)} shape : {y_censor.shape}") 

##
## Load Y censor file
##

# y_censor = load_sparse(dataroot, opt['dataload']['y_censor'])
# if y_censor is None:
#     y_censor = scipy.sparse.csr_matrix(y_regr.shape)
#     vprint(f" y_sensor is {opt['dataload']['y_censor']}   Created y_censor shape       : {y_censor.shape}")
    
# y_censor_shape = y_censor.shape if y_censor is not None else "n/a"
# print(f" y_censor  - type : {type(y_censor)}  shape: {y_censor_shape}")

In [None]:
# if (y_regr is None) and (y_censor is not None):
#     raise ValueError("y_censor provided please also provide --y_regr.")

# # if y_class is None:
# #     y_class = scipy.sparse.csr_matrix((ecfp.shape[0], 0))
# #     vprint(f"Created y_class shape        : {y_class.shape}")

# if y_regr is None:
#     y_regr  = scipy.sparse.csr_matrix((ecfp.shape[0], 0))
#     vprint(f"Created y_regr shape         : {y_regr.shape}")

#### Input folding & transformation

In [None]:
print(f"args.fold_inputs : {opt['dataload']['fold_inputs']} \t\t  transform: {opt['dataload']['input_transform']}\n")
print(repr(ecfp))
ecfp = fold_and_transform_inputs(ecfp, folding_size=opt['dataload']['fold_inputs'], transform=opt['dataload']['input_transform'])
print(repr(ecfp))

print(type(ecfp), ecfp.shape)


####  Loading weights files for tasks

In [None]:
# num_regr   = np.bincount(y_regr.indices, minlength=y_regr.shape[1])
print(' Classification weights: ',opt['dataload']['weights_class'])
tasks_class = load_task_weights(opt['dataload']['weights_class'], y=y_class[0], label="y_class")
# tasks_regr  = load_task_weights(opt['dataload']['weights_regr'] , y=y_regr , label="y_regr")

print(tasks_class)
print(tasks_class.training_weight.shape)

In [None]:
print(y_files[0].shape, y_files[1].shape, y_class.shape)

In [None]:
if tasks_class.aggregation_weight is None:
    '''
    fold classes 
    '''
    ## using min_samples rule
    fold_pos, fold_neg = class_fold_counts(y_class, folding)
    n = opt['dataload']['min_samples_class']
    tasks_class.aggregation_weight = ((fold_pos >= n).all(0) & (fold_neg >= n)).all(0).astype(np.float64)
    print(f" tasks_class.aggregation_weight WAS NOT passed ")
    print(f" min_samples_class: opt['dataload']['min_samples_class']")
    print(f" Class fold counts: \n  fold_pos:\n{fold_pos}  \n\n  fold_neg:\n{fold_neg}") 
else:
    print(f"  tasks_class.aggregation_weight passed ")
    
print(f" tasks_class.aggregation_weight.shape: {tasks_class.aggregation_weight.shape} \n {tasks_class.aggregation_weight}")

In [None]:
# if tasks_regr.aggregation_weight is None:
#     if y_censor.nnz == 0:
#         y_regr2 = y_regr.copy()
#         y_regr2.data[:] = 1
#     else:
#         ## only counting uncensored data
#         y_regr2      = y_censor.copy()
#         y_regr2.data = (y_regr2.data == 0).astype(np.int32)
  
#     fold_regr, _ = sc.class_fold_counts(y_regr2, folding)
#     del y_regr2
#     tasks_regr.aggregation_weight = (fold_regr >= args.min_samples_regr).all(0).astype(np.float64)

In [None]:
##
## Display dataset dimensions 
##
print(f"Input dimension      : {ecfp.shape[1]}")
print(f"#samples             : {ecfp.shape[0]}")
print(f"#classification tasks: {y_class[0].shape[1]}")
print(f"Using {(tasks_class.aggregation_weight > 0).sum()} classification tasks for calculating aggregated metrics (AUCROC, F1_max, etc).")

# vprint(f"#regression tasks    : {y_regr.shape[1]}")
# vprint(f"Using {(tasks_regr.aggregation_weight > 0).sum()} regression tasks for calculating metrics (RMSE, Rsquared, correlation).")
# print(ecfp[18387,:10].toarray())

####  Compute batch size

In [None]:
print(f" batch_ratio        : {opt['batch_ratio']}")
print(f" internal_batch_max : {opt['internal_batch_max']}")

# batch_size  = int(np.ceil(opt['batch_ratio'] * idx_tr.shape[0]))
# num_int_batches = 1
# print(f" batch_ratio * # idx_tr:   {opt['batch_ratio']} * {idx_tr.shape[0]} = {opt['batch_ratio'] * idx_tr.shape[0]}")


# if opt['internal_batch_max'] is not None:
#     if opt['internal_batch_max'] < batch_size:
#         num_int_batches = int(np.ceil(batch_size / opt['internal_batch_max']))
#         print(f"\n\n internal_batch_max: {opt['internal_batch_max']}   batch_size: {batch_size}")
#         print(f" batch_size / internal_batch_max: {batch_size / opt['internal_batch_max']}   num_int_batches: {num_int_batches}")
#         batch_size      = int(np.ceil(batch_size / num_int_batches))
#         print(f" batch_size / num_int_batches: {batch_size / num_int_batches}   modified batch_size: {batch_size}")
        

batch_size = 320 
print(f" batch size:   {batch_size}")

#### Separate test dataset

In [None]:
print(f"opt['dataload']['fold_te'] : {opt['dataload']['fold_te'] }")
print(f"opt['dataload']['fold_va'] : {opt['dataload']['fold_va'] }")

In [None]:
if opt['dataload']['fold_te'] is not None and opt['dataload']['fold_te'] >= 0:
    ## removing test data
    print(f" Remove test data")
    assert opt['dataload']['fold_te'] != opt['dataload']['fold_va'], "fold_va and fold_te must not be equal."
    keep    = (folding != args.fold_te)
    ecfp    = ecfp[keep]
    y_class = y_class[keep]
    y_regr  = y_regr[keep]
    y_censor= y_censor[keep]
    folding = folding[keep]

In [None]:
prop = (np.cumsum([0.3, 0.3, 0.3, 0.1])* ecfp.shape[0]+1).astype(np.int32)
print(prop, prop.astype(np.int32))

#### Separate train, train1, train2, and validation  dataset

In [None]:
fold_va = opt['dataload']['fold_va']
fold_va = 0
fold_train1 = 1
fold_train2 = 2

idx_val    = np.where(folding == fold_va)[0]
idx_train  = np.where(folding == fold_train1)[0]
idx_train1 = np.where(folding == fold_train2)[0]
idx_train2 = np.where(folding >  fold_train2)[0]



In [None]:
dataroot = opt['dataload']['dataroot']
ecfp     = load_sparse(dataroot, opt['dataload']['x'])

total_input = ecfp.shape[0]
ranges      = (np.cumsum([0.3, 0.3, 0.1, 0.3])* total_input).astype(np.int32)
print(total_input, '     ', ranges)

idx_train  = np.arange(ranges[0])
idx_train1 = np.arange(ranges[0], ranges[1])
idx_train2 = np.arange(ranges[1], ranges[2])
idx_val    = np.arange(ranges[2], ranges[-1])

print( f' idx_train   len: {len(idx_train) :6d}  - {(idx_train)} ')
print( f' idx_train1  len: {len(idx_train1):6d}  - {(idx_train1)}')
print( f' idx_train2  len: {len(idx_train2):6d}  - {(idx_train2)}')
print( f' idx_val     len: {len(idx_val)   :6d}  - {(idx_val)}   ')

In [None]:
# y_class_tr = y_class[idx_train]
# y_class_va = y_class[idx_va]

# y_regr_tr  = y_regr[idx_tr]
# y_regr_va  = y_regr[idx_va]

# y_censor_tr = y_censor[idx_tr]
# y_censor_va = y_censor[idx_va]

# num_pos_va  = np.array((y_class_va == +1).sum(0)).flatten()
# num_neg_va  = np.array((y_class_va == -1).sum(0)).flatten()
# num_regr_va = np.bincount(y_regr_va.indices, minlength=y_regr.shape[1])

## Chembl Dataloader V1

In [None]:
##
## Instantiate datasets
##
trainset  = ClassRegrSparseDataset(x=ecfp, y_class=y_class, y_regr=None, y_censor=None, indicies = idx_train)
trainset1 = ClassRegrSparseDataset(x=ecfp, y_class=y_class, y_regr=None, y_censor=None, indicies = idx_train1)
trainset2 = ClassRegrSparseDataset(x=ecfp, y_class=y_class, y_regr=None, y_censor=None, indicies = idx_train2)
valset    = ClassRegrSparseDataset(x=ecfp, y_class=y_class, y_regr=None, y_censor=None, indicies = idx_val)

In [None]:
input_size  = trainset.input_size
output_size = trainset.output_size

class_output_size = trainset.class_output_size
regr_output_size  = trainset.regr_output_size

print(f' trainset - input size       : {input_size:6d}       output_size     : {output_size:6d}')
print(f' trainset - class_output_size: {class_output_size:6d}       regr_output_size: {regr_output_size:6d}')

for i in [trainset, trainset1, trainset2, valset]:
    print(f' trainset - input size       : {i.input_size:6d}       output_size     : {i.output_size:6d}')
    print(f' trainset - class_output_size: {i.class_output_size:6d}       regr_output_size: {i.regr_output_size:6d}')
# dataset_tr.y_class.shape

In [None]:
# trainset.__dict__
batch_size = 320