# nb-model_hparam-neural-process

In [1]:
import sys
import os
from os import sep
from os.path import dirname, realpath, exists
from pathlib import Path
from functools import partial
import random
import logging

def get_cwd(fname, subdir, crunch_dir=realpath(Path.home()) +sep +'crunch' +sep):
    """
    Convenience function to make a directory string for the current file based on inputs.
    Jupyter Notebook in Anaconda invokes the Python interpreter in Anaconda's subdirectory
    which is why changing sys.argv[0] is necessary. In the future a better way to do this
    should be preferred..
    """
    return crunch_dir +subdir +fname

def fix_path(cwd):
    """
    Convenience function to fix argv and python path so that jupyter notebook can run the same as
    any script in crunch.
    """
    sys.argv[0] = cwd
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

fname = 'nb-model_xg-model-neural-process-clf-ddir.ipynb'
dir_name = 'model'
fix_path(get_cwd(fname, dir_name +sep))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
from verification.batch_norm import BatchNormVerificationCallback
from verification.batch_gradient import BatchGradientVerificationCallback

from ipywidgets import interact, interactive, fixed
from IPython.display import display

pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 50)

from common_util import MODEL_DIR, load_json, dump_json, rectify_json, load_df, str_now, makedir_if_not_exists, is_valid, isnt, compose, pd_split_ternary_to_binary, df_del_midx_level, midx_intersect, pd_get_midx_level, pd_rows, df_midx_restack
from model.common import ASSETS, DATASET_DIR, XG_PROCESS_DIR, XG_DATA_DIR, XG_DIR, PYTORCH_MODELS_DIR, TRAIN_RATIO, EXPECTED_NUM_HOURS
from model.pl_xgdm import XGDataModule
from model.pl_np import NPModel
from model.np_util2 import AttentiveNP
from recon.viz import *
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

CRITICAL:root:script location: /home/kev/crunch/model/nb-model_xg-model-neural-process-clf-ddir.ipynb
CRITICAL:root:using project dir: /home/kev/crunch/


Prune the xg data down to the data of interest to use in further experiments.

In [2]:
%autoreload 1
%aimport model.np_util2, model.train_util, model.pl_np, model.pl_generic, model.metrics_util, model.model_util, model.pl_xgdm

In [4]:
asset_name = 'dow_jones'

In [6]:
path = os.sep.join([MODEL_DIR, 'olog-02062021', 'NPModel_AttentiveNP', asset_name,
               '1996_2018_ddir_h_pba_h', 'val_f1.0', 'trials.csv'])

In [13]:
optuna_results = {}

for asset_name in ASSETS:
    path = os.sep.join([MODEL_DIR, 'olog-02062021', 'NPModel_AttentiveNP', asset_name,
               '1996_2018_ddir_h_pba_h', 'val_f1.0', 'trials.csv'])
    optuna_results[asset_name] = load_df(path, data_format='csv')

In [53]:
topn = 5

In [None]:
study_dir = MODEL_DIR +sep.join(['exp-log', asset_name, np_type, model_name, dm.name]) +sep

In [52]:
optuna_results[ASSETS[0]]

Unnamed: 0_level_0,value,datetime_start,datetime_complete,duration,params_batch_size,params_ffn_de_global_dropout,params_ffn_de_input_dropout,params_ffn_de_output_dropout,params_lat_encoder_cat_before_rt,params_lat_encoder_class_agg,params_lat_encoder_dist_type,params_lat_encoder_latent_size,params_mha_rt_dropout,params_mha_rt_num_heads,params_sample_latent_post,params_sample_latent_prior,params_sample_out,params_stcn_ft_dilation_power,params_stcn_ft_input_dropout,params_stcn_ft_kernel_sizes,params_stcn_ft_output_dropout,params_train_target_overlap,system_attrs_completed_rung_0,system_attrs_completed_rung_1,system_attrs_completed_rung_2,system_attrs_completed_rung_3,system_attrs_completed_rung_4,state
number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
207,0.577062,2021-05-26 23:21:30.835713,2021-05-26 23:26:00.251094,0 days 00:04:29.415381,64,0.10,0.36,0.18,False,False,normal,none,0.54,1,True,False,False,2,0.04,8,0.45,0,0.560966,0.572253,0.578176,0.554955,0.577062,COMPLETE
415,0.567410,2021-05-27 03:01:26.525607,2021-05-27 03:05:50.556929,0 days 00:04:24.031322,64,0.57,0.17,0.02,True,True,beta,none,0.27,1,True,True,False,3,0.35,8,0.53,8,0.554010,0.555762,,,,COMPLETE
316,0.565238,2021-05-27 01:21:16.777931,2021-05-27 01:28:02.776623,0 days 00:06:45.998692,64,0.42,0.25,0.06,True,True,beta,none,0.29,10,True,False,False,2,0.39,8,0.60,8,0.542203,0.563780,,,,COMPLETE
327,0.562580,2021-05-27 01:38:54.475482,2021-05-27 01:43:08.872137,0 days 00:04:14.396655,64,0.40,0.22,0.05,True,True,beta,none,0.29,1,True,False,False,2,0.37,8,0.60,8,0.556460,0.564513,,,,COMPLETE
774,0.562463,2021-06-01 13:13:59.376530,2021-06-01 13:17:57.208931,0 days 00:03:57.832401,64,0.47,0.24,0.08,False,False,normal,none,0.16,1,True,False,False,1,0.28,9,0.15,0,0.568022,0.562463,0.562463,,,COMPLETE
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
930,0.314525,2021-06-01 15:19:37.126322,2021-06-01 15:25:11.733988,0 days 00:05:34.607666,64,0.42,0.31,0.08,True,True,normal,1024,0.02,10,True,True,False,1,0.16,9,0.19,8,0.564607,0.563495,0.314525,,,PRUNED
445,0.314525,2021-05-27 03:30:00.961056,2021-05-27 03:31:16.568369,0 days 00:01:15.607313,64,0.50,0.26,0.08,True,True,normal,1024,0.08,1,True,False,False,2,0.37,8,0.56,8,0.314525,,,,,PRUNED
806,0.314525,2021-06-01 13:40:59.223194,2021-06-01 13:41:05.335922,0 days 00:00:06.112728,64,0.60,0.16,0.06,False,True,beta,1024,0.31,1,True,False,False,3,0.13,9,0.32,8,0.314525,,,,,PRUNED
277,0.314525,2021-05-27 00:44:34.537239,2021-05-27 00:44:42.998408,0 days 00:00:08.461169,64,0.15,0.37,0.05,False,False,normal,none,0.27,10,True,False,True,3,0.01,7,0.52,8,0.314525,,,,,PRUNED


### Dropout

In [39]:
drop_cols = [col for col in optuna_results[ASSETS[0]].columns if ('dropout' in col and 'ft_' in col)]

In [42]:
for asset_name in ASSETS:
    print(asset_name, ':', optuna_results[asset_name].head(topn).loc[:, drop_cols].median())

sp_500 : params_stcn_ft_input_dropout     0.35
params_stcn_ft_output_dropout    0.53
dtype: float64
russell_2000 : params_stcn_ft_input_dropout     0.05
params_stcn_ft_output_dropout    0.42
dtype: float64
nasdaq_100 : params_stcn_ft_input_dropout     0.18
params_stcn_ft_output_dropout    0.24
dtype: float64
dow_jones : params_stcn_ft_input_dropout     0.16
params_stcn_ft_output_dropout    0.40
dtype: float64


In [55]:
drop_cols = [col for col in optuna_results[ASSETS[0]].columns if ('params_train_target_overlap' in col)]
for asset_name in ASSETS:
    print(asset_name, ':', optuna_results[asset_name].head(topn).loc[:, drop_cols].describe())

sp_500 :        params_train_target_overlap
count                      5.00000
mean                       4.80000
std                        4.38178
min                        0.00000
25%                        0.00000
50%                        8.00000
75%                        8.00000
max                        8.00000
russell_2000 :        params_train_target_overlap
count                     5.000000
mean                      1.600000
std                       3.577709
min                       0.000000
25%                       0.000000
50%                       0.000000
75%                       0.000000
max                       8.000000
nasdaq_100 :        params_train_target_overlap
count                      5.00000
mean                      12.80000
std                        4.38178
min                        8.00000
25%                        8.00000
50%                       16.00000
75%                       16.00000
max                       16.00000
dow_jones :       