# nb-model_xg-model-data-fi

In [1]:
import sys
import os
from os import sep
from os.path import dirname, realpath
from pathlib import Path
from collections import OrderedDict
from functools import partial, reduce
import logging

def get_cwd(fname, subdir, crunch_dir=realpath(Path.home()) +sep +'crunch' +sep):
    """
    Convenience function to make a directory string for the current file based on inputs.
    Jupyter Notebook in Anaconda invokes the Python interpreter in Anaconda's subdirectory
    which is why changing sys.argv[0] is necessary. In the future a better way to do this
    should be preferred..
    """
    return crunch_dir +subdir +fname

def fix_path(cwd):
    """
    Convenience function to fix argv and python path so that jupyter notebook can run the same as
    any script in crunch.
    """
    sys.argv[0] = cwd
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

fname = 'nb-model_xg-model-data-fi.ipynb'
dir_name = 'model'
fix_path(get_cwd(fname, dir_name +sep))

import numpy as np
import pandas as pd
#import matplotlib.pyplot as plt
from dask import delayed, compute
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import Dataset as TorchDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import torchfunc
from torchmeta.utils.data import BatchMetaDataLoader
import pytorch_lightning as pl

from ipywidgets import interact, interactive, fixed
from IPython.display import display

pd.set_option("display.max_rows", 100)
pd.set_option('display.max_columns', 50)

from common_util import MODEL_DIR, RECON_DIR, JSON_SFX_LEN, DT_CAL_DAILY_FREQ, pd_to_np, pairwise, df_midx_restack, compose, is_type, df_rows_in_year, list_all_eq, remove_dups_list, NestedDefaultDict, set_loglevel, search_df, chained_filter, get_variants, load_df, dump_df, load_json, gb_transpose, pd_common_index_rows, filter_cols_below, inner_join, outer_join, ser_shift, list_get_dict, window_iter, benchmark
from common_util import pd_split_ternary_to_binary, np_value_counts, isnt, window_iter, all_eq, np_assert_identical_len_dim, pd_idx_rename, midx_get_level, pd_rows, midx_intersect, pd_get_midx_level, pd_common_idx_rows, midx_split, pd_midx_to_arr, window_iter, np_at_least_nd, np_is_ndim, identity_fn
from model.common import DATASET_DIR, XG_PROCESS_DIR, XG_DATA_DIR, XG_DIR, PYTORCH_MODELS_DIR, ERROR_CODE, TEST_RATIO, VAL_RATIO, EXPECTED_NUM_HOURS, default_dataset
from model.common import PYTORCH_ACT_MAPPING, PYTORCH_OPT_MAPPING, PYTORCH_SCH_MAPPING, PYTORCH_LOSS_MAPPING
from model.xg_util import xgload
from model.preproc_util import temporal_preproc_3d, stride_preproc_3d
from model.train_util import pd_to_np_tvt, batchify
#from model.pl_util import TCNModel
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

CRITICAL:root:script location: /home/kev/crunch/model/nb-model_xg-model-data-fi.ipynb
CRITICAL:root:using project dir: /home/kev/crunch/


Prune the xg data down to the data of interest to use in further experiments.

## Load Data

In [2]:
assets = ['sp_500', 'russell_2000', 'nasdaq_100', 'dow_jones']
chosen_asset = assets[0]

In [3]:
f = xgload(XG_DATA_DIR +'features' +sep)
l = xgload(XG_DATA_DIR +'labels' +sep)
t = xgload(XG_DATA_DIR +'targets' +sep)

In [4]:
print('num f: {}'.format(len(list(f))))
print('num l: {}'.format(len(list(l))))
print('num t: {}'.format(len(list(t))))

num f: 2520
num l: 1008
num t: 1504


### ddir / dret

In [5]:
ddir_pba_hoc = {a: list(l.childkeys([a, 'ddir', 'ddir', 'pba_hoc_hdxret_ddir'])) for a in assets}
ddir_vol_hoc = {a: list(l.childkeys([a, 'ddir', 'ddir', 'vol_hoc_hdxret_ddir'])) for a in assets}

In [6]:
dret_pba_hoc = {a: list(t.childkeys([a, 'dret', 'dret', 'pba_hoc_hdxret_dret'])) for a in assets}
dret_vol_hoc = {a: list(t.childkeys([a, 'dret', 'dret', 'vol_hoc_hdxret_dret'])) for a in assets}

### ddir1 / dret1

In [7]:
groups = ['lin', 'log']
fmt3, fmt4 = '{}_{}', '{}_hdxret1_{}'

In [8]:
e = 'ddir1'
b = 'pba_hoc'; ddir1_pba_hoc = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'pba_hlh'; ddir1_pba_hlh = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hoc'; ddir1_vol_hoc = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hlh'; ddir1_vol_hlh = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}

In [9]:
e = 'dret1'
b = 'pba_hoc'; dret1_pba_hoc = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'pba_hlh'; dret1_pba_hlh = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hoc'; dret1_vol_hoc = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hlh'; dret1_vol_hlh = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}

### ddir2/dret2

In [10]:
scalars = ['0.5', '1', '2']
stats = ['avg', 'std', 'mad', 'max', 'min']
fmt4, fmt5 = '{}_hdxret2_{}', '{}_hdxret2({}*{},1)_{}'

In [11]:
e = 'ddir2'
b = 'pba_hoc'; ddir2_pba_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'pba_hlh'; ddir2_pba_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hoc'; ddir2_vol_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hlh'; ddir2_vol_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}

In [12]:
e = 'dret2'
b = 'pba_hoc'; dret2_pba_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'pba_hlh'; dret2_pba_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hoc'; dret2_vol_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hlh'; dret2_vol_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}

### dxfbdir1 / dxfbret1

In [13]:
groups = ['lin', 'log']
fmt3, fmt4 = '{}_{}', '{}_hdxcret1_{}'

In [14]:
e = 'dxfbdir1'
b = 'pba_hoc'; dxfbdir1_pba_hoc = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'pba_hlh'; dxfbdir1_pba_hlh = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hoc'; dxfbdir1_vol_hoc = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hlh'; dxfbdir1_vol_hlh = {a: {g: list(l.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}

In [15]:
e = 'dxfbcret1'
#fmt3, fmt4 = '{}_{}', '{}_hdxcret1_{}'
b = 'pba_hoc'; dxfbcret1_pba_hoc = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'pba_hlh'; dxfbcret1_pba_hlh = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hoc'; dxfbcret1_vol_hoc = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hlh'; dxfbcret1_vol_hlh = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}

In [16]:
e = 'dxfbval1'
b = 'pba_hoc'; dxfbval1_pba_hoc = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'pba_hlh'; dxfbval1_pba_hlh = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hoc'; dxfbval1_vol_hoc = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}
b = 'vol_hlh'; dxfbval1_vol_hlh = {a: {g: list(t.childkeys([a, e, fmt3.format(e, g), fmt4.format(b, e)])) for g in groups} for a in assets}

### dxfbdir2 / dxfbcret2

In [17]:
scalars = ['0.5', '1', '2']
stats = ['avg', 'std', 'mad', 'max', 'min']
fmt4, fmt5 = '{}_hdxcret2_{}', '{}_hdxcret2({}*{},1)_{}'

In [18]:
e = 'dxfbdir2'
b = 'pba_hoc'; dxfbdir2_pba_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'pba_hlh'; dxfbdir2_pba_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hoc'; dxfbdir2_vol_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hlh'; dxfbdir2_vol_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}

In [19]:
e = 'dxfbcret2'
b = 'pba_hoc'; dxfbcret2_pba_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'pba_hlh'; dxfbcret2_pba_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hoc'; dxfbcret2_vol_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hlh'; dxfbcret2_vol_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}

In [20]:
e = 'dxfbval2'
b = 'pba_hoc'; dxfbval2_pba_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'pba_hlh'; dxfbval2_pba_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hoc'; dxfbval2_vol_hoc = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}
b = 'vol_hlh'; dxfbval2_vol_hlh = {a: {d: [[a, e, e, fmt4.format(b, e), fmt5.format(b, c, d, e)] for c in scalars] for d in stats} for a in assets}

### Process Labels/Targets

In [21]:
def get_lt(d, store, asset=chosen_asset, subset=None):
    """
    Return label or target data as a DataFrame.
    """
    kcs = d[asset][subset] if (subset) else d[asset]
    lt = pd.concat([store[kc] for kc in kcs], axis=1, keys=[kc[-1] for kc in kcs])
    lt.columns = lt.columns.droplevel(-1)
    return lt

#### PBA

In [22]:
ddir = get_lt(ddir_pba_hoc, l)
ddir1 = get_lt(ddir1_pba_hoc, l, subset='log')
dxfbdir1 = get_lt(dxfbdir1_pba_hoc, l, subset='log')

In [23]:
dret = get_lt(dret_pba_hoc, t)
dret1 = get_lt(dret1_pba_hoc, t, subset='log')
dxfbcret1 = get_lt(dxfbcret1_pba_hoc, t, subset='log')
dxfbval1 = get_lt(dxfbval1_pba_hoc, t, subset='log')

#### VOL

In [24]:
ddir_vol = get_lt(ddir_vol_hoc, l)
ddir1_vol = get_lt(ddir1_vol_hoc, l, subset='log')
dxfbdir1_vol = get_lt(dxfbdir1_vol_hoc, l, subset='log')

In [25]:
dret_vol = get_lt(dret_vol_hoc, t)
dret1_vol = get_lt(dret1_vol_hoc, t, subset='log')
dxfbcret1_vol = get_lt(dxfbcret1_vol_hoc, t, subset='log')
dxfbval1_vol = get_lt(dxfbval1_vol_hoc, t, subset='log')

### Features

In [26]:
flist = list(sorted(set([k[1] for k in f.childkeys([assets[0]])])))
daily = [axf for axf in flist if (axf[0]=='d')]
hourly =  [axf for axf in flist if (axf[0]=='h')]
display('daily: ' +str(daily))
display('hourly: ' +str(hourly))

"daily: ['dc', 'ddiff', 'dffd', 'dlogret', 'dohlca', 'dwrmx', 'dwrod', 'dwrpt', 'dwrxmx', 'dwrzn']"

"hourly: ['hdgau', 'hdmx', 'hdod', 'hdpt', 'hduni', 'hdzn', 'hohlca']"

## Pruning Features

### Remove Raw Data

In [27]:
def src_in(axefile, src):
    axefile = [axefile] if (is_type(axefile, str)) else axefile
    return any(any(src in val for val in fpath) for fpath in f.childkeys([assets[0], *axefile]))

def getkeys(kc):
    return f.childkeys(kc)

In [28]:
daily.remove('dc')
daily.remove('dohlca')
hourly.remove('hohlca')

### Select Subset
select hourly, pba and vol

In [29]:
hourly_dict = {}
for src in ['pba_vol', 'buzz', 'nonbuzz']:
    hourly_dict[src] = list(filter(partial(src_in, src=src), hourly))

In [30]:
asset_name = chosen_asset
pba_vol_hourly = {'pba': {}, 'vol': {}}

for axef in hourly:
    for axe_src in pba_vol_hourly.keys():
        kcs = f.childkeys([asset_name, axef])
        pba_vol_hourly[axe_src][axef] = list(filter(lambda kc: axe_src in kc[-1], kcs))
        #pba_vol_hourly[axe_src][axef] = list(filter(lambda kc: kc[-1].startswith(axe_src), kcs))

In [31]:
pba_vol_hourly['pba'].keys()
pba_len = pd.Series(index=pba_vol_hourly['pba'].keys(),data=map(len, pba_vol_hourly['pba'].values()), name='pba')
vol_len = pd.Series(index=pba_vol_hourly['vol'].keys(),data=map(len, pba_vol_hourly['vol'].values()), name='vol')
comb_len = pd.concat([pba_len, vol_len], axis=1)

In [32]:
comb_len

Unnamed: 0,pba,vol
hdgau,28,28
hdmx,7,7
hdod,7,7
hdpt,7,7
hduni,28,28
hdzn,7,7


Filter and rewrap to prepare for concatenation into new dataframes:

In [33]:
assert(all(comb_len.loc[:, 'pba']==comb_len.loc[:, 'vol']))
concat_kcs = {k: {} for k in pba_vol_hourly.keys()}

for k0, v0 in pba_vol_hourly.items():
    for k, v in v0.items():
        if (k in ('hdgau', 'hduni')):
            v = [kc for kc in v if ('8' in kc[-1])] # Filter out symbol sizes of 2, 3, 4 from hdgau and hduni
            #continue
        working_wrap, all_wraps = [v[0]], []
        for kc in v[1:]:
            if (working_wrap[0][-1].split('_')[:2]==kc[-1].split('_')[:2]):
                working_wrap.append(kc)
            else:
                all_wraps.append(working_wrap)
                working_wrap = [kc]
        all_wraps.append(working_wrap)
        concat_kcs[k0][k] = all_wraps

concatenate groups of dataframes of the same transform (ie hdgau), return type (lh, oc, or ohlca), and source (pba or vol):

In [34]:
concat_dfs = {k: {} for k in concat_kcs.keys()}
for k0, v0 in concat_kcs.items():
    for k, v in v0.items():
        print(k, len(v))
        concat_dfs[k0][k] = {}
        for g in v:
            #print([kc[-1] for kc in g])
            #ex = [f[kc] for kc in g]
            to_concat = []
            concat_name_keys = []
            for kc in g:
                gr_df = f[kc]
                #display(gr_df.index.levels[1])
                # Append last keychain key to 2nd (last) level of MultiIndex
                gr_df.index = gr_df.index.set_levels(['_'.join([sublvl, kc[-1]]) for sublvl in gr_df.index.levels[1]], level=1, inplace=False)
                to_concat.append(gr_df)
                concat_name_keys.extend(kc[-1].split('_'))
                #display(gr_df)
            
            # Disambiguate sub-dfs and set names
            concat_name = '_'.join(dict.fromkeys(concat_name_keys).keys())
            concat_dfs[k0][k][concat_name] = pd.concat(to_concat)
            print(concat_name)
        print()

hdgau 3
pba_hlh_hlogret_hdzn_hdgau(8)_hret_hspread
pba_hoc_hlogret_hdzn_hdgau(8)_hret_hspread
pba_hohlca_hdzn_hdgau(8)

hdmx 3
pba_hlh_hlogret_hdmx_hret_hspread
pba_hoc_hlogret_hdmx_hret_hspread
pba_hohlca_hdmx

hdod 3
pba_hlh_hlogret_hdod_hret_hspread
pba_hoc_hlogret_hdod_hret_hspread
pba_hohlca_hdod

hdpt 3
pba_hlh_hlogret_hdpt_hret_hspread
pba_hoc_hlogret_hdpt_hret_hspread
pba_hohlca_hdpt

hduni 3
pba_hlh_hlogret_hdmx_hduni(8)_hret_hspread
pba_hoc_hlogret_hdmx_hduni(8)_hret_hspread
pba_hohlca_hdmx_hduni(8)

hdzn 3
pba_hlh_hlogret_hdzn_hret_hspread
pba_hoc_hlogret_hdzn_hret_hspread
pba_hohlca_hdzn

hdgau 3
vol_hlh_hlogret_hdzn_hdgau(8)_hret_hspread
vol_hoc_hlogret_hdzn_hdgau(8)_hret_hspread
vol_hohlca_hdzn_hdgau(8)

hdmx 3
vol_hlh_hlogret_hdmx_hret_hspread
vol_hoc_hlogret_hdmx_hret_hspread
vol_hohlca_hdmx

hdod 3
vol_hlh_hlogret_hdod_hret_hspread
vol_hoc_hlogret_hdod_hret_hspread
vol_hohlca_hdod

hdpt 3
vol_hlh_hlogret_hdpt_hret_hspread
vol_hoc_hlogret_hdpt_hret_hspread
vol_hohlca_hd

## Select Data

### Choose Features and Join

In [35]:
# Manual way
#kc_end = ['dohlca']
#ft_all = {a: list(f.childkeys([a, *kc_end])) for a in assets}
#feat = ft_all[chosen_asset]
#chosen_f = pd.concat([f[feat[0]], f[feat[1]]])

In [36]:
#chosen_f = concat_dfs['pba']['hdmx']['pba_hohlca_hdmx']
chosen_f = concat_dfs['pba']['hdzn']['pba_hohlca_hdzn']
#kc_end = ['hduni']
#ft_all = {a: list(f.childkeys([a, *kc_end])) for a in assets}
#feat = ft_all[chosen_asset]
#chosen_f = f[feat[3]]
#chosen_f = pd.concat([f[feat[3]], f[feat[6]]])
#chosen_f = pd.concat([f[feat[3]], f[feat[6]]])

### Choose Labels/Targets and Process

In [37]:
chosen_l = pd_split_ternary_to_binary(ddir.replace(to_replace=-1, value=0))
chosen_t = pd_split_ternary_to_binary(dret)

In [38]:
# categorical (non-binary) loss broken with nll for now
#chosen_l = pd_split_ternary_to_binary(ddir1)
#chosen_t = pd_split_ternary_to_binary(dret1)

### Get Common Indexed Rows (Intersect First Level of MultiIndex)

In [39]:
year_interval = ('2009', '2018')
common_idx = midx_intersect(pd_get_midx_level(chosen_f), pd_get_midx_level(chosen_l), pd_get_midx_level(chosen_t))
common_idx = common_idx[(common_idx > year_interval[0]) & (common_idx < year_interval[1])]
feature_df, label_df, target_df = map(compose(partial(pd_rows, idx=common_idx), df_midx_restack), [chosen_f, chosen_l, chosen_t])
assert(all(feature_df.index.levels[0]==label_df.index.levels[0]))
assert(all(feature_df.index.levels[0]==target_df.index.levels[0]))

## PL TCN Model

In [40]:
from model.pl_tcn import TCNModel

In [46]:
m_params = {
	'window_size': 10,
	'num_blocks': 1,
	#'block_channels': [[30, 20, 10]],
	'block_channels': [[5]],
	'block_act': 'elu',
	'out_act': 'relu',
	'kernel_sizes': [3],
	'dilation_index': 'global',
	'global_dropout': .2,
	'no_dropout': [0],
	'out_shape': len(chosen_l.columns)*len(chosen_l.index.levels[1]) # directions * slots per direction
}
t_params = {
	'epochs': 200,
	'batch_size': 64,
	'loss': 'ce',
	'opt': {
		'name': 'adam',
		'kwargs': {
			'lr': .001
		}
	},
	'sch': {
		'name': 'rpl',
		'kwargs': {
			'mode': 'min',
			'factor': 0.1,
			'patience': 10,
			'threshold': 0.0001,
			'threshold_mode': 'rel',
			'cooldown': 0,
			'min_lr': 0
		}
	}
}

In [47]:
mdl = TCNModel(m_params, t_params, (feature_df, label_df, target_df))

In [48]:
mdl

TCNModel(
  (loss): CrossEntropyLoss()
  (clf): OutputLinear(
    (emb): TemporalConvNet(
      (convnet): Sequential(
        (RB_0): ResidualBlock(
          (net): Sequential(
            (TL_0_0): TemporalLayer1d(
              (layer): Sequential(
                (0): Conv1d(5, 40, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): ELU(alpha=1.0)
                (2): Dropout(p=0, inplace=False)
              )
            )
          )
          (downsample): Conv1d(5, 40, kernel_size=(1,), stride=(1,))
          (out_act): ReLU()
        )
      )
    )
    (out): Linear(in_features=40, out_features=2, bias=True)
  )
)

In [49]:
trainer = pl.Trainer(max_nb_epochs=t_params['epochs'], gpus=1, amp_level='O1', use_amp=True)

In [50]:
trainer.fit(mdl)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=7.0, style=Prog…




1

## Debug Data (Runtime Transforms)

### Split into Train/Val/Test and Convert to Numpy Tensor

In [88]:
train_np, val_np, test_np = zip(*map(pd_to_np_tvt, (feature_df, label_df, target_df)))
shapes = np.asarray(tuple(map(lambda tvt: tuple(map(np.shape, tvt)), (train_np, val_np, test_np))))
assert all(np.array_equal(a[:, 1:], b[:, 1:]) for a, b in pairwise(shapes)), 'feature, label, target shapes must be identical across splits'
assert all(len(np.unique(mat.T[0, :]))==1 for mat in shapes), 'first dimension (N) must be identical length in each split for all (feature, label, and target) tensors'

### Runtime Preproc

In [89]:
params = {
    'loss': 'nll',
    'batch_size': 1,
    'window_size': 20,
}

In [90]:
def __preproc__(data, m_params, overlap=True):
    x, y, z = temporal_preproc_3d(data, window_size=m_params['window_size'], apply_idx=[0]) if (overlap) else stride_preproc_3d(data, window_size=m_params['window_size'])
    if (m_params['loss'] in ('ce', 'nll')):
        y_new = np.sum(y, axis=(1, 2), keepdims=False)		# Sum label matrices to scalar values
        if (y.shape[1] > 1):
            y_new += y.shape[1]								# Shift to range [0, C-1]
        y = y_new
    return (x, y, z)

#@pl.data_loader
def train_dataloader(t_params, flt):
    logging.info('train_dataloader called')
    return batchify(t_params, __preproc__(flt), False)

#@pl.data_loader
def val_dataloader(t_params, flt):
    logging.info('val_dataloader called')
    return batchify(t_params, __preproc__(flt), False)

#@pl.data_loader
def test_dataloader(t_params, flt):
    logging.info('test_dataloader called')
    return batchify(t_params, __preproc__(flt), False)

### Overlapping Episodes:

In [91]:
train_ol_np = __preproc__(train_np, params)
val_ol_np = __preproc__(val_np, params)
test_ol_np = __preproc__(test_np, params)
print(tuple(map(lambda tvt: tuple(map(np.shape, tvt)), (train_np, val_np, test_np)))[0])
print(tuple(map(np.shape, train_ol_np)))
print(np_value_counts(train_ol_np[1]))

((1359, 5, 8), (1359, 1, 2), (1359, 1, 2))
((1340, 5, 160), (1340,), (1340, 1, 2))
(array([0., 1.]), array([593, 747]))


### Non-Overlapping Episodes:

In [92]:
train_nol_np = __preproc__(train_np, params, False)
val_nol_np = __preproc__(val_np, params, False)
test_nol_np = __preproc__(test_np, params, False)
print(tuple(map(lambda tvt: tuple(map(np.shape, tvt)), (train_np, val_np, test_np)))[0])
print(tuple(map(np.shape, train_nol_np)))

((1359, 5, 8), (1359, 1, 2), (1359, 1, 2))
((67, 5, 160), (67, 2), (67, 1, 20, 2))


In [93]:
def batchify(params, data, shuffle_batches=False):
	"""
	Return a torch.DataLoader made from a tuple of numpy arrays.

	Args:
		params (dict): model parameters dictionary
		data (tuple): tuple of numpy arrays, features are the first element
		shuffle_batches (bool): whether or not to shuffle the batches

	Returns:
		torch.DataLoader
	"""
	f = torch.tensor(data[0], dtype=torch.float32, requires_grad=True)
	if (params['loss'] in ('bce', 'bcel', 'mae', 'mse')):
		l = [torch.tensor(d, dtype=torch.float32, requires_grad=False) for d in data[1:]]
	elif (params['loss'] in ('ce', 'nll')):
		l = [torch.tensor(d, dtype=torch.int64, requires_grad=False).squeeze() for d in data[1:]]
	ds = TensorDataset(f, *l)
	dl = DataLoader(ds, batch_size=params['batch_size'], shuffle=shuffle_batches)
	return ds

## Sklearn Tests

In [94]:
from sklearn.linear_model import LinearRegression, LogisticRegression, SGDClassifier, ElasticNet

In [95]:
x_train, y_train, z_train = train_ol_np
x_val, y_val, z_val = val_ol_np
x_test, y_test, z_test = test_ol_np

In [96]:
def np_collapse_last_two_dim(arr):
    return arr.reshape(arr.shape[0], arr.shape[1]*arr.shape[2])

In [97]:
x_train.reshape(1340, 800)

array([[-1.58067519, -1.15105818, -0.33156908, ..., -0.51929487,
        -0.76036019, -1.43855728],
       [-1.84158528, -0.10414461,  0.57745401, ..., -0.57928382,
        -1.1379013 ,  0.84857252],
       [ 1.64234671, -0.26223317, -1.58819638, ..., -0.11044341,
         1.28551912,  1.79793437],
       ...,
       [ 0.369901  ,  1.28237349,  0.43914589, ...,  0.4296448 ,
         0.71908972,  1.0004945 ],
       [-0.66945975, -1.85637533, -0.82125016, ...,  0.21710106,
         0.35351704,  1.02228811],
       [-1.63411621, -0.84232943, -0.57462642, ...,  1.59981899,
         0.96346278, -1.17287592]])

In [98]:
x_train_new = np_collapse_last_two_dim(x_train)
x_val_new = np_collapse_last_two_dim(x_val)
x_test_new = np_collapse_last_two_dim(x_test)

# Logistic Regression
clf = LogisticRegression(C=10**-2, l1_ratio=.9, penalty='elasticnet', random_state=0, solver='saga').fit(x_train_new, y_train)
sc = clf.score(x_val_new, y_val)
print(sc)
#scs.append(sc)

0.5092165898617511


In [67]:
# Few Shot Tests
# XXX - Need to update for code changes
tr_split = .75
w = params['window_size']
tr, ts = int(tr_split*w), (1-tr_split)*w
print(tr, ts)

scs = []
for x, y, z in zip(*train_nol_np):
    x, y = x.T, y
    print(x.shape)
    print(x[0:tr].shape)
    print(x[tr:].shape)
    break
    clf = LogisticRegression(C=10**-2, l1_ratio=.9, penalty='elasticnet', random_state=0, solver='saga').fit(x[0:tr], y[0:tr].squeeze())
    #clf = LinearRegression().fit(x[0:tr], y[0:tr].squeeze())
    #yh = clf.predict(x[tr:])
    #|yp = clf.predict_proba(x[tr:])
    sc = clf.score(x[tr:], y[tr:].squeeze())
    scs.append(sc)
    #print(sc)
    #print(x)
    #print(y.squeeze())
    #print(z.T)

15 5.0
(160, 5)
(15, 5)
(145, 5)


In [68]:
(np.array(scs).mean()-train_nol[1].squeeze().mean())*100

  {
  ret = ret.dtype.type(ret / rcount)


NameError: name 'train_nol' is not defined