In [1]:
import os
import time
import pickle
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import log_loss, roc_auc_score

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

DATA_PATH = '../jane-street-market-prediction/'

# GPU_NUM = 8
BATCH_SIZE = 256# * GPU_NUM
EPOCHS = 200
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
EARLYSTOP_NUM = 5
NFOLDS = 5

TRAIN = True
CACHE_PATH = './'

train = pd.read_csv(f'{DATA_PATH}/train.csv')

def save_pickle(dic, save_path):
    with open(save_path, 'wb') as f:
    # with gzip.open(save_path, 'wb') as f:
        pickle.dump(dic, f)

def load_pickle(load_path):
    with open(load_path, 'rb') as f:
    # with gzip.open(load_path, 'rb') as f:
        message_dict = pickle.load(f)
    return message_dict

class EarlyStopping:
    def __init__(self, patience=7, mode="max", delta=0.001):
        self.patience = patience
        self.counter = 0
        self.mode = mode
        self.best_score = None
        self.early_stop = False
        self.delta = delta
        if self.mode == "min":
            self.val_score = np.Inf
        else:
            self.val_score = -np.Inf

    def __call__(self, epoch_score, model, model_path):

        if self.mode == "min":
            score = -1.0 * epoch_score
        else:
            score = np.copy(epoch_score)

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
        elif score < self.best_score: #  + self.delta
            self.counter += 1
            # print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            # ema.apply_shadow()
            self.save_checkpoint(epoch_score, model, model_path)
            # ema.restore()
            self.counter = 0

    def save_checkpoint(self, epoch_score, model, model_path):
        if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]:
            # print('Validation score improved ({} --> {}). Saving model!'.format(self.val_score, epoch_score))
            # if not DEBUG:
            torch.save(model.state_dict(), model_path)
        self.val_score = epoch_score

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=42)

feat_cols = [f'feature_{i}' for i in range(130)]

if TRAIN:
    train = train.loc[train.date > 85].reset_index(drop=True)
    train['action'] = (train['resp'] > 0).astype('int')
    train['action_1'] = (train['resp_1'] > 0).astype('int')
    train['action_2'] = (train['resp_2'] > 0).astype('int')
    train['action_3'] = (train['resp_3'] > 0).astype('int')
    train['action_4'] = (train['resp_4'] > 0).astype('int')
    # valid = train.loc[(train.date >= 450) & (train.date < 500)].reset_index(drop=True)
    # train = train.loc[train.date < 450].reset_index(drop=True)
target_cols = ['action', 'action_1', 'action_2', 'action_3', 'action_4']

if TRAIN:
    #df = pd.concat([train[feat_cols], valid[feat_cols]]).reset_index(drop=True)
    #f_mean = df.mean()
    f_mean = train.mean()
    f_mean = f_mean.values
    np.save(f'{CACHE_PATH}/f_mean_online.npy', f_mean)
    train.fillna(train.mean(), inplace=True)
    #valid.fillna(df.mean(), inplace=True)
else:
    f_mean = np.load(f'{CACHE_PATH}/f_mean_online.npy')

##### Making features
# https://www.kaggle.com/lucasmorin/running-algos-fe-for-fast-inference/data
# eda:https://www.kaggle.com/carlmcbrideellis/jane-street-eda-of-day-0-and-feature-importance
# his example:https://www.kaggle.com/gracewan/plot-model
def fillna_npwhere_njit(array, values):
    if np.isnan(array.sum()):
        array = np.where(np.isnan(array), values, array)
    return array

class RunningEWMean:
    def __init__(self, WIN_SIZE=20, n_size=1, lt_mean=None):
        if lt_mean is not None:
            self.s = lt_mean
        else:
            self.s = np.zeros(n_size)
        self.past_value = np.zeros(n_size)
        self.alpha = 2 / (WIN_SIZE + 1)

    def clear(self):
        self.s = 0

    def push(self, x):

        x = fillna_npwhere_njit(x, self.past_value)
        self.past_value = x
        self.s = self.alpha * x + (1 - self.alpha) * self.s

    def get_mean(self):
        return self.s

if TRAIN:
    all_feat_cols = [col for col in feat_cols]

    train['cross_41_42_43'] = train['feature_41'] + train['feature_42'] + train['feature_43']
    train['cross_1_2'] = train['feature_1'] / (train['feature_2'] + 1e-5)
    #valid['cross_41_42_43'] = valid['feature_41'] + valid['feature_42'] + valid['feature_43']
    #valid['cross_1_2'] = valid['feature_1'] / (valid['feature_2'] + 1e-5)

    all_feat_cols.extend(['cross_41_42_43', 'cross_1_2'])

##### Model&Data fnc
class SmoothBCEwLogits(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean', smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    @staticmethod
    def _smooth(targets:torch.Tensor, n_labels:int, smoothing=0.0):
        assert 0 <= smoothing < 1
        with torch.no_grad():
            targets = targets * (1.0 - smoothing) + 0.5 * smoothing
        return targets

    def forward(self, inputs, targets):
        targets = SmoothBCEwLogits._smooth(targets, inputs.size(-1),
            self.smoothing)
        loss = F.binary_cross_entropy_with_logits(inputs, targets,self.weight)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss

class MarketDataset:
    def __init__(self, df):
        self.features = df[all_feat_cols].values

        self.label = df[target_cols].values.reshape(-1, len(target_cols))

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        return {
            'features': torch.tensor(self.features[idx], dtype=torch.float),
            'label': torch.tensor(self.label[idx], dtype=torch.float)
        }


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.batch_norm0 = nn.BatchNorm1d(len(all_feat_cols))
        self.dropout0 = nn.Dropout(0.2)

        dropout_rate = 0.2
        hidden_size = 128
        self.dense1 = nn.Linear(len(all_feat_cols), hidden_size)
        self.batch_norm1 = nn.BatchNorm1d(hidden_size)
        self.dropout1 = nn.Dropout(dropout_rate)

        self.dense2 = nn.Linear(hidden_size+len(all_feat_cols), hidden_size)
        self.batch_norm2 = nn.BatchNorm1d(hidden_size)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.dense3 = nn.Linear(hidden_size+hidden_size, hidden_size)
        self.batch_norm3 = nn.BatchNorm1d(hidden_size)
        self.dropout3 = nn.Dropout(dropout_rate)

        self.dense4 = nn.Linear(hidden_size+hidden_size, hidden_size)
        self.batch_norm4 = nn.BatchNorm1d(hidden_size)
        self.dropout4 = nn.Dropout(dropout_rate)
    
        self.dense5 = nn.Linear(hidden_size+hidden_size, len(target_cols))

        self.Relu = nn.ReLU(inplace=True)
        self.PReLU = nn.PReLU()
        self.LeakyReLU = nn.LeakyReLU(negative_slope=0.01, inplace=True)
        # self.GeLU = nn.GELU()
        self.RReLU = nn.RReLU()

    def forward(self, x):
        x = self.batch_norm0(x)
        x = self.dropout0(x)

        x1 = self.dense1(x)
        x1 = self.batch_norm1(x1)
        # x = F.relu(x)
        # x = self.PReLU(x)
        x1 = self.LeakyReLU(x1)
        x1 = self.dropout1(x1)

        x = torch.cat([x, x1], 1)

        x2 = self.dense2(x)
        x2 = self.batch_norm2(x2)
        # x = F.relu(x)
        # x = self.PReLU(x)
        x2 = self.LeakyReLU(x2)
        x2 = self.dropout2(x2)

        x = torch.cat([x1, x2], 1)

        x3 = self.dense3(x)
        x3 = self.batch_norm3(x3)
        # x = F.relu(x)
        # x = self.PReLU(x)
        x3 = self.LeakyReLU(x3)
        x3 = self.dropout3(x3)

        x = torch.cat([x2, x3], 1)

        x4 = self.dense4(x)
        x4 = self.batch_norm4(x4)
        # x = F.relu(x)
        # x = self.PReLU(x)
        x4 = self.LeakyReLU(x4)
        x4 = self.dropout4(x4)

        x = torch.cat([x3, x4], 1)

        x = self.dense5(x)

        return x

def train_fn(model, optimizer, scheduler, loss_fn, dataloader, device):
    model.train()
    final_loss = 0

    for data in dataloader:
        optimizer.zero_grad()
        features = data['features'].to(device)
        label = data['label'].to(device)
        outputs = model(features)
        loss = loss_fn(outputs, label)
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        final_loss += loss.item()

    final_loss /= len(dataloader)

    return final_loss

def inference_fn(model, dataloader, device):
    model.eval()
    preds = []

    for data in dataloader:
        features = data['features'].to(device)

        with torch.no_grad():
            outputs = model(features)

        preds.append(outputs.sigmoid().detach().cpu().numpy())

    preds = np.concatenate(preds).reshape(-1, len(target_cols))

    return preds

def utility_score_bincount(date, weight, resp, action):
    count_i = len(np.unique(date))
    # print('weight: ', weight)
    # print('resp: ', resp)
    # print('action: ', action)
    # print('weight * resp * action: ', weight * resp * action)
    Pi = np.bincount(date, weight * resp * action)
    t = np.sum(Pi) / np.sqrt(np.sum(Pi ** 2)) * np.sqrt(250 / count_i)
    u = np.clip(t, 0, 6) * np.sum(Pi)
    return u

In [2]:
train.index = pd.to_datetime(train.date)


In [3]:
train.head()

Unnamed: 0_level_0,date,weight,resp_1,resp_2,resp_3,resp_4,resp,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,...,feature_88,feature_89,feature_90,feature_91,feature_92,feature_93,feature_94,feature_95,feature_96,feature_97,feature_98,feature_99,feature_100,feature_101,feature_102,feature_103,feature_104,feature_105,feature_106,feature_107,feature_108,feature_109,feature_110,feature_111,feature_112,feature_113,feature_114,feature_115,feature_116,feature_117,feature_118,feature_119,feature_120,feature_121,feature_122,feature_123,feature_124,feature_125,feature_126,feature_127,feature_128,feature_129,ts_id,action,action_1,action_2,action_3,action_4,cross_41_42_43,cross_1_2
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1
1970-01-01 00:00:00.000000086,86,0.859516,-0.003656,-0.005449,-0.017403,-0.028896,-0.021435,1,3.151305,5.467693,-0.164505,-0.189219,0.663966,0.988896,0.035757,0.026819,2.184804,3.278742,0.069354,0.060242,1.471544,2.39143,1.640887,3.938759,0.110741,0.114943,2.361346,4.71164,0.167676,0.177819,1.958027,4.069699,2.535238,4.813858,0.12584,0.16081,-0.194392,-0.336857,0.207271,0.252054,-0.073242,-0.131142,-0.197839,-0.288336,0.341815,0.599994,-0.202268,-0.471068,-0.405654,0.05244,...,-1.64186,-2.060506,0.273385,-1.515613,0.315019,-1.746285,-1.086886,-3.38807,0.401428,1.607253,0.420753,0.199793,0.969663,-2.434601,0.244741,0.227364,0.320794,-0.633981,-0.122468,-4.349793,0.392752,0.322244,0.417838,-0.458309,-0.03274,-3.018269,0.262441,-1.219454,0.288197,-2.608786,-1.611309,-2.724954,0.241291,0.254985,2.433699,4.282284,1.621115,4.33103,2.55322,3.799011,2.642943,3.998054,527894,0,0,0,0,0,-1.702478,0.576349
1970-01-01 00:00:00.000000086,86,0.0,-0.009107,-0.013542,-0.022222,-0.032522,-0.026394,1,2.249176,2.618401,-0.304355,-0.276975,-0.035921,-0.036215,0.035757,0.026819,3.354857,3.040463,0.069354,0.060242,2.36505,2.376956,2.337125,3.438553,0.110741,0.114943,3.041641,4.165903,0.167676,0.177819,2.889146,4.174374,3.234317,4.276899,0.12584,0.16081,-1.644735,-2.479335,0.207271,0.252054,-1.321317,-1.491122,-2.478752,-2.496164,0.396227,0.435508,-0.248213,-0.439213,-0.993568,3.075146,...,-1.64186,-2.579694,0.273385,-1.515613,0.315019,-1.746285,-1.086886,-4.781603,0.401428,2.557578,0.420753,0.665543,1.704761,-1.965635,0.244741,-0.079505,0.320794,-0.857492,-0.512759,-4.546557,0.392752,1.275872,0.417838,-0.054892,0.872509,-3.120828,0.262441,-1.881751,0.288197,-3.280218,-2.261787,-3.617442,0.241291,0.254985,2.053416,-0.493276,1.661974,-1.082122,2.427706,-0.756115,2.210572,-0.639075,527895,0,0,0,0,0,2.098619,0.858985
1970-01-01 00:00:00.000000086,86,0.590949,0.000347,-0.000376,-0.004051,-0.007995,-0.004743,-1,-0.365888,0.824004,-0.293208,-0.416391,-0.599185,-0.99733,0.035757,0.026819,-0.86933,0.174646,0.069354,0.060242,-2.376733,-2.602154,-0.580833,0.145479,0.110741,0.114943,-0.440224,-0.943834,0.167676,0.177819,-1.842764,-3.478558,-0.506549,-1.058953,0.12584,0.16081,0.539967,1.481719,0.207271,0.252054,0.533328,1.164644,0.958275,1.93693,-0.550514,-0.9267,0.055286,0.153123,1.277755,-2.542437,...,-0.449674,-2.954607,0.273385,0.092611,0.315019,0.372403,-0.066319,-2.740989,0.401428,-1.338859,0.420753,-0.018706,-0.52289,-2.132602,0.244741,-0.94019,0.320794,1.03441,-0.590374,-2.151733,0.392752,-2.4713,0.417838,-0.385969,-2.290683,-3.531129,0.262441,-1.673329,0.288197,1.017174,-1.059342,-1.723941,0.241291,0.254985,-0.702873,4.038753,-0.789767,4.133183,-1.207878,3.402796,-0.92829,3.511141,527896,0,1,0,0,0,-3.338676,-0.444031
1970-01-01 00:00:00.000000086,86,0.172997,0.000168,0.000333,-0.002375,-0.003064,0.001527,1,1.514607,0.596214,0.324062,0.15473,0.845069,0.521491,0.035757,0.026819,0.310387,-0.379196,0.069354,0.060242,0.866451,0.148476,0.197457,-0.516572,0.110741,0.114943,1.025831,0.704435,0.167676,0.177819,1.691567,1.379021,1.111965,0.682265,0.12584,0.16081,-0.635982,-0.525029,0.207271,0.252054,-0.458078,-0.246643,-0.916675,-0.48224,0.590027,0.381223,-0.03372,-0.019842,-0.368249,1.269972,...,-1.64186,-3.581284,0.273385,-1.515613,0.315019,-1.746285,-1.086886,-4.438488,0.401428,-0.140773,0.420753,-0.762597,-0.409249,-2.698973,0.244741,-0.252248,0.320794,-0.906413,-0.748366,-3.765935,0.392752,-2.338233,0.417838,-1.568599,-2.851826,-4.75762,0.262441,-2.294113,0.288197,-3.416992,-2.645002,-2.973197,0.241291,0.254985,2.304354,1.530169,3.596848,4.613493,4.51611,3.341374,2.635798,1.535235,527897,1,1,1,0,0,2.698305,2.540332
1970-01-01 00:00:00.000000086,86,0.0,0.000503,0.000589,-0.001587,-0.002665,-0.000139,-1,-1.158576,-0.146579,-0.035525,-0.008082,-0.152708,-0.251737,0.035757,0.026819,-0.420522,0.883579,0.069354,0.060242,-1.256021,-0.94303,-0.052071,1.562831,0.110741,0.114943,-0.351357,-0.788583,0.167676,0.177819,-0.282124,-0.737988,-0.388627,-0.867531,0.12584,0.16081,0.793271,2.338712,0.207271,0.252054,1.828729,4.012945,1.53447,3.294522,-0.193584,-0.327048,0.195077,0.648519,3.548811,-2.674746,...,-0.145292,-2.477622,0.273385,0.245049,0.315019,1.245978,0.3452,-2.195447,0.401428,-0.353913,0.420753,0.585657,0.138084,-1.916957,0.244741,-0.262477,0.320794,2.252213,1.328353,-1.874398,0.392752,-1.951431,0.417838,0.21278,-0.854314,-3.008724,0.262441,-1.004068,0.288197,2.417244,0.80685,-1.460248,0.241291,0.254985,-1.364269,2.006926,-1.237922,2.124396,-2.005723,1.716914,-1.646484,1.795211,527898,0,1,1,0,0,-1.624481,7.904661


In [4]:
train[target_cols].value_counts()

action  action_1  action_2  action_3  action_4
1       1         1         1         1           493027
0       0         0         0         0           479572
        1         1         0         0           136574
1       0         0         1         1           134593
0       1         1         1         0            90531
1       0         0         0         1            86953
                  1         1         1            74088
0       1         0         0         0            71471
1       1         1         1         0            33437
0       0         0         0         1            32760
                  1         0         0            26409
1       1         0         1         1            25624
0       0         0         1         1            23334
1       1         1         0         0            23319
0       0         0         1         0            18691
1       1         1         0         1            18689
0       1         1         1         1  

In [5]:
import itertools as itt

def cpcv_generator(t_span, n, k, verbose=True):
    # split data into N groups, with N << T
    # this will assign each index position to a group position
    group_num = np.arange(t_span) // (t_span // n)
    group_num[group_num == n] = n-1
    
    # generate the combinations 
    test_groups = np.array(list(itt.combinations(np.arange(n), k))).reshape(-1, k)
    C_nk = len(test_groups)
    n_paths = C_nk * k // n 
    
    if verbose:
        print('n_sim:', C_nk)
        print('n_paths:', n_paths)
    
    # is_test is a T x C(n, k) array where each column is a logical array 
    # indicating which observation in in the test set
    is_test_group = np.full((n, C_nk), fill_value=False)
    is_test = np.full((t_span, C_nk), fill_value=False)
    
    # assign test folds for each of the C(n, k) simulations
    for k, pair in enumerate(test_groups):
        i, j = pair
        is_test_group[[i, j], k] = True
        
        # assigning the test folds
        mask = (group_num == i) | (group_num == j)
        is_test[mask, k] = True
        
    # for each path, connect the folds from different simulations to form a backtest path
    # the fold coordinates are: the fold number, and the simulation index e.g. simulation 0, fold 0 etc
    path_folds = np.full((n, n_paths), fill_value=np.nan)
    
    for i in range(n_paths):
        for j in range(n):
            s_idx = is_test_group[j, :].argmax().astype(int)
            path_folds[j, i] = s_idx
            is_test_group[j, s_idx] = False
            
    
    # finally, for each path we indicate which simulation we're building the path from and the time indices
    paths = np.full((t_span, n_paths), fill_value= np.nan)
    
    for p in range(n_paths):
        for i in range(n):
            mask = (group_num == i)
            paths[mask, p] = int(path_folds[i, p])
    # paths = paths_# .astype(int)

    return (is_test, paths, path_folds)    

In [6]:
# AFML, snippet 7.1
from tqdm import tqdm

def purge(t1, test_times): # whatever is not in the train set should be in the test set
    train_ = t1.copy(deep=True) # copy of the index
    train_ = train_.drop_duplicates()
    for start, end in tqdm(test_times.iteritems(), total=len(test_times)):
        df_0 = train_[(start <= train_.index) & (train_.index <= end)].index # train_ starts within test
        df_1 = train_[(start <= train_) & (train_ <= end)].index
        df_2 = train_[(train_.index <= start) & (end <= train_)].index
        train_ = train_.drop(df_0.union(df_1).union(df_2))
    return train_

# AFML, snippet 7.2
def embargo_(times, pct_embargo):
    step = int(times.shape[0] * pct_embargo) # more complicated logic if needed to use a time delta
    print('step:', step)
    if step == 0:
        ans = pd.Series(times, index=test_times)
    else:
        ans = pd.Series(times[step:].values, index=times[:-step].index)
        ans = ans.append(pd.Series(times[-1], index=times[-step:].index))
    return ans

def embargo(test_times, t1, pct_embargo=0.01): # done before purging
    # embargoed t1
    t1_embargo = embargo_(t1, pct_embargo)
    # test_start, test_end = test_times.index[0], test_times.index[-1]
    test_times_embargoed = t1_embargo.loc[test_times.index]
    return test_times_embargoed

In [7]:
train.head()

Unnamed: 0_level_0,date,weight,resp_1,resp_2,resp_3,resp_4,resp,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,...,feature_88,feature_89,feature_90,feature_91,feature_92,feature_93,feature_94,feature_95,feature_96,feature_97,feature_98,feature_99,feature_100,feature_101,feature_102,feature_103,feature_104,feature_105,feature_106,feature_107,feature_108,feature_109,feature_110,feature_111,feature_112,feature_113,feature_114,feature_115,feature_116,feature_117,feature_118,feature_119,feature_120,feature_121,feature_122,feature_123,feature_124,feature_125,feature_126,feature_127,feature_128,feature_129,ts_id,action,action_1,action_2,action_3,action_4,cross_41_42_43,cross_1_2
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1
1970-01-01 00:00:00.000000086,86,0.859516,-0.003656,-0.005449,-0.017403,-0.028896,-0.021435,1,3.151305,5.467693,-0.164505,-0.189219,0.663966,0.988896,0.035757,0.026819,2.184804,3.278742,0.069354,0.060242,1.471544,2.39143,1.640887,3.938759,0.110741,0.114943,2.361346,4.71164,0.167676,0.177819,1.958027,4.069699,2.535238,4.813858,0.12584,0.16081,-0.194392,-0.336857,0.207271,0.252054,-0.073242,-0.131142,-0.197839,-0.288336,0.341815,0.599994,-0.202268,-0.471068,-0.405654,0.05244,...,-1.64186,-2.060506,0.273385,-1.515613,0.315019,-1.746285,-1.086886,-3.38807,0.401428,1.607253,0.420753,0.199793,0.969663,-2.434601,0.244741,0.227364,0.320794,-0.633981,-0.122468,-4.349793,0.392752,0.322244,0.417838,-0.458309,-0.03274,-3.018269,0.262441,-1.219454,0.288197,-2.608786,-1.611309,-2.724954,0.241291,0.254985,2.433699,4.282284,1.621115,4.33103,2.55322,3.799011,2.642943,3.998054,527894,0,0,0,0,0,-1.702478,0.576349
1970-01-01 00:00:00.000000086,86,0.0,-0.009107,-0.013542,-0.022222,-0.032522,-0.026394,1,2.249176,2.618401,-0.304355,-0.276975,-0.035921,-0.036215,0.035757,0.026819,3.354857,3.040463,0.069354,0.060242,2.36505,2.376956,2.337125,3.438553,0.110741,0.114943,3.041641,4.165903,0.167676,0.177819,2.889146,4.174374,3.234317,4.276899,0.12584,0.16081,-1.644735,-2.479335,0.207271,0.252054,-1.321317,-1.491122,-2.478752,-2.496164,0.396227,0.435508,-0.248213,-0.439213,-0.993568,3.075146,...,-1.64186,-2.579694,0.273385,-1.515613,0.315019,-1.746285,-1.086886,-4.781603,0.401428,2.557578,0.420753,0.665543,1.704761,-1.965635,0.244741,-0.079505,0.320794,-0.857492,-0.512759,-4.546557,0.392752,1.275872,0.417838,-0.054892,0.872509,-3.120828,0.262441,-1.881751,0.288197,-3.280218,-2.261787,-3.617442,0.241291,0.254985,2.053416,-0.493276,1.661974,-1.082122,2.427706,-0.756115,2.210572,-0.639075,527895,0,0,0,0,0,2.098619,0.858985
1970-01-01 00:00:00.000000086,86,0.590949,0.000347,-0.000376,-0.004051,-0.007995,-0.004743,-1,-0.365888,0.824004,-0.293208,-0.416391,-0.599185,-0.99733,0.035757,0.026819,-0.86933,0.174646,0.069354,0.060242,-2.376733,-2.602154,-0.580833,0.145479,0.110741,0.114943,-0.440224,-0.943834,0.167676,0.177819,-1.842764,-3.478558,-0.506549,-1.058953,0.12584,0.16081,0.539967,1.481719,0.207271,0.252054,0.533328,1.164644,0.958275,1.93693,-0.550514,-0.9267,0.055286,0.153123,1.277755,-2.542437,...,-0.449674,-2.954607,0.273385,0.092611,0.315019,0.372403,-0.066319,-2.740989,0.401428,-1.338859,0.420753,-0.018706,-0.52289,-2.132602,0.244741,-0.94019,0.320794,1.03441,-0.590374,-2.151733,0.392752,-2.4713,0.417838,-0.385969,-2.290683,-3.531129,0.262441,-1.673329,0.288197,1.017174,-1.059342,-1.723941,0.241291,0.254985,-0.702873,4.038753,-0.789767,4.133183,-1.207878,3.402796,-0.92829,3.511141,527896,0,1,0,0,0,-3.338676,-0.444031
1970-01-01 00:00:00.000000086,86,0.172997,0.000168,0.000333,-0.002375,-0.003064,0.001527,1,1.514607,0.596214,0.324062,0.15473,0.845069,0.521491,0.035757,0.026819,0.310387,-0.379196,0.069354,0.060242,0.866451,0.148476,0.197457,-0.516572,0.110741,0.114943,1.025831,0.704435,0.167676,0.177819,1.691567,1.379021,1.111965,0.682265,0.12584,0.16081,-0.635982,-0.525029,0.207271,0.252054,-0.458078,-0.246643,-0.916675,-0.48224,0.590027,0.381223,-0.03372,-0.019842,-0.368249,1.269972,...,-1.64186,-3.581284,0.273385,-1.515613,0.315019,-1.746285,-1.086886,-4.438488,0.401428,-0.140773,0.420753,-0.762597,-0.409249,-2.698973,0.244741,-0.252248,0.320794,-0.906413,-0.748366,-3.765935,0.392752,-2.338233,0.417838,-1.568599,-2.851826,-4.75762,0.262441,-2.294113,0.288197,-3.416992,-2.645002,-2.973197,0.241291,0.254985,2.304354,1.530169,3.596848,4.613493,4.51611,3.341374,2.635798,1.535235,527897,1,1,1,0,0,2.698305,2.540332
1970-01-01 00:00:00.000000086,86,0.0,0.000503,0.000589,-0.001587,-0.002665,-0.000139,-1,-1.158576,-0.146579,-0.035525,-0.008082,-0.152708,-0.251737,0.035757,0.026819,-0.420522,0.883579,0.069354,0.060242,-1.256021,-0.94303,-0.052071,1.562831,0.110741,0.114943,-0.351357,-0.788583,0.167676,0.177819,-0.282124,-0.737988,-0.388627,-0.867531,0.12584,0.16081,0.793271,2.338712,0.207271,0.252054,1.828729,4.012945,1.53447,3.294522,-0.193584,-0.327048,0.195077,0.648519,3.548811,-2.674746,...,-0.145292,-2.477622,0.273385,0.245049,0.315019,1.245978,0.3452,-2.195447,0.401428,-0.353913,0.420753,0.585657,0.138084,-1.916957,0.244741,-0.262477,0.320794,2.252213,1.328353,-1.874398,0.392752,-1.951431,0.417838,0.21278,-0.854314,-3.008724,0.262441,-1.004068,0.288197,2.417244,0.80685,-1.460248,0.241291,0.254985,-1.364269,2.006926,-1.237922,2.124396,-2.005723,1.716914,-1.646484,1.795211,527898,0,1,1,0,0,-1.624481,7.904661


In [8]:
# prediction and evalution times
# using business days, but the index is not holidays aware -- it can be fixed
t1_ = train.index
# recall that we are holding our position for 21 days
# normally t1 is important is there events such as stop losses, or take profit events
t1 = pd.Series(t1_[:], index=t1_[:]) # t1 is both the trade time and the event time
t1.head() # notice how the events (mark-to-market) take place 5 days later

date
1970-01-01 00:00:00.000000086   1970-01-01 00:00:00.000000086
1970-01-01 00:00:00.000000086   1970-01-01 00:00:00.000000086
1970-01-01 00:00:00.000000086   1970-01-01 00:00:00.000000086
1970-01-01 00:00:00.000000086   1970-01-01 00:00:00.000000086
1970-01-01 00:00:00.000000086   1970-01-01 00:00:00.000000086
Name: date, dtype: datetime64[ns]

In [9]:
num_paths = 5
num_groups_test = 2
num_groups = num_paths + 1 
num_ticks = len(train)
is_test, paths, _ = cpcv_generator(num_ticks, num_groups, num_groups_test)

n_sim: 15
n_paths: 5


In [10]:
num_sim = is_test.shape[1] # num of simulations needed to generate all backtest paths
print(num_sim)

15


In [None]:
def run():
    torch.multiprocessing.freeze_support()

    for _fold in range(num_sim):

        test_idx = is_test[:,_fold]
        test_times = t1.loc[test_idx]
        test_times = test_times.drop_duplicates()
        
        #embargo
        test_times_embargoed = embargo(test_times, t1, pct_embargo=0.01)
        test_times_embargoed = test_times_embargoed.drop_duplicates()
        
        #purge
        train_times = purge(t1, test_times_embargoed)
        
        valid = train.loc[test_times.index, :]
        train_set = MarketDataset(train.loc[train_times.index, :])
        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
        valid_set = MarketDataset(train.loc[test_times.index, :])
        valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
        start_time = time.time()
        
        print(len(train_set))
        print(len(valid_set))

        print(f'Fold: {_fold}')
        seed_everything(seed=42+_fold)
        torch.cuda.empty_cache()
        device = torch.device("cuda:0")
        model = Model()
        model.to(device)
        # model = nn.DataParallel(model)

        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        # optimizer = Nadam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        # optimizer = Lookahead(optimizer=optimizer, k=10, alpha=0.5)
        scheduler = None
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e3,
        #                                                 max_lr=1e-2, epochs=EPOCHS, steps_per_epoch=len(train_loader))
        # loss_fn = nn.BCEWithLogitsLoss()
        loss_fn = SmoothBCEwLogits(smoothing=0.005)

        model_weights = f"{CACHE_PATH}/online_model{_fold}.pth"
        es = EarlyStopping(patience=EARLYSTOP_NUM, mode="max")
        for epoch in range(EPOCHS):
            train_loss = train_fn(model, optimizer, scheduler, loss_fn, train_loader, device)

            valid_pred = inference_fn(model, valid_loader, device)
            valid_auc = roc_auc_score(valid[target_cols].values, valid_pred)
            valid_logloss = log_loss(valid[target_cols].values, valid_pred)
            valid_pred = np.median(valid_pred, axis=1)
            valid_pred = np.where(valid_pred >= 0.5, 1, 0).astype(int)
            valid_u_score = utility_score_bincount(date=valid.date.values, weight=valid.weight.values,
                                                resp=valid.resp.values, action=valid_pred)
            print(f"FOLD{_fold} EPOCH:{epoch:3} train_loss={train_loss:.5f} "
                    f"valid_u_score={valid_u_score:.5f} valid_auc={valid_auc:.5f} "
                    f"time: {(time.time() - start_time) / 60:.2f}min")
            es(valid_auc, model, model_path=model_weights)
            if es.early_stop:
                print("Early stopping")
                break
        # torch.save(model.state_dict(), model_weights)
    if True:
        valid_pred = np.zeros((len(valid), len(target_cols)))
        for _fold in range(NFOLDS):
            torch.cuda.empty_cache()
            device = torch.device("cuda:0")
            model = Model()
            model.to(device)
            model_weights = f"{CACHE_PATH}/online_model{_fold}.pth"
            model.load_state_dict(torch.load(model_weights))

            valid_pred += inference_fn(model, valid_loader, device) / NFOLDS
        auc_score = roc_auc_score(valid[target_cols].values, valid_pred)
        logloss_score = log_loss(valid[target_cols].values, valid_pred)

        valid_pred = np.median(valid_pred, axis=1)
        valid_pred = np.where(valid_pred >= 0.5, 1, 0).astype(int)
        valid_score = utility_score_bincount(date=valid.date.values, weight=valid.weight.values, resp=valid.resp.values,
                                            action=valid_pred)
        print(f'{NFOLDS} models valid score: {valid_score}\tauc_score: {auc_score:.4f}\tlogloss_score:{logloss_score:.4f}')

if __name__ == '__main__':
    run()

step: 18625


100%|███████████████████████████████████████████████████████████████████████████████| 159/159 [00:00<00:00, 649.58it/s]


Fold0:
