In [1]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/riiid-test-answer-prediction/example_sample_submission.csv
/kaggle/input/riiid-test-answer-prediction/example_test.csv
/kaggle/input/riiid-test-answer-prediction/questions.csv
/kaggle/input/riiid-test-answer-prediction/train.csv
/kaggle/input/riiid-test-answer-prediction/lectures.csv
/kaggle/input/riiid-test-answer-prediction/riiideducation/competition.cpython-37m-x86_64-linux-gnu.so
/kaggle/input/riiid-test-answer-prediction/riiideducation/__init__.py
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp184_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp218_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp233_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp166_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp224_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp219_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp248_best.pth
/kaggle/input/mamastan-gpu-v27/models/my_encoder_exp162_best.pth
/kaggle/i

In [2]:
#libs
from copy import deepcopy
import pickle
import numpy as np
import pandas as pd
import os
from tqdm.notebook import tqdm
import types
import ast
import gc
def imports():
    for name, val in globals().items():
        # module imports
        if isinstance(val, types.ModuleType):
            yield name, val
        # functions / callables
        if hasattr(val, '__call__'):
            yield name, val
np.seterr(divide='ignore', invalid='ignore')
noglobal = lambda fn: types.FunctionType(fn.__code__, dict(imports()))
import collections
from scipy.sparse import lil_matrix
import scipy.sparse
%load_ext Cython
from itertools import chain
from IPython.display import display, HTML
import lightgbm as lgb
from pprint import pprint
from sklearn.metrics import roc_auc_score
import warnings
warnings.filterwarnings('ignore', 'Mean of empty slice')

#for GPU
import torch
from torch import nn
import torch.utils.data as torchdata
import torch.nn.functional as F
import math
from collections import OrderedDict

In [3]:
#utils
class RiiidEnv:
    def __init__(self, sn, iterate_wo_predict = False):
        self.sn = sn
        self.sub = sn.loc[sn['content_type_id'] == 0][['row_id']].copy()
        self.sub['answered_correctly'] = 0.5
        self.can_yield = True
        self.iterate_wo_predict = iterate_wo_predict
        self.num_groups = self.sn.index.max() + 1
        
    def iter_test(self):
        for i in range(self.num_groups):
            self.i = i
            assert(self.can_yield)
            if not self.iterate_wo_predict:
                self.can_yield = False

            if i in self.sub.index:
                yield self.sn.loc[[i]], self.sub.loc[[i]]
            elif i not in self.sub.index:
                yield self.sn.loc[[i]], None
                    
    def predict(self, my_sub):
        assert(my_sub['row_id'].dtype == 'int64')
        assert(my_sub['answered_correctly'].dtype == 'float64')
        assert(my_sub.index.name == 'group_num')
        assert(my_sub.index.dtype == 'int64')
        
        if self.i in self.sub.index:
            assert(np.all(my_sub['answered_correctly'] >= 0))
            assert(np.all(my_sub['answered_correctly'] <= 1))
            assert(np.all(my_sub.index == self.i))
            assert(np.all(my_sub['row_id'] == self.sub.loc[[self.i]]['row_id']))
            self.sub.loc[[self.i]] = my_sub
            self.can_yield = True
            
        elif self.i not in self.sub.index:
            assert(my_sub.shape[0] == 0)
            self.can_yield = True

@noglobal
def save_pickle(obj, path):
    with open(path, mode='wb') as f:
        pickle.dump(obj, f)

@noglobal
def load_pickle(path):
    with open(path, mode='rb') as f:
        obj = pickle.load(f)
    return obj

@noglobal
def encode(train, column_name):
    encoded = pd.merge(train[[column_name]], pd.DataFrame(train[column_name].unique(), columns=[column_name])\
                       .reset_index().dropna(), how='left', on=column_name)[['index']].rename(
        columns={'index': column_name})
    return encoded

@noglobal
def update_user_map(tes, user_map_ref, n_users_ref):
    #new_users = tes[tes['timestamp'] == 0]['user_id'].unique()
    users = tes['user_id'].unique()
    keys = user_map_ref.keys()
    new_users = users[np.array([user not in keys for user in users])]
    n_new_users = new_users.shape[0]
    if n_new_users > 0:
        user_map_ref[new_users] = np.arange(n_users_ref, n_users_ref + n_new_users)
    return user_map_ref, n_users_ref + n_new_users
    
@noglobal
def write_to_ref_map(path, ref_name, f_names):
    ref_map = load_pickle(path)
    ref_map[ref_name] = f_names
    save_pickle(ref_map, path)
    
class VectorizedDict():
    def __init__(self):
        self.tr_dict = dict()
        self.set_value = np.vectorize(self.tr_dict.__setitem__)
        self.get_value = np.vectorize(self.tr_dict.__getitem__)
        
    def keys(self):
        return self.tr_dict.keys()
        
    def __setitem__(self, indices, values):
        self.set_value(indices, values)
    
    def __getitem__(self, indices):
        if indices.shape[0] == 0:
            return np.array([], dtype=np.int32)
        return self.get_value(indices)    

In [4]:
%%cython 
import numpy as np
cimport numpy as np

cpdef np.ndarray[int] cget_memory_indices(np.ndarray task):
    
    cdef Py_ssize_t n = task.shape[1]
    cdef np.ndarray[int, ndim = 2] res = np.zeros_like(task, dtype = np.int32)
    cdef np.ndarray[int] tmp_counter = np.full(task.shape[0], -1, dtype = np.int32)
    cdef np.ndarray[int] u_counter = np.full(task.shape[0], task.shape[1] - 1, dtype = np.int32)
    
    for i in range(n):
        res[:, i] = u_counter
        tmp_counter += 1
        if i != n - 1:
            mask = (task[:, i] != task[:, i + 1])
            u_counter[mask] = tmp_counter[mask]
    return res

In [5]:
#func_gpu
@noglobal
def nn_online_get_content_id_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.int16)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = spc_tes['content_id'].values
    return pd.DataFrame(res, columns = ['nn_content_id_history_' + str(i) + '_length_' + str(n_sample) for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_content_id_history(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    content_ids = spc_prev_tes_cp['content_id'].values
    for user, content_id in zip(enc_users, content_ids):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = content_id
    return r_ref

@noglobal
def nn_online_get_normed_log_timestamp_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    spc_tes['log_timestamp'] = np.log1p(spc_tes['timestamp'].values.astype(np.float32))
    std = 3.3530
    mean = 20.863
    spc_tes['normed_log_timestamp'] = (spc_tes['log_timestamp'].values - mean)/std
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.float32)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = spc_tes['normed_log_timestamp'].values
    return pd.DataFrame(res, columns = ['nn_normed_log_timestamp_history_' + str(i) + '_length_' + str(n_sample) \
                                        for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_normed_log_timestamp_history(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    timestamp = spc_prev_tes_cp['timestamp'].values.astype(np.float32)
    std = 3.3530
    mean = 20.863
    normed_log_timestamps = (np.log1p(timestamp) - mean)/std
    for user, normed_log_timestamp in zip(enc_users, normed_log_timestamps):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = normed_log_timestamp
    return r_ref

@noglobal
def nn_online_get_correctness_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.int8)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = 2
    return pd.DataFrame(res, columns = ['nn_correctness_history_' + str(i) + '_length_' + str(n_sample) for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_correctness_history(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    targets = spc_prev_tes_cp['answered_correctly'].values
    for user, target in zip(enc_users, targets):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = target
    return r_ref

@noglobal
def nn_online_get_question_had_explanation_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.int8)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = 2
    return pd.DataFrame(res, columns = ['nn_question_had_explanation_history_' + str(i) + '_length_' + str(n_sample) \
                                        for i in range(n_sample + 1)])

@noglobal
def nn_early_update_reference_get_question_had_explanation_history(r_ref, tes, user_map_ref):
    spc_tes_cp = tes[(tes['content_type_id'] == 0) \
                    & (~tes['prior_question_had_explanation'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    explanations = spc_tes_cp['prior_question_had_explanation'].values.astype('int')
    for explanation, idx in zip(explanations, row_idx):
        r_ref[idx, r_ref[idx] == 2] = explanation
    return r_ref

@noglobal
def nn_update_reference_get_question_had_explanation_history(r_ref, prev_tes, tes, user_map_ref):
    spc_prev_tes = prev_tes[prev_tes['content_type_id'] == 0].copy()
    row_idx = user_map_ref[spc_prev_tes['user_id'].values]
    for user in row_idx:
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = 2
    return r_ref

@noglobal
def nn_online_get_normed_elapsed_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.float32)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = -999
    return pd.DataFrame(res, columns = ['nn_normed_elapsed_history_' + str(i) + '_length_' + str(n_sample) for i in range(n_sample + 1)])

@noglobal
def nn_early_update_reference_get_normed_elapsed_history(r_ref, tes, user_map_ref):
    spc_tes_cp = tes[(tes['content_type_id'] == 0) \
                    & (~tes['prior_question_elapsed_time'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    elapsed = spc_tes_cp['prior_question_elapsed_time'].values
    clipped = np.clip(elapsed, 0, 1000 * 300).astype(np.float32)
    mean = 25953
    std = 20418
    normed_elapsed = (clipped - mean)/std
    for el, idx in zip(normed_elapsed, row_idx):
        r_ref[idx, r_ref[idx] == -999] = el
    return r_ref

@noglobal
def nn_update_reference_get_normed_elapsed_history(r_ref, prev_tes, tes, user_map_ref):
    spc_prev_tes = prev_tes[prev_tes['content_type_id'] == 0].copy()
    row_idx = user_map_ref[spc_prev_tes['user_id'].values]
    for user in row_idx:
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = -999
    return r_ref

@noglobal
def online_get_modified_timedelta(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    uniq, enc = np.unique(spc_tes['user_id'].values, return_inverse = True)
    spc_tes['count'] = np.bincount(enc)[enc]
    res = (spc_tes['timestamp'].values - r_ref[user_map_ref[spc_tes['user_id'].values]])/spc_tes['count'].values
    return pd.DataFrame(res, columns = ['modified_timedelta']).astype(np.float32)

@noglobal
def update_reference_get_timestamp_diff(r_ref, prev_tes, tes, user_map_ref):
    r_ref[user_map_ref[prev_tes['user_id'].values]] = prev_tes['timestamp'].values
    return r_ref

@noglobal
def nn_online_get_normed_modified_timedelta_history(r_ref, tes, f_tes_delta, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()

    clipped = np.clip(f_tes_delta['modified_timedelta'].values, 0, 1000 * 800)
    mean = 126568
    std = 218000
    clipped = (clipped - mean)/std
    clipped[np.isnan(clipped)] = 0
    spc_tes['normed_modified_timedelta'] = clipped.astype(np.float32)

    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.float32)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = spc_tes['normed_modified_timedelta'].values
    return pd.DataFrame(res, columns = ['nn_normed_modified_timedelta_history_' + str(i) + '_length_' + str(n_sample) \
                                        for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_normed_modified_timedelta_history(r_ref, prev_tes, tes, f_tes_delta, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    
    clipped = np.clip(f_tes_delta['modified_timedelta'].values, 0, 1000 * 800)
    mean = 126568
    std = 218000
    clipped = (clipped - mean)/std
    clipped[np.isnan(clipped)] = 0
    spc_prev_tes_cp['normed_modified_timedelta'] = clipped.astype(np.float32)

    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    normed_modified_timedeltas = spc_prev_tes_cp['normed_modified_timedelta'].values
    for user, normed_modified_timedelta in zip(enc_users, normed_modified_timedeltas):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = normed_modified_timedelta
    return r_ref

@noglobal
def nn_online_get_user_answer_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.int8)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = 4
    return pd.DataFrame(res, columns = ['nn_user_answer_history_' + str(i) + '_length_' + str(n_sample) for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_user_answer_history(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    targets = spc_prev_tes_cp['user_answer'].values
    for user, target in zip(enc_users, targets):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = target
    return r_ref

@noglobal
def online_get_task_container_id_diff(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = spc_tes['task_container_id'].values - r_ref[user_map_ref[spc_tes['user_id'].values]]
    return pd.DataFrame(res, columns = ['task_container_id_diff']).astype(np.float32)

@noglobal
def update_reference_get_task_container_id_diff(r_ref, prev_tes, tes, user_map_ref):
    r_ref[user_map_ref[prev_tes['user_id'].values]] = prev_tes['task_container_id'].values
    return r_ref

@noglobal
def nn_online_get_task_container_id_diff_history(r_ref, tes, f_tes_delta, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    value = f_tes_delta['task_container_id_diff'].values
    value[np.isnan(value)] = 0
    spc_tes['task_container_id_diff'] = value
    
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.float32)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = spc_tes['task_container_id_diff'].values
    return pd.DataFrame(res, columns = ['nn_task_container_id_diff_history_' + str(i) + '_length_' + str(n_sample) \
                                        for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_task_container_id_diff_history(r_ref, prev_tes, tes, f_tes_delta, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    
    value = f_tes_delta['task_container_id_diff'].values
    value[np.isnan(value)] = 0
    spc_prev_tes_cp['task_container_id_diff'] = value

    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    task_container_id_diffs = spc_prev_tes_cp['task_container_id_diff'].values
    for user, task_container_id_diff in zip(enc_users, task_container_id_diffs):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = task_container_id_diff
    return r_ref

@noglobal
def online_get_content_type_id_diff(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = spc_tes['content_type_id'].values - r_ref[user_map_ref[spc_tes['user_id'].values]]
    return pd.DataFrame(res, columns = ['content_type_id_diff']).astype(np.float32)

@noglobal
def update_reference_get_content_type_id_diff(r_ref, prev_tes, tes, user_map_ref):
    r_ref[user_map_ref[prev_tes['user_id'].values]] = prev_tes['content_type_id'].values
    return r_ref

@noglobal
def nn_online_get_content_type_id_diff_history(r_ref, tes, f_tes_delta, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    value = f_tes_delta['content_type_id_diff'].values
    value[np.isnan(value)] = 0
    spc_tes['content_type_id_diff'] = value
    
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.float32)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = spc_tes['content_type_id_diff'].values
    return pd.DataFrame(res, columns = ['nn_content_type_id_diff_history_' + str(i) + '_length_' + str(n_sample) \
                                        for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_content_type_id_diff_history(r_ref, prev_tes, tes, f_tes_delta, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    
    value = f_tes_delta['content_type_id_diff'].values
    value[np.isnan(value)] = 0
    spc_prev_tes_cp['content_type_id_diff'] = value

    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    content_type_id_diffs = spc_prev_tes_cp['content_type_id_diff'].values
    for user, content_type_id_diff in zip(enc_users, content_type_id_diffs):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = content_type_id_diff
    return r_ref

@noglobal
def nn_online_get_task_container_id_history(r_ref, tes, user_map_ref):
    n_sample = r_ref.shape[1]
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    res = np.zeros((spc_tes.shape[0], n_sample + 1), dtype = np.int16)
    res[:, :n_sample] = r_ref[user_map_ref[users]]
    res[:, n_sample] = spc_tes['task_container_id'].values
    return pd.DataFrame(res, columns = ['nn_task_container_id_history_' + str(i) + '_length_' + str(n_sample) for i in range(n_sample + 1)])

@noglobal
def nn_update_reference_get_task_container_id_history(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    task_container_ids = spc_prev_tes_cp['task_container_id'].values
    for user, task_container_id in zip(enc_users, task_container_ids):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = task_container_id
    return r_ref

class Embedder(nn.Module):
    def __init__(self, n_proj, n_dims):
        super(Embedder, self).__init__()
        self.n_proj = n_proj
        self.n_dims = n_dims
        self.embed = nn.Embedding(n_proj, n_dims)

    def forward(self, indices):
        z = self.embed(indices)
        return z

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class MyEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(MyEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'gelu':
            self.activation = F.gelu

    def forward(self, q, k, v, src_mask, src_key_padding_mask):
        src2 = self.self_attn(q, k, v, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = q + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class MyEncoderExp162(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp162, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features + 4 * args.n_conv, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)
        
        #conv
        self.conv1 = nn.Conv1d(11 + 2, 8, 7, padding = 3)
        self.conv2 = nn.Conv1d(8, 8, 5, padding = 2)
        self.conv3 = nn.Conv1d(8, 4, 3, padding = 1)

    def forward(self, batch, args):
        #input
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']
            
        # relative positional encoding
        position = batch['position']
            
        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #key and value
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)
        
        #query process
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))
        
        #key and value process
        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)
        
        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        
        #conv process
        conv_inputs_int = batch['conv_inputs_int']
        conv_inputs_float = batch['conv_inputs_float']
        conv_inputs_const = batch['conv_inputs_const']

        conv_explanation = conv_inputs_int[:, :, :, 0]
        conv_correctness = conv_inputs_int[:, :, :, 1]
        conv_normed_elapsed = conv_inputs_float[:, :, :, 0]
        conv_normed_log_timestamp = conv_inputs_float[:, :, :, 1]
        conv_normed_timedelta = conv_inputs_float[:, :, :, 2]
        conv_task_container_id_diff = conv_inputs_float[:, :, :, 3]
        conv_content_type_id_diff = conv_inputs_float[:, :, :, 4]
    
        const_normed_log_timestamp = conv_inputs_const[:, :, :, 0]
        const_normed_timedelta = conv_inputs_const[:, :, :, 1]
        ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
        ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

        for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
        conv_normed_elapsed.unsqueeze(3), conv_normed_log_timestamp.unsqueeze(3),
        conv_normed_timedelta.unsqueeze(3), conv_task_container_id_diff.unsqueeze(3),
        conv_content_type_id_diff.unsqueeze(3), const_normed_log_timestamp.unsqueeze(3), 
        const_normed_timedelta.unsqueeze(3)], dim = 3).transpose(0, 1)
        for_conv = for_conv.contiguous().transpose(2, 3).view(N * n_length * 2, -1, args.n_conv)
        out_conv = self.conv3(F.relu(self.conv2(F.relu(self.conv1(for_conv))))).view(n_length * 2, N, -1)
        
        #concat 
        out = torch.cat([out, out_conv, inputs_add], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features + 4 * args.n_conv)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]
            
            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]
            
            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1
            
            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query[:, :n_sample].clone()
            query_clone[token_idx, -1] = query_clone[token_idx, -1] * 0
            memory_cat = torch.cat([query_clone, ohe_explanation[:, :-1], ohe_correctness[:, :-1], 
                                    normed_elapsed.unsqueeze(2)[:, :-1], ohe_user_answer[:, :-1]], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)
            
            #conv process
            conv_explanation = explanation[:, -args.n_conv - 1 : -1]
            conv_correctness = correctness[:, -args.n_conv - 1 : -1]
            conv_normed_elapsed = normed_elapsed[:, -args.n_conv - 1 : -1]
            conv_normed_log_timestamp = normed_log_timestamp[:, -args.n_conv - 1 : -1]
            conv_normed_timedelta = normed_timedelta[:, -args.n_conv - 1 : -1]
            conv_task_container_id_diff = task_container_id_diff[:, -args.n_conv - 1 : -1]
            conv_content_type_id_diff = content_type_id_diff[:, -args.n_conv - 1 : -1]
            
            
            const_normed_log_timestamp = normed_log_timestamp[:, -1:].repeat(1, args.n_conv)
            const_normed_timedelta = normed_timedelta[:, -1:].repeat(1, args.n_conv)

            ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
            ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

            for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
            conv_normed_elapsed.unsqueeze(2), conv_normed_log_timestamp.unsqueeze(2),
            conv_normed_timedelta.unsqueeze(2), conv_task_container_id_diff.unsqueeze(2),
            conv_content_type_id_diff.unsqueeze(2), const_normed_log_timestamp.unsqueeze(2), 
            const_normed_timedelta.unsqueeze(2)], dim = 2).transpose(1, 2)

            out_conv = self.conv3(F.relu(self.conv2(F.relu(self.conv1(for_conv))))).view(n_batch, - 1)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, out_conv, features], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp166(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp166, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features + 4 * args.n_conv, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)
        
        #conv
        self.conv1 = nn.Conv1d(11 + 2, 8, 7, padding = 3)
        self.conv2 = nn.Conv1d(8, 8, 5, padding = 2)
        self.conv3 = nn.Conv1d(8, 4, 3, padding = 1)

    def forward(self, batch, args):
        #input
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']
            
        # relative positional encoding
        position = batch['position']
            
        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #key and value
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)
        
        #query process
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))
        
        #key and value process
        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)
        
        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        
        #conv process
        conv_inputs_int = batch['conv_inputs_int']
        conv_inputs_float = batch['conv_inputs_float']
        conv_inputs_const = batch['conv_inputs_const']

        conv_explanation = conv_inputs_int[:, :, :, 0]
        conv_correctness = conv_inputs_int[:, :, :, 1]
        conv_normed_elapsed = conv_inputs_float[:, :, :, 0]
        conv_normed_log_timestamp = conv_inputs_float[:, :, :, 1]
        conv_normed_timedelta = conv_inputs_float[:, :, :, 2]
        conv_task_container_id_diff = conv_inputs_float[:, :, :, 3]
        conv_content_type_id_diff = conv_inputs_float[:, :, :, 4]
    
        const_normed_log_timestamp = conv_inputs_const[:, :, :, 0]
        const_normed_timedelta = conv_inputs_const[:, :, :, 1]
        ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
        ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

        for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
        conv_normed_elapsed.unsqueeze(3), conv_normed_log_timestamp.unsqueeze(3),
        conv_normed_timedelta.unsqueeze(3), conv_task_container_id_diff.unsqueeze(3),
        conv_content_type_id_diff.unsqueeze(3), const_normed_log_timestamp.unsqueeze(3), 
        const_normed_timedelta.unsqueeze(3)], dim = 3).transpose(0, 1)
        for_conv = for_conv.contiguous().transpose(2, 3).view(N * n_length * 2, -1, args.n_conv)
        out_conv = self.conv3(F.relu(self.conv2(F.relu(self.conv1(for_conv))))).view(n_length * 2, N, -1)
        
        #concat 
        out = torch.cat([out, out_conv, inputs_add], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features + 4 * args.n_conv)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]
            
            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]
            
            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1
            
            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query[:, :n_sample].clone()
            query_clone[token_idx, -1] = query_clone[token_idx, -1] * 0
            memory_cat = torch.cat([query_clone, ohe_explanation[:, :-1], ohe_correctness[:, :-1], 
                                    normed_elapsed.unsqueeze(2)[:, :-1], ohe_user_answer[:, :-1]], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)
            
            #conv process
            conv_explanation = explanation[:, -args.n_conv - 1 : -1]
            conv_correctness = correctness[:, -args.n_conv - 1 : -1]
            conv_normed_elapsed = normed_elapsed[:, -args.n_conv - 1 : -1]
            conv_normed_log_timestamp = normed_log_timestamp[:, -args.n_conv - 1 : -1]
            conv_normed_timedelta = normed_timedelta[:, -args.n_conv - 1 : -1]
            conv_task_container_id_diff = task_container_id_diff[:, -args.n_conv - 1 : -1]
            conv_content_type_id_diff = content_type_id_diff[:, -args.n_conv - 1 : -1]
            
            
            const_normed_log_timestamp = normed_log_timestamp[:, -1:].repeat(1, args.n_conv)
            const_normed_timedelta = normed_timedelta[:, -1:].repeat(1, args.n_conv)

            ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
            ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

            for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
            conv_normed_elapsed.unsqueeze(2), conv_normed_log_timestamp.unsqueeze(2),
            conv_normed_timedelta.unsqueeze(2), conv_task_container_id_diff.unsqueeze(2),
            conv_content_type_id_diff.unsqueeze(2), const_normed_log_timestamp.unsqueeze(2), 
            const_normed_timedelta.unsqueeze(2)], dim = 2).transpose(1, 2)

            out_conv = self.conv3(F.relu(self.conv2(F.relu(self.conv1(for_conv))))).view(n_batch, - 1)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, out_conv, features], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp184(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp184, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features + 16 * 8, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)
        
        #conv
#         self.conv1 = nn.Conv1d(11 + 2, 8, 7, padding = 3)
#         self.conv2 = nn.Conv1d(8, 8, 5, padding = 2)
#         self.conv3 = nn.Conv1d(8, 4, 3, padding = 1)
        
        #rnn
        self.rnn = nn.GRU(13, 16, 8, batch_first = True)

    def forward(self, batch, args):
        #input
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']
            
        # relative positional encoding
        position = batch['position']
            
        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #key and value
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)
        
        #query process
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))
        
        #key and value process
        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)
        
        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer4(out, memory, memory, memory_mask, padding_mask)
        
        #conv process
        conv_inputs_int = batch['conv_inputs_int']
        conv_inputs_float = batch['conv_inputs_float']
        conv_inputs_const = batch['conv_inputs_const']

        conv_explanation = conv_inputs_int[:, :, :, 0]
        conv_correctness = conv_inputs_int[:, :, :, 1]
        conv_normed_elapsed = conv_inputs_float[:, :, :, 0]
        conv_normed_log_timestamp = conv_inputs_float[:, :, :, 1]
        conv_normed_timedelta = conv_inputs_float[:, :, :, 2]
        conv_task_container_id_diff = conv_inputs_float[:, :, :, 3]
        conv_content_type_id_diff = conv_inputs_float[:, :, :, 4]
    
        const_normed_log_timestamp = conv_inputs_const[:, :, :, 0]
        const_normed_timedelta = conv_inputs_const[:, :, :, 1]
        ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
        ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

        for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
        conv_normed_elapsed.unsqueeze(3), conv_normed_log_timestamp.unsqueeze(3),
        conv_normed_timedelta.unsqueeze(3), conv_task_container_id_diff.unsqueeze(3),
        conv_content_type_id_diff.unsqueeze(3), const_normed_log_timestamp.unsqueeze(3), 
        const_normed_timedelta.unsqueeze(3)], dim = 3).transpose(0, 1)
        
        #rnn
        for_rnn = for_conv.contiguous().view(N * n_length * 2, args.n_conv, -1)
        _, out_rnn = self.rnn(for_rnn)
        out_rnn = out_rnn.transpose(0, 1).contiguous().view(n_length *2, N, -1)

        #cat
        out = torch.cat([out, out_rnn, inputs_add], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features + 16 * 8)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]
            
            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]
            
            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1
            
            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query[:, :n_sample].clone()
            query_clone[token_idx, -1] = query_clone[token_idx, -1] * 0
            memory_cat = torch.cat([query_clone, ohe_explanation[:, :-1], ohe_correctness[:, :-1], 
                                    normed_elapsed.unsqueeze(2)[:, :-1], ohe_user_answer[:, :-1]], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)
            
            #conv process
            conv_explanation = explanation[:, -args.n_conv - 1 : -1]
            conv_correctness = correctness[:, -args.n_conv - 1 : -1]
            conv_normed_elapsed = normed_elapsed[:, -args.n_conv - 1 : -1]
            conv_normed_log_timestamp = normed_log_timestamp[:, -args.n_conv - 1 : -1]
            conv_normed_timedelta = normed_timedelta[:, -args.n_conv - 1 : -1]
            conv_task_container_id_diff = task_container_id_diff[:, -args.n_conv - 1 : -1]
            conv_content_type_id_diff = content_type_id_diff[:, -args.n_conv - 1 : -1]
            
            
            const_normed_log_timestamp = normed_log_timestamp[:, -1:].repeat(1, args.n_conv)
            const_normed_timedelta = normed_timedelta[:, -1:].repeat(1, args.n_conv)

            ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
            ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

            for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
            conv_normed_elapsed.unsqueeze(2), conv_normed_log_timestamp.unsqueeze(2),
            conv_normed_timedelta.unsqueeze(2), conv_task_container_id_diff.unsqueeze(2),
            conv_content_type_id_diff.unsqueeze(2), const_normed_log_timestamp.unsqueeze(2), 
            const_normed_timedelta.unsqueeze(2)], dim = 2)
            
            #rnn
            for_rnn = for_conv
            _, out_rnn = self.rnn(for_rnn)
            out_rnn = out_rnn.transpose(0, 1).contiguous().view(n_batch, -1)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, out_rnn, features], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp218(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp218, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #for rnn
        self.linear5 = nn.Linear(714 + 1 + 2 + 256, args.emb_dim * 2)
        self.linear6 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.norm3 = nn.LayerNorm(args.emb_dim)

        self.linear7 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear8 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm4 = nn.LayerNorm(args.emb_dim)
        self.rnn = nn.GRU(512, 256, 4, batch_first = True, dropout = args.dropout)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)

    def forward(self, batch, args):
        #inputs
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']

        # relative positional encoding
        position = batch['position']

        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #initial query and memory 
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #for rnn process
        memory[padding_mask] = memory[padding_mask] * 0

        tmp, _ = self.rnn(memory)
        out_rnn = tmp.clone()
        out_rnn[torch.arange(N), end_pos_idx - 1] = out_rnn[torch.arange(N), end_pos_idx - 1] * 0

        memory_idx = batch['memory_idx']
        memory_idx_ = memory_idx + (torch.arange(memory_idx.shape[0]).to(args.device) * memory_idx.shape[1])[:, None]
        out_rnn = out_rnn.contiguous().view(N * args.n_length * 2, -1)
        memory_rnn = out_rnn[memory_idx_]

        #new query and memory
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
        query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)

        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer4(out, memory, memory, memory_mask, padding_mask)

        #cat
        out = torch.cat([out, inputs_add], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]

            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]

            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1

            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))
            memory[:, :-1][padding_mask] = 0
            memory[:, -1] = 0
            out_rnn, _ = self.rnn(memory)
            out_rnn[:, -1] = 0

            task_container_id = inputs['task_container_id'][i: i + pred_bs]
            memory_idx = cget_memory_indices(task_container_id)
            memory_idx_ = memory_idx + (np.arange(memory_idx.shape[0]) * memory_idx.shape[1])[:, None]
            out_rnn = out_rnn.contiguous().view(n_batch * (n_sample + 1), -1)
            memory_rnn = out_rnn[torch.from_numpy(memory_idx_).to(args.device)]

            #query process 2 
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
            query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

            #key and value process 2
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))
            memory = memory[:, :-1]

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, features], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp224(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp224, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #for rnn
        self.linear5 = nn.Linear(714 + 1 + 2 + 256, args.emb_dim * 2)
        self.linear6 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.norm3 = nn.LayerNorm(args.emb_dim)

        self.linear7 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear8 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm4 = nn.LayerNorm(args.emb_dim)
        self.rnn = nn.LSTM(512, 256, 4, batch_first = True, dropout = args.dropout)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features + 256 + 16 * 8, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)
        
        self.fixlen_rnn = nn.GRU(13, 16, 8, batch_first = True)

    def forward(self, batch, args):
        #inputs
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']

        # relative positional encoding
        position = batch['position']

        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #initial query and memory 
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #for rnn process
        memory[padding_mask] = memory[padding_mask] * 0

        tmp, _ = self.rnn(memory)
        out_rnn = tmp.clone()
        out_rnn[torch.arange(N), end_pos_idx - 1] = out_rnn[torch.arange(N), end_pos_idx - 1] * 0

        memory_idx = batch['memory_idx']
        memory_idx_ = memory_idx + (torch.arange(memory_idx.shape[0]).to(args.device) * memory_idx.shape[1])[:, None]
        out_rnn = out_rnn.contiguous().view(N * args.n_length * 2, -1)
        memory_rnn = out_rnn[memory_idx_]

        #new query and memory
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
        query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)

        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer4(out, memory, memory, memory_mask, padding_mask)
        
        #conv process
        conv_inputs_int = batch['conv_inputs_int']
        conv_inputs_float = batch['conv_inputs_float']
        conv_inputs_const = batch['conv_inputs_const']

        conv_explanation = conv_inputs_int[:, :, :, 0]
        conv_correctness = conv_inputs_int[:, :, :, 1]
        conv_normed_elapsed = conv_inputs_float[:, :, :, 0]
        conv_normed_log_timestamp = conv_inputs_float[:, :, :, 1]
        conv_normed_timedelta = conv_inputs_float[:, :, :, 2]
        conv_task_container_id_diff = conv_inputs_float[:, :, :, 3]
        conv_content_type_id_diff = conv_inputs_float[:, :, :, 4]
    
        const_normed_log_timestamp = conv_inputs_const[:, :, :, 0]
        const_normed_timedelta = conv_inputs_const[:, :, :, 1]
        ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
        ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

        for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
        conv_normed_elapsed.unsqueeze(3), conv_normed_log_timestamp.unsqueeze(3),
        conv_normed_timedelta.unsqueeze(3), conv_task_container_id_diff.unsqueeze(3),
        conv_content_type_id_diff.unsqueeze(3), const_normed_log_timestamp.unsqueeze(3), 
        const_normed_timedelta.unsqueeze(3)], dim = 3).transpose(0, 1)
        
        #rnn
        for_rnn = for_conv.contiguous().view(N * n_length * 2, args.n_conv, -1)
        _, out_fixlen_rnn = self.fixlen_rnn(for_rnn)
        out_fixlen_rnn = out_fixlen_rnn.transpose(0, 1).contiguous().view(n_length *2, N, -1)

        #cat
        out = torch.cat([out, out_fixlen_rnn, inputs_add, memory_rnn.transpose(0, 1)], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features + 256 + 16 * 8)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]

            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]

            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1

            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))
            memory[:, :-1][padding_mask] = 0
            memory[:, -1] = 0
            out_rnn, _ = self.rnn(memory)
            out_rnn[:, -1] = 0

            task_container_id = inputs['task_container_id'][i: i + pred_bs]
            memory_idx = cget_memory_indices(task_container_id)
            memory_idx_ = memory_idx + (np.arange(memory_idx.shape[0]) * memory_idx.shape[1])[:, None]
            out_rnn = out_rnn.contiguous().view(n_batch * (n_sample + 1), -1)
            memory_rnn = out_rnn[torch.from_numpy(memory_idx_).to(args.device)]

            #query process 2 
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
            query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

            #key and value process 2
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))
            memory = memory[:, :-1]

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)
            
            #conv process
            conv_explanation = explanation[:, -args.n_conv - 1 : -1]
            conv_correctness = correctness[:, -args.n_conv - 1 : -1]
            conv_normed_elapsed = normed_elapsed[:, -args.n_conv - 1 : -1]
            conv_normed_log_timestamp = normed_log_timestamp[:, -args.n_conv - 1 : -1]
            conv_normed_timedelta = normed_timedelta[:, -args.n_conv - 1 : -1]
            conv_task_container_id_diff = task_container_id_diff[:, -args.n_conv - 1 : -1]
            conv_content_type_id_diff = content_type_id_diff[:, -args.n_conv - 1 : -1]
            
            
            const_normed_log_timestamp = normed_log_timestamp[:, -1:].repeat(1, args.n_conv)
            const_normed_timedelta = normed_timedelta[:, -1:].repeat(1, args.n_conv)

            ohe_conv_explanation = F.one_hot(conv_explanation, num_classes = 3)
            ohe_conv_correctness = F.one_hot(conv_correctness, num_classes = 3)

            for_conv = torch.cat([ohe_conv_explanation, ohe_conv_correctness,
            conv_normed_elapsed.unsqueeze(2), conv_normed_log_timestamp.unsqueeze(2),
            conv_normed_timedelta.unsqueeze(2), conv_task_container_id_diff.unsqueeze(2),
            conv_content_type_id_diff.unsqueeze(2), const_normed_log_timestamp.unsqueeze(2), 
            const_normed_timedelta.unsqueeze(2)], dim = 2)
            
            #rnn
            for_rnn = for_conv
            _, out_fixlen_rnn = self.fixlen_rnn(for_rnn)
            out_fixlen_rnn = out_fixlen_rnn.transpose(0, 1).contiguous().view(n_batch, -1)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, out_fixlen_rnn, features, memory_rnn[:, -1]], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp219(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp219, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #for rnn
        self.linear5 = nn.Linear(714 + 1 + 2 + 256, args.emb_dim * 2)
        self.linear6 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.norm3 = nn.LayerNorm(args.emb_dim)

        self.linear7 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear8 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm4 = nn.LayerNorm(args.emb_dim)
        self.rnn = nn.LSTM(512, 256, 4, batch_first = True)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)

    def forward(self, batch, args):
        #inputs
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']

        # relative positional encoding
        position = batch['position']

        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #initial query and memory 
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #for rnn process
        memory[padding_mask] = memory[padding_mask] * 0

        tmp, _ = self.rnn(memory)
        out_rnn = tmp.clone()
        out_rnn[torch.arange(N), end_pos_idx - 1] = out_rnn[torch.arange(N), end_pos_idx - 1] * 0

        memory_idx = batch['memory_idx']
        memory_idx_ = memory_idx + (torch.arange(memory_idx.shape[0]).to(args.device) * memory_idx.shape[1])[:, None]
        out_rnn = out_rnn.contiguous().view(N * args.n_length * 2, -1)
        memory_rnn = out_rnn[memory_idx_]

        #new query and memory
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
        query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)

        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer4(out, memory, memory, memory_mask, padding_mask)

        #cat
        out = torch.cat([out, inputs_add], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]

            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]

            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1

            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))
            memory[:, :-1][padding_mask] = 0
            memory[:, -1] = 0
            out_rnn, _ = self.rnn(memory)
            out_rnn[:, -1] = 0

            task_container_id = inputs['task_container_id'][i: i + pred_bs]
            memory_idx = cget_memory_indices(task_container_id)
            memory_idx_ = memory_idx + (np.arange(memory_idx.shape[0]) * memory_idx.shape[1])[:, None]
            out_rnn = out_rnn.contiguous().view(n_batch * (n_sample + 1), -1)
            memory_rnn = out_rnn[torch.from_numpy(memory_idx_).to(args.device)]

            #query process 2 
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
            query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

            #key and value process 2
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))
            memory = memory[:, :-1]

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, features], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp221(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp221, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #for rnn
        self.linear5 = nn.Linear(714 + 1 + 2 + 256, args.emb_dim * 2)
        self.linear6 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.norm3 = nn.LayerNorm(args.emb_dim)

        self.linear7 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear8 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm4 = nn.LayerNorm(args.emb_dim)
        self.rnn = nn.GRU(512, 256, 4, batch_first = True)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features + 256, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)

    def forward(self, batch, args):
        #inputs
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']

        # relative positional encoding
        position = batch['position']

        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #initial query and memory 
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #for rnn process
        memory[padding_mask] = memory[padding_mask] * 0

        tmp, _ = self.rnn(memory)
        out_rnn = tmp.clone()
        out_rnn[torch.arange(N), end_pos_idx - 1] = out_rnn[torch.arange(N), end_pos_idx - 1] * 0

        memory_idx = batch['memory_idx']
        memory_idx_ = memory_idx + (torch.arange(memory_idx.shape[0]).to(args.device) * memory_idx.shape[1])[:, None]
        out_rnn = out_rnn.contiguous().view(N * args.n_length * 2, -1)
        memory_rnn = out_rnn[memory_idx_]

        #new query and memory
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
        query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)

        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer4(out, memory, memory, memory_mask, padding_mask)

        #cat
        out = torch.cat([out, inputs_add, memory_rnn.transpose(0, 1)], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features + 256)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]

            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]

            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1

            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))
            memory[:, :-1][padding_mask] = 0
            memory[:, -1] = 0
            out_rnn, _ = self.rnn(memory)
            out_rnn[:, -1] = 0

            task_container_id = inputs['task_container_id'][i: i + pred_bs]
            memory_idx = cget_memory_indices(task_container_id)
            memory_idx_ = memory_idx + (np.arange(memory_idx.shape[0]) * memory_idx.shape[1])[:, None]
            out_rnn = out_rnn.contiguous().view(n_batch * (n_sample + 1), -1)
            memory_rnn = out_rnn[torch.from_numpy(memory_idx_).to(args.device)]

            #query process 2 
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
            query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

            #key and value process 2
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))
            memory = memory[:, :-1]

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, features, memory_rnn[:, -1]], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()
    
class MyEncoderExp222(nn.Module):
    def __init__(self, args):
        super(MyEncoderExp222, self).__init__()
        #query, key and value
        self.embedder_content_id = Embedder(13523, 512)
        self.linear1 = nn.Linear(714 + 1 + 2, args.emb_dim * 2)
        self.linear2 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.dropout = nn.Dropout(args.dropout)
        self.norm1 = nn.LayerNorm(args.emb_dim)

        #key and value
        self.linear3 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear4 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm2 = nn.LayerNorm(args.emb_dim)

        #for rnn
        self.linear5 = nn.Linear(714 + 1 + 2 + 256, args.emb_dim * 2)
        self.linear6 = nn.Linear(args.emb_dim * 2, args.emb_dim)
        self.norm3 = nn.LayerNorm(args.emb_dim)

        self.linear7 = nn.Linear(args.emb_dim + 12, args.emb_dim)
        self.linear8 = nn.Linear(args.emb_dim, args.emb_dim)
        self.norm4 = nn.LayerNorm(args.emb_dim)
        self.rnn = nn.LSTM(512, 256, 4, batch_first = True, dropout = args.dropout)

        #MyEncoder
        self.MyEncoderLayer1 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer2 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer3 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.MyEncoderLayer4 = MyEncoderLayer(args.emb_dim, args.nhead, dim_feedforward = args.dim_feedforward, 
                                                dropout = args.dropout)
        self.last_fc1 = nn.Linear(args.emb_dim + args.num_features + 256, args.emb_dim * 4)
        self.last_fc2 = nn.Linear(args.emb_dim * 4, 1)

    def forward(self, batch, args):
        #inputs
        inputs_int = batch['inputs_int']
        inputs_float = batch['inputs_float']
        inputs_add = batch['inputs_add'].transpose(0, 1)
        inputs_add_2 = batch['inputs_add_2']
        N = inputs_int.shape[0]
        n_length = int(inputs_int.shape[1]/2)

        #for query, key and value
        content_id = inputs_int[:, :, 0]
        part = (inputs_int[:, :, 1] - 1)
        tags = inputs_int[:, :, 2:8]
        tag_mask = batch['tag_mask']
        #digit_timedelta = inputs_int[:, :, 8]
        normed_timedelta = inputs_float[:, :, 2]
        normed_log_timestamp = inputs_float[:, :, 1]
        correct_answer = inputs_int[:, :, 13]
        task_container_id_diff = inputs_add_2[:, :, 0]
        content_type_id_diff = inputs_add_2[:, :, 1]

        #for key and value
        explanation = inputs_int[:, :, 9]
        correctness = inputs_int[:, :, 10]
        normed_elapsed = inputs_float[:, :, 0]
        user_answer = inputs_int[:, :, 11]
        end_pos_idx = batch['end_pos_idx']

        #mask
        memory_mask = torch.repeat_interleave(batch['memory_mask'], args.nhead, dim = 0)
        padding_mask = batch['padding_mask']

        #target, loss_mask
        target = batch['target']
        if 'loss_mask' in batch.keys():
            loss_mask = batch['loss_mask']
        else:
            loss_mask = batch['cut_mask']

        # relative positional encoding
        position = batch['position']

        #query, key and value
        emb_content_id = self.embedder_content_id(content_id)
        ohe_part = F.one_hot(part, num_classes = 7)
        ohe_tags = F.one_hot(tags, num_classes = 189)
        ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
        ohe_tags_sum = ohe_tags.sum(dim = 2)
        ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

        #initial query and memory 
        ohe_explanation = F.one_hot(explanation, num_classes = 3)
        ohe_correctness = F.one_hot(correctness, num_classes = 3)
        ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
        query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))

        #for rnn process
        memory[padding_mask] = memory[padding_mask] * 0

        tmp, _ = self.rnn(memory)
        out_rnn = tmp.clone()
        out_rnn[torch.arange(N), end_pos_idx - 1] = out_rnn[torch.arange(N), end_pos_idx - 1] * 0

        memory_idx = batch['memory_idx']
        memory_idx_ = memory_idx + (torch.arange(memory_idx.shape[0]).to(args.device) * memory_idx.shape[1])[:, None]
        out_rnn = out_rnn.contiguous().view(N * args.n_length * 2, -1)
        memory_rnn = out_rnn[memory_idx_]

        #new query and memory
        query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                               normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                               ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                               content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
        query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

        query_clone = query.clone()
        query_clone[torch.arange(N), end_pos_idx - 1] = query_clone[torch.arange(N), end_pos_idx - 1] * 0
        memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
        memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))

        #transpose
        query = query.transpose(0, 1)
        memory = memory.transpose(0, 1)

        #MyEncoder
        out = self.MyEncoderLayer1(query, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer2(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer3(out, memory, memory, memory_mask, padding_mask)
        out = self.MyEncoderLayer4(out, memory, memory, memory_mask, padding_mask)

        #cat
        out = torch.cat([out, inputs_add, memory_rnn.transpose(0, 1)], dim = 2)
        out = self.last_fc2(self.dropout(F.relu(self.last_fc1(out.reshape(-1, args.emb_dim + args.num_features + 256)))))\
        .reshape(-1, N).transpose(0, 1)
        score = torch.sigmoid(out)
        err = F.binary_cross_entropy_with_logits(out[loss_mask], target[loss_mask])
        return err, score
    
    def predict_on_batch(self, inputs, args, pred_bs, que, que_proc_int):
        torch.set_grad_enabled(False)
        N = inputs['content_id'].shape[0]
        scores = torch.zeros(N).to(args.device)
        for i in range(0, N, pred_bs):
            content_id = inputs['content_id'][i: i + pred_bs]
            n_batch = content_id.shape[0]
            n_sample = content_id.shape[1] - 1

            padding_mask = (content_id == -1)[:, :n_sample]
            token_idx = padding_mask[:, -1].copy()
            padding_mask[:, -1] = False

            part = que['part'].values[content_id] - 1
            tags = que_proc_int[content_id]
            tag_mask = (tags != 188)
            correct_answer = que['correct_answer'].values[content_id]

            normed_timedelta = inputs['normed_timedelta'][i: i + pred_bs]
            normed_log_timestamp = inputs['normed_log_timestamp'][i: i + pred_bs]
            explanation = inputs['explanation'][i: i + pred_bs]
            correctness = inputs['correctness'][i: i + pred_bs]
            normed_elapsed = inputs['normed_elapsed'][i: i + pred_bs]
            user_answer = inputs['user_answer'][i: i + pred_bs]

            #diff features
            task_container_id_diff = inputs['task_container_id_diff'][i: i + pred_bs]/10000
            content_type_id_diff = inputs['content_type_id_diff'][i: i + pred_bs]

            #preprop
            content_id = torch.from_numpy(content_id.astype(np.int64)).to(args.device)
            content_id[content_id == -1] = 1
            padding_mask = torch.from_numpy(padding_mask).to(args.device)
            token_idx = torch.from_numpy(token_idx).to(args.device)
            part = torch.from_numpy(part.astype(np.int64)).to(args.device)
            tags = torch.from_numpy(tags.astype(np.int64)).to(args.device)
            tag_mask = torch.from_numpy(tag_mask).to(args.device)
            correct_answer = torch.from_numpy(correct_answer.astype(np.int64)).to(args.device)

            normed_timedelta = torch.from_numpy(normed_timedelta).to(args.device)
            normed_log_timestamp = torch.from_numpy(normed_log_timestamp).to(args.device)
            explanation = torch.from_numpy(explanation.astype(np.int64)).to(args.device)
            explanation[explanation == -2] = 1
            explanation[explanation == 2] = 1
            correctness = torch.from_numpy(correctness.astype(np.int64)).to(args.device)
            correctness[correctness == -1] = 1
            correctness[correctness == 2] = 1
            normed_elapsed = torch.from_numpy(normed_elapsed).to(args.device)
            normed_elapsed[normed_elapsed == -999] = 0
            user_answer = torch.from_numpy(user_answer.astype(np.int64)).to(args.device)
            user_answer[user_answer == -1] = 1
            user_answer[user_answer == 4] = 1

            #diff features
            task_container_id_diff = torch.from_numpy(task_container_id_diff.astype(np.float32)).to(args.device)
            content_type_id_diff = torch.from_numpy(content_type_id_diff.astype(np.float32)).to(args.device)

            #generate token
            explanation[token_idx, n_sample - 1] = 2
            correctness[token_idx, n_sample - 1] = 2
            normed_elapsed[token_idx, n_sample - 1] = 0
            user_answer[token_idx, n_sample - 1] = 4

            #for conv process
            explanation[:, :-1][padding_mask] = 2
            correctness[:, :-1][padding_mask] = 2
            normed_elapsed[:, :-1][padding_mask] = 0
            normed_timedelta[:, :-1][padding_mask] = 0
            normed_log_timestamp[:, :-1][padding_mask] = 0

            #query, key and value
            emb_content_id = self.embedder_content_id(content_id)
            ohe_part = F.one_hot(part, num_classes = 7)
            ohe_tags = F.one_hot(tags, num_classes = 189)
            ohe_tags = ohe_tags * tag_mask.unsqueeze(3)
            ohe_tags_sum = ohe_tags.sum(dim = 2)
            ohe_correct_answer = F.one_hot(correct_answer, num_classes = 4)

            #key and value
            ohe_explanation = F.one_hot(explanation, num_classes = 3)
            ohe_correctness = F.one_hot(correctness, num_classes = 3)
            ohe_user_answer = F.one_hot(user_answer, num_classes = 5)

            # relative positional encoding
            position = torch.from_numpy(((np.arange(args.n_length + 1) - args.n_length)/(args.n_length)).astype(np.float32))
            position = position.unsqueeze(0).repeat(n_batch, 1).to(args.device)

            #query process
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2)], dim = 2)
            query = self.norm1(self.linear2(self.dropout(F.relu(self.linear1(query_cat)))))

            #key and value process
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm2(self.linear4(self.dropout(F.relu(self.linear3(memory_cat)))))
            memory[:, :-1][padding_mask] = 0
            memory[:, -1] = 0
            out_rnn, _ = self.rnn(memory)
            out_rnn[:, -1] = 0

            task_container_id = inputs['task_container_id'][i: i + pred_bs]
            memory_idx = cget_memory_indices(task_container_id)
            memory_idx_ = memory_idx + (np.arange(memory_idx.shape[0]) * memory_idx.shape[1])[:, None]
            out_rnn = out_rnn.contiguous().view(n_batch * (n_sample + 1), -1)
            memory_rnn = out_rnn[torch.from_numpy(memory_idx_).to(args.device)]

            #query process 2 
            query_cat = torch.cat([emb_content_id, ohe_part, ohe_tags_sum, 
                                   normed_timedelta.unsqueeze(2), normed_log_timestamp.unsqueeze(2), 
                                   ohe_correct_answer, task_container_id_diff.unsqueeze(2), 
                                   content_type_id_diff.unsqueeze(2), position.unsqueeze(2), memory_rnn], dim = 2)
            query = self.norm3(self.linear6(self.dropout(F.relu(self.linear5(query_cat)))))

            #key and value process 2
            query_clone = query.clone()
            query_clone[token_idx, -2] = 0
            memory_cat = torch.cat([query_clone, ohe_explanation, ohe_correctness, 
                                    normed_elapsed.unsqueeze(2), ohe_user_answer], dim = 2)
            memory = self.norm4(self.linear8(self.dropout(F.relu(self.linear7(memory_cat)))))
            memory = memory[:, :-1]

            #transpose
            query = query.transpose(0, 1)
            memory = memory.transpose(0, 1)
            spc_query = query[-1:]

            #MyEncoder
            out = self.MyEncoderLayer1(spc_query, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer2(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer3(out, memory, memory, None, padding_mask)
            out = self.MyEncoderLayer4(out, memory, memory, None, padding_mask).squeeze(0)

            #features
            features = inputs['features'][i: i + pred_bs].copy()
            features[:, args.log_indices] = np.log1p(features[:, args.log_indices])
            features[:, args.stdize_indices] = (features[:, args.stdize_indices] - args.stdize_mu_std[0][None, :])/args.stdize_mu_std[1][None, :]
            features[np.isnan(features)] = 0
            features = torch.from_numpy(features).to(args.device)

            #concat and last fc
            out_cat = torch.cat([out, features, memory_rnn[:, -1]], dim = 1)
            out_cat = self.last_fc2(self.dropout(F.relu(self.last_fc1(out_cat))))
            score = torch.sigmoid(out_cat[:, 0])
            scores[i: i + pred_bs] = score
        return scores.cpu().numpy()

#func_cpu
@noglobal
def online_get_user_id_content_id_task_container_id(tes):
    f_tes = tes[tes['content_type_id'] == 0][['user_id', 
                                           'content_id', 'task_container_id']].astype('float32').reset_index(drop = True)
    return f_tes

@noglobal
def online_get_timestamp_and_prior_question_elapsed_time_and_prior_question_had_explanation(tes):
    f_tes = tes[tes['content_type_id'] == 0][['timestamp', 
                                              'prior_question_elapsed_time', 'prior_question_had_explanation']].astype('float32').reset_index(drop = True)
    return f_tes

@noglobal
def online_get_part(tes, que):
    spc_tes_cp = tes[tes['content_type_id'] == 0].copy().reset_index(drop = True)
    f_tes = pd.DataFrame(que['part'].values[spc_tes_cp['content_id'].values], columns = ['part']).astype(np.float32)
    return f_tes

@noglobal
def online_get_correct_answer(tes, que):
    spc_tes_cp = tes[tes['content_type_id'] == 0].copy().reset_index(drop = True)
    f_tes = pd.DataFrame(que['correct_answer'].values[spc_tes_cp['content_id'].values], columns = ['correct_answer']).astype(np.float32)
    return f_tes

@noglobal
def online_get_modified_timedelta(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    uniq, enc = np.unique(spc_tes['user_id'].values, return_inverse = True)
    spc_tes['count'] = np.bincount(enc)[enc]
    res = (spc_tes['timestamp'].values - r_ref[user_map_ref[spc_tes['user_id'].values]])/spc_tes['count'].values
    return pd.DataFrame(res, columns = ['modified_timedelta']).astype(np.float32)

@noglobal
def online_get_tags(tes, que_proc_2):
    spc_tes_cp = tes[tes['content_type_id'] == 0].copy().reset_index(drop = True)
    return pd.DataFrame(que_proc_2.values[spc_tes_cp['content_id'].values], columns = ['tags_' + str(i) for i in range(6)])

@noglobal
def online_get_hell_rolling_mean_for_tags(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[:, :-1][user_map_ref[spc_tes['user_id'].values]]
    f_tes = pd.DataFrame(sum_count[:, :, 0]/sum_count[:, :, 1], columns = ['hell_rolling_mean_for_tag_' + str(i) for i in range(188)])
    return f_tes

@noglobal
def online_get_rolling_mean_sum_count(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    mean = sum_count[:, 0]/sum_count[:, 1]
    res = np.zeros((spc_tes.shape[0], 3), dtype = np.float32)
    res[:, 0] = mean
    res[:, 1:] = sum_count
    res = np.zeros((spc_tes.shape[0], 3), dtype = np.float32)
    res[:, 0] = mean
    res[:, 1:] = sum_count
    return pd.DataFrame(res, columns = ['target_full_mean', 'target_full_sum', 'target_count'])

@noglobal
def update_reference_get_rolling_mean_sum_count(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    #prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    cts = spc_prev_tes_cp['user_id'].value_counts()
    pos_cts = spc_prev_tes_cp['user_id'][spc_prev_tes_cp['answered_correctly'] == 1].value_counts()
    r_ref[user_map_ref[pos_cts.index], 0] += pos_cts.values
    r_ref[user_map_ref[cts.index], 1] += cts.values
    return r_ref

@noglobal
def online_get_rolling_mean_sum_count_for_6_tags_and_whole_tag(r_ref, tes, que_proc, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    col_idx = que_proc.values[spc_tes['content_id'].values]
    users = user_map_ref[spc_tes['user_id'].values]
    row_idx = np.repeat(users[:, None], 6, axis = 1)

    res = np.zeros((spc_tes.shape[0], 7, 3), dtype = np.float32)
    sliced = r_ref[row_idx, col_idx]
    res[:, :6, 1:] = sliced
    res[:, 6, 1:] = np.nansum(sliced, axis = 1)
    res[:, :, 0] = res[:, :, 1]/res[:, :, 2]
    res = res.reshape(-1, 3 * 7)
    f_names = sum([['tags_order_mean_' + str(i), 'tags_order_sum_' + str(i), 'tags_order_count_' + str(i)] for i in range(6)], [])
    f_names += ['whole_tags_order_mean', 'whole_tags_order_sum', 'whole_tags_order_count']
    return pd.DataFrame(res, columns = f_names)

@noglobal
def update_reference_get_rolling_mean_sum_count_for_6_tags_and_whole_tag(r_ref, prev_tes, tes, que_onehot, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    #prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)

    for_sum = que_onehot.values[spc_prev_tes_cp['content_id'].values[spc_prev_tes_cp['answered_correctly'] == 1]]
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values[spc_prev_tes_cp['answered_correctly'] == 1], return_inverse = True)
    sum_cts = np.array([np.bincount(enc, weights = for_sum[:, i]) for i in range(188)]).T
    r_ref[user_map_ref[uniq], :-1, 0] += sum_cts

    for_count = que_onehot.values[spc_prev_tes_cp['content_id']]
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'], return_inverse = True)
    cts = np.array([np.bincount(enc, weights = for_count[:, i]) for i in range(188)]).T
    r_ref[user_map_ref[uniq], :-1, 1] += cts
    return r_ref

@noglobal
def online_get_rolling_mean_sum_count_for_part(r_ref, tes, que, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    spc_tes['part'] = que['part'].values[spc_tes['content_id'].values]
    res = np.zeros((spc_tes.shape[0], 8, 3), dtype = np.float32)
    res[:, :7, 1:] = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res[:, 7, 1:] = r_ref[user_map_ref[spc_tes['user_id'].values], spc_tes['part'] - 1]
    res[:, :, 0] = res[:, :, 1]/res[:, :, 2]
    res = res.reshape(-1, 3 * 8)
    f_names = sum([['part_' + str(i) + '_mean', 'part_' + str(i) + '_sum', 'part_' + str(i) + '_count'] for i in range(1, 8)], []) + \
    ['part_cut_mean', 'part_cut_sum', 'part_cut_count']
    return pd.DataFrame(res, columns = f_names)

@noglobal
def update_reference_get_rolling_mean_sum_count_for_part(r_ref, prev_tes, tes, que_part_onehot, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    
    for_sum = que_part_onehot.values[spc_prev_tes_cp['content_id'].values[spc_prev_tes_cp['answered_correctly'] == 1]]
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values[spc_prev_tes_cp['answered_correctly'] == 1], return_inverse = True)
    sum_cts = np.array([np.bincount(enc, weights = for_sum[:, i]) for i in range(7)]).T
    r_ref[user_map_ref[uniq], :, 0] += sum_cts
    
    for_count = que_part_onehot.values[spc_prev_tes_cp['content_id']]
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'], return_inverse = True)
    cts = np.array([np.bincount(enc, weights = for_count[:, i]) for i in range(7)]).T
    r_ref[user_map_ref[uniq], :, 1] += cts
    return r_ref

@noglobal
def online_get_lec_rolling_count(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    return pd.DataFrame(r_ref[user_map_ref[spc_tes['user_id'].values]], columns = ['lec_rolling_count'])

@noglobal
def update_reference_get_lec_rolling_count(r_ref, prev_tes, tes, user_map_ref):
    spc_prev_tes = prev_tes[prev_tes['content_type_id'] == 1]
    r_ref[user_map_ref[spc_prev_tes['user_id'].values]] += 1
    return r_ref

@noglobal
def online_get_lec_part_rolling_count(r_ref, tes, user_map_ref, que):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = np.zeros((spc_tes.shape[0], 8), dtype = np.float32)
    part_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res[:, :7] = part_count
    res[:, 7] = part_count[np.arange(spc_tes.shape[0]), que['part'].values[spc_tes['content_id'].values] - 1]
    return pd.DataFrame(res, columns = ['lec_part_' + str(i + 1) for i in range(7)] + ['lec_part_cut'])

@noglobal
def update_reference_get_lec_part_rolling_count(r_ref, prev_tes, tes, user_map_ref, lec, lec_map):
    spc_prev_tes = prev_tes[prev_tes['content_type_id'] == 1]
    row_idx = user_map_ref[spc_prev_tes['user_id']]
    col_idx = lec['part'].values[lec_map[spc_prev_tes['content_id'].values]] - 1
    r_ref[row_idx, col_idx] += 1
    return r_ref

@noglobal
def online_get_lec_type_of_rolling_count(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = r_ref[user_map_ref[spc_tes['user_id'].values]]
    return pd.DataFrame(res, columns = ['lec_type_of_' + str(i) for i in range(4)]).astype(np.float32)

@noglobal
def update_reference_get_lec_type_of_rolling_count(r_ref, prev_tes, tes, user_map_ref, lec_proc, lec_map):
    spc_prev_tes = prev_tes[prev_tes['content_type_id'] == 1]
    row_idx = user_map_ref[spc_prev_tes['user_id']]
    col_idx = lec_proc['type_of'].values[lec_map[spc_prev_tes['content_id'].values]]
    r_ref[row_idx, col_idx] += 1
    return r_ref

@noglobal
def online_get_lec_tags_rolling_count(r_ref, tes, user_map_ref, que_proc):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = np.zeros((spc_tes.shape[0], 7), dtype = np.float32)
    users = user_map_ref[spc_tes['user_id'].values]
    row_idx = np.repeat(users[:, None], 6, axis = 1)
    col_idx = que_proc.values[spc_tes['content_id'].values]
    
    sliced = r_ref[row_idx, col_idx]
    res[:, :6] = sliced
    res[:, 6] = np.nansum(sliced, axis = 1)
    return pd.DataFrame(res, columns = sum([['lec_tags_order_count_' + str(i)] for i in range(6)], []) + ['lec_whole_tags_order_count'])

@noglobal
def update_reference_get_lec_tags_rolling_count(r_ref, prev_tes, tes, user_map_ref, lec, lec_map):
    prev_tes_cp = prev_tes.copy()
    #prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    #prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 1].copy().reset_index(drop = True)
    col_idx = lec['tag'].values[lec_map[spc_prev_tes_cp['content_id'].values]]
    row_idx = user_map_ref[spc_prev_tes_cp['user_id'].values]
    r_ref[row_idx, col_idx] += 1
    return r_ref

@noglobal 
def online_get_rolling_mean_sum_count_for_content_id(r_ref_sum, r_ref_count, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    row_idx = user_map_ref[spc_tes['user_id'].values]
    col_idx = spc_tes['content_id'].values
    f_sum = r_ref_sum[row_idx, col_idx].toarray()[0]
    f_count = r_ref_count[row_idx, col_idx].toarray()[0]
    f_mean = f_sum/f_count
    f_tr = pd.DataFrame()
    f_tr['content_id_cut_mean'] = f_mean.astype(np.float32)
    f_tr['content_id_cut_sum'] = f_sum.astype(np.float32)
    f_tr['content_id_cut_count'] = f_count.astype(np.float32)
    return f_tr

@noglobal
def update_reference_get_rolling_mean_sum_count_for_content_id(r_ref_sum, r_ref_count, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    contents = spc_prev_tes_cp['content_id'].values
    targets = spc_prev_tes_cp['answered_correctly'].values
    for user, content, target in zip(users, contents, targets):
        r_ref_sum[user, content] = r_ref_sum[user, content] + target
        r_ref_count[user, content] = r_ref_count[user, content] + 1
    return r_ref_sum, r_ref_count

@noglobal
def online_get_timestamp_diff(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = spc_tes['timestamp'].values - r_ref[user_map_ref[spc_tes['user_id'].values]]
    return pd.DataFrame(res, columns = ['timestamp_diff']).astype(np.float32)

@noglobal
def update_reference_get_timestamp_diff(r_ref, prev_tes, tes, user_map_ref):
    r_ref[user_map_ref[prev_tes['user_id'].values]] = prev_tes['timestamp'].values
    return r_ref

@noglobal
def online_get_whole_oof_target_encoding_content_id(r_ref, tes):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[spc_tes['content_id'].values]
    res = (sum_count[:, 0]/sum_count[:, 1]).astype(np.float32)
    return pd.DataFrame(res, columns = ['whole_oof_target_encoding_content_id'])

@noglobal
def update_reference_get_whole_oof_target_encoding_content_id(r_ref, prev_tes, tes):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    cts = spc_prev_tes_cp['content_id'].value_counts()
    pos_cts = spc_prev_tes_cp['content_id'][spc_prev_tes_cp['answered_correctly'] == 1].value_counts()
    r_ref[pos_cts.index.values, 0] += pos_cts.values
    r_ref[cts.index.values, 1] += cts.values
    return r_ref

@noglobal
def online_get_whole_oof_target_encoding_tags_order_and_whole_tags(r_ref, tes, que_proc):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[que_proc.values[spc_tes['content_id'].values]]
    mean = sum_count[:, :, 0]/sum_count[:, :, 1]
    nansum = np.nansum(sum_count, axis = 1)
    whole_mean = nansum[:, 0]/nansum[:, 1]
    res = np.zeros((spc_tes.shape[0], 7), dtype = np.float32)
    res[:, :6] = mean
    res[:, 6] = whole_mean
    f_names = ['whole_oof_target_encoding_tags_order_' + str(i) for i in range(6)] + ['whole_oof_target_encoding_whole_tags']
    return pd.DataFrame(res, columns = f_names)

@noglobal
def update_reference_get_whole_oof_target_encoding_tags_order_and_whole_tags(r_ref, prev_tes, tes, que_onehot):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    #prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    for_sum = que_onehot.values[spc_prev_tes_cp['content_id'].values[spc_prev_tes_cp['answered_correctly'] == 1]]
    sum_cts = for_sum.sum(axis = 0)
    r_ref[:-1, 0] += sum_cts

    for_count = que_onehot.values[spc_prev_tes_cp['content_id'].values]
    cts = for_count.sum(axis = 0)
    r_ref[:-1, 1] += cts
    return r_ref

@noglobal
def online_get_correct_answer(tes, que):
    spc_tes_cp = tes[tes['content_type_id'] == 0].copy().reset_index(drop = True)
    f_tes = pd.DataFrame(que['correct_answer'].values[spc_tes_cp['content_id'].values], columns = ['correct_answer']).astype(np.float32)
    return f_tes

@noglobal
def online_get_modified_timedelta(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    uniq, enc = np.unique(spc_tes['user_id'].values, return_inverse = True)
    spc_tes['count'] = np.bincount(enc)[enc]
    res = (spc_tes['timestamp'].values - r_ref[user_map_ref[spc_tes['user_id'].values]])/spc_tes['count'].values
    return pd.DataFrame(res, columns = ['modified_timedelta']).astype(np.float32)

@noglobal
def update_reference_get_norm_rolling_count_and_cut_for_user_answer(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)

    for_count = np.zeros((spc_prev_tes_cp.shape[0], 4))
    for_count[np.arange(for_count.shape[0]), spc_prev_tes_cp['user_answer'].values] = 1
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values, return_inverse = True)
    cts = np.array([np.bincount(enc, weights = for_count[:, i]) for i in range(4)]).T
    r_ref[user_map_ref[uniq]] += cts
    return r_ref

@noglobal
def online_get_norm_rolling_count_and_cut_for_user_answer(r_ref, tes, que, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = np.zeros((spc_tes.shape[0], 5), dtype = np.float32)
    cts = r_ref[user_map_ref[spc_tes['user_id'].values]]
    norm_cts = cts/cts.sum(axis = 1)[:, None]
    res[:, :4] = norm_cts
    correct_answer = que['correct_answer'].values[spc_tes['content_id'].values]
    res[:, 4] = norm_cts[np.arange(norm_cts.shape[0]), correct_answer]
    f_names = ['norm_rolling_count_user_answer_' + str(i) for i in range(4)] + ['cut_norm_rolling_count_user_answer']
    return pd.DataFrame(res, columns = f_names)

@noglobal
def online_get_rolling_mean_for_prior_question_elapsed_time(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res = sum_count[:, 0]/sum_count[:, 1]
    return pd.DataFrame(res, columns = ['rolling_mean_for_prior_question_elapsed_time']).astype(np.float32)

@noglobal
def early_update_reference_get_rolling_mean_for_prior_question_elapsed_time(r_ref, tes, user_map_ref):
    tes_cp = tes.copy()
    spc_tes_cp = tes_cp[(tes_cp['content_type_id'] == 0) \
                        & (~tes_cp['prior_question_elapsed_time'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    r_ref[row_idx, 0] += spc_tes_cp['prior_question_elapsed_time'].values * r_ref[row_idx, 2]
    r_ref[row_idx, 1] += r_ref[row_idx, 2]
    uniq, enc = np.unique(tes_cp[tes_cp['content_type_id'] == 0]['user_id'].values, return_inverse = True)
    r_ref[user_map_ref[uniq], 2] = np.bincount(enc)
    return r_ref

@noglobal
def online_get_rolling_mean_for_prior_question_had_explanation(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res = sum_count[:, 0]/sum_count[:, 1]
    return pd.DataFrame(res, columns = ['rolling_mean_for_prior_question_had_explanation']).astype(np.float32)

@noglobal
def early_update_reference_get_rolling_mean_for_prior_question_had_explanation(r_ref, tes, user_map_ref):
    tes_cp = tes.copy()
    spc_tes_cp = tes_cp[(tes_cp['content_type_id'] == 0) \
                        & (~tes_cp['prior_question_had_explanation'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    r_ref[row_idx, 0] += spc_tes_cp['prior_question_had_explanation'].values.astype('int') * r_ref[row_idx, 2]
    r_ref[row_idx, 1] += r_ref[row_idx, 2]
    uniq, enc = np.unique(tes_cp[tes_cp['content_type_id'] == 0]['user_id'].values, return_inverse = True)
    r_ref[user_map_ref[uniq], 2] = np.bincount(enc)
    return r_ref

@noglobal
def online_get_rolling_sum_for_prior_question_isnull(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    res = r_ref[user_map_ref[spc_tes['user_id'].values]]
    return pd.DataFrame(res, columns = ['rolling_sum_for_prior_question_isnull']).astype(np.float32)

@noglobal
def update_reference_get_rolling_sum_for_prior_question_isnull(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes[(prev_tes['content_type_id'] == 0) & (prev_tes['prior_question_elapsed_time'].isnull())].copy()
    r_ref[user_map_ref[prev_tes_cp['user_id'].values]] += prev_tes_cp['prior_question_elapsed_time'].isnull().astype('int').values
    return r_ref

@noglobal
def online_get_positive_rolling_mean_for_prior_question_elapsed_time(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res = sum_count[:, 0]/sum_count[:, 1]
    return pd.DataFrame(res, columns = ['positive_rolling_mean_for_prior_question_elapsed_time']).astype(np.float32)

@noglobal
def early_update_reference_get_positive_rolling_mean_for_prior_question_elapsed_time(r_ref, tes, user_map_ref):
    tes_cp = tes.copy()
    spc_tes_cp = tes_cp[(tes_cp['content_type_id'] == 0) \
                        & (~tes_cp['prior_question_elapsed_time'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    r_ref[row_idx, 0] += spc_tes_cp['prior_question_elapsed_time'].values * r_ref[row_idx, 2]
    r_ref[row_idx, 1] += r_ref[row_idx, 2]
    return r_ref

@noglobal
def update_reference_get_positive_rolling_mean_for_prior_question_elapsed_time(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values, return_inverse = True)
    sum_cts = np.bincount(enc, weights = spc_prev_tes_cp['answered_correctly'].values)
    r_ref[user_map_ref[uniq], 2] = sum_cts
    return r_ref

@noglobal
def online_get_negative_rolling_mean_for_prior_question_elapsed_time(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res = sum_count[:, 0]/sum_count[:, 1]
    return pd.DataFrame(res, columns = ['negative_rolling_mean_for_prior_question_elapsed_time']).astype(np.float32)

@noglobal
def early_update_reference_get_negative_rolling_mean_for_prior_question_elapsed_time(r_ref, tes, user_map_ref):
    tes_cp = tes.copy()
    spc_tes_cp = tes_cp[(tes_cp['content_type_id'] == 0) \
                        & (~tes_cp['prior_question_elapsed_time'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    r_ref[row_idx, 0] += spc_tes_cp['prior_question_elapsed_time'].values * r_ref[row_idx, 2]
    r_ref[row_idx, 1] += r_ref[row_idx, 2]
    return r_ref

@noglobal
def update_reference_get_negative_rolling_mean_for_prior_question_elapsed_time(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values, return_inverse = True)
    sum_cts = np.bincount(enc, weights = 1 - spc_prev_tes_cp['answered_correctly'].values)
    r_ref[user_map_ref[uniq], 2] = sum_cts
    return r_ref

@noglobal
def online_get_positive_rolling_mean_for_prior_question_had_explanation(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res = sum_count[:, 0]/sum_count[:, 1]
    return pd.DataFrame(res, columns = ['positive_rolling_mean_for_prior_question_had_explanation']).astype(np.float32)

@noglobal
def early_update_reference_get_positive_rolling_mean_for_prior_question_had_explanation(r_ref, tes, user_map_ref):
    tes_cp = tes.copy()
    spc_tes_cp = tes_cp[(tes_cp['content_type_id'] == 0) \
                        & (~tes_cp['prior_question_had_explanation'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    r_ref[row_idx, 0] += spc_tes_cp['prior_question_had_explanation'].values * r_ref[row_idx, 2]
    r_ref[row_idx, 1] += r_ref[row_idx, 2]
    return r_ref

@noglobal
def update_reference_get_positive_rolling_mean_for_prior_question_had_explanation(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values, return_inverse = True)
    sum_cts = np.bincount(enc, weights = spc_prev_tes_cp['answered_correctly'].values)
    r_ref[user_map_ref[uniq], 2] = sum_cts
    return r_ref

@noglobal
def online_get_negative_rolling_mean_for_prior_question_had_explanation(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    sum_count = r_ref[user_map_ref[spc_tes['user_id'].values]]
    res = sum_count[:, 0]/sum_count[:, 1]
    return pd.DataFrame(res, columns = ['negative_rolling_mean_for_prior_question_had_explanation']).astype(np.float32)

@noglobal
def early_update_reference_get_negative_rolling_mean_for_prior_question_had_explanation(r_ref, tes, user_map_ref):
    tes_cp = tes.copy()
    spc_tes_cp = tes_cp[(tes_cp['content_type_id'] == 0) \
                        & (~tes_cp['prior_question_had_explanation'].isnull())].copy().reset_index(drop = True)
    row_idx = user_map_ref[spc_tes_cp['user_id'].values]
    r_ref[row_idx, 0] += spc_tes_cp['prior_question_had_explanation'].values * r_ref[row_idx, 2]
    r_ref[row_idx, 1] += r_ref[row_idx, 2]
    return r_ref

@noglobal
def update_reference_get_negative_rolling_mean_for_prior_question_had_explanation(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    uniq, enc = np.unique(spc_prev_tes_cp['user_id'].values, return_inverse = True)
    sum_cts = np.bincount(enc, weights = 1 - spc_prev_tes_cp['answered_correctly'].values)
    r_ref[user_map_ref[uniq], 2] = sum_cts
    return r_ref

@noglobal
def online_get_diff_modified_timedelta_and_prior_elapsed_time_rolling_mean(f_tes_base, f_tes_pn, f_tes_p, f_tes_n):
    f_names = ['rolling_mean_for_prior_question_elapsed_time_diff_modified_timedelta', 
                                   'positive_rolling_mean_for_prior_question_elapsed_time_diff_modified_timedelta', 
                                   'negative_rolling_mean_for_prior_question_elapsed_time_diff_modified_timedelta']
    f_tes = pd.DataFrame()
    f_tes[f_names[0]] = f_tes_base.values[:, 0] - f_tes_pn.values[:, 0]
    f_tes[f_names[1]] = f_tes_base.values[:, 0] - f_tes_p.values[:, 0]
    f_tes[f_names[2]] = f_tes_base.values[:, 0] - f_tes_n.values[:, 0]
    return f_tes

@noglobal
def online_get_n_samples_rolling_mean(r_ref, tes, user_map_ref):
    n_samples = np.array([1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100, 200])
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    users = spc_tes['user_id'].values
    history = r_ref[user_map_ref[users]].astype(np.float32)
    history[history == -1] = np.nan
    res = np.zeros((spc_tes.shape[0], len(n_samples)), dtype = np.float32)
    for i, n_sample in enumerate(n_samples):
        res[:, i] = np.nanmean(history[:, -n_sample:], axis = 1)
    return pd.DataFrame(res, columns = [str(i) + '_samples_rolling_mean' for i in n_samples])

@noglobal
def update_reference_get_n_samples_rolling_mean(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    #prev_tes_cp['user_answer'] = ast.literal_eval(tes['prior_group_responses'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    enc_users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    targets = spc_prev_tes_cp['answered_correctly'].values
    for user, target in zip(enc_users, targets):
        r_ref[user, :-1] = r_ref[user, 1:]
        r_ref[user, -1] = target
    return r_ref

@noglobal 
def online_get_rolling_mean_sum_count_for_content_id_darkness(r_ref, tes, user_map_ref):
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    row_idx = user_map_ref[spc_tes['user_id'].values]
    col_idx = spc_tes['content_id'].values
    f_base = r_ref[row_idx, col_idx].toarray()[0]
    f_sum = f_base % 128
    f_count = f_base // 128
    f_mean = f_sum/f_count
    f_tr = pd.DataFrame()
    f_tr['content_id_cut_mean'] = f_mean.astype(np.float32)
    f_tr['content_id_cut_sum'] = f_sum.astype(np.float32)
    f_tr['content_id_cut_count'] = f_count.astype(np.float32)
    return f_tr

@noglobal
def update_reference_get_rolling_mean_sum_count_for_content_id_darkness(r_ref, prev_tes, tes, user_map_ref):
    prev_tes_cp = prev_tes.copy()
    prev_tes_cp['answered_correctly'] = ast.literal_eval(tes['prior_group_answers_correct'].iloc[0])
    spc_prev_tes_cp = prev_tes_cp[prev_tes_cp['content_type_id'] == 0].copy().reset_index(drop = True)
    users = user_map_ref[spc_prev_tes_cp['user_id'].values]
    contents = spc_prev_tes_cp['content_id'].values
    targets = spc_prev_tes_cp['answered_correctly'].values
    for user, content, target in zip(users, contents, targets):
        r_ref[user, content] = r_ref[user, content] + target + 128
    return r_ref

In [6]:
#initialization
#comment outで入れ替える
#prefix = '..'
prefix = '/kaggle/input/mamastan-gpu-v27'
#sn = pd.read_pickle('../others/ex_tes.pkl'); env = RiiidEnv(sn, iterate_wo_predict = False)
import riiideducation; env = riiideducation.make_env()
iter_test = env.iter_test()
gc.collect()

20

In [7]:
#params162
args = OrderedDict()
args.n_length = 200
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.n_conv = 30
args.pred_bs = 256
args.device = torch.device('cuda')

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']
log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params
args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp162 = deepcopy(args)
del args;

#params166
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.n_conv = 30
args.pred_bs = 256
args.device = torch.device('cuda')

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']
log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params
args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp166 = deepcopy(args)
del args;

#params184
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.n_conv = 30
args.device = torch.device('cuda')
args.pred_bs = 256

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']

log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']

stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']

stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params

args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp184 = deepcopy(args)
del args;

#params218
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.n_conv = 30
args.device = torch.device('cuda')
args.pred_bs = 256

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']

log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']

stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']

stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params

args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp218 = deepcopy(args)
del args;

#params224
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.n_conv = 30
args.device = torch.device('cuda')
args.pred_bs = 256

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']
log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params

args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp224 = deepcopy(args)
del args;

#params219
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.lr = 0.2 * 1e-3
args.num_features = 90
args.device = torch.device('cuda')
args.pred_bs = 256

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']
log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params

args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp219 = deepcopy(args)
del args;

#params221
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.device = torch.device('cuda')
args.pred_bs = 256

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']
log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params

args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp221 = deepcopy(args)
del args;

#params222
args = OrderedDict()
args.n_length = 400
args.emb_dim = 512
args.dim_feedforward = 1028
args.nhead = 4
args.dropout = 0.2
args.dropout_pe = 0
args.num_features = 90
args.device = torch.device('cuda')
args.pred_bs = 256

f_names = ['target_full_mean',
 'target_count',
 'tags_order_mean_0',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_mean_1',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_mean_2',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_mean_3',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_mean_4',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_mean_5',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_mean',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_mean',
 'part_1_sum',
 'part_1_count',
 'part_2_mean',
 'part_2_sum',
 'part_2_count',
 'part_3_mean',
 'part_3_sum',
 'part_3_count',
 'part_4_mean',
 'part_4_sum',
 'part_4_count',
 'part_5_mean',
 'part_5_sum',
 'part_5_count',
 'part_6_mean',
 'part_6_sum',
 'part_6_count',
 'part_7_mean',
 'part_7_sum',
 'part_7_count',
 'part_cut_mean',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_mean',
 'content_id_cut_sum',
 'content_id_cut_count',
 'norm_rolling_count_user_answer_0',
 'norm_rolling_count_user_answer_1',
 'norm_rolling_count_user_answer_2',
 'norm_rolling_count_user_answer_3',
 'cut_norm_rolling_count_user_answer',
 'rolling_mean_for_prior_question_elapsed_time',
 'rolling_mean_for_prior_question_had_explanation',
 'rolling_sum_for_prior_question_isnull',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_had_explanation',
 'negative_rolling_mean_for_prior_question_had_explanation',
 '1_samples_rolling_mean',
 '2_samples_rolling_mean',
 '3_samples_rolling_mean',
 '4_samples_rolling_mean',
 '5_samples_rolling_mean',
 '10_samples_rolling_mean',
 '20_samples_rolling_mean',
 '30_samples_rolling_mean',
 '40_samples_rolling_mean',
 '50_samples_rolling_mean',
 '100_samples_rolling_mean',
 '200_samples_rolling_mean']
log_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'lec_part_1',
 'lec_part_2',
 'lec_part_3',
 'lec_part_4',
 'lec_part_5',
 'lec_part_6',
 'lec_part_7',
 'lec_part_cut',
 'lec_type_of_0',
 'lec_type_of_1',
 'lec_type_of_2',
 'lec_tags_order_count_0',
 'lec_tags_order_count_1',
 'lec_tags_order_count_2',
 'lec_whole_tags_order_count',
 'content_id_cut_sum',
 'content_id_cut_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_f_names = [
 'target_count',
 'tags_order_sum_0',
 'tags_order_count_0',
 'tags_order_sum_1',
 'tags_order_count_1',
 'tags_order_sum_2',
 'tags_order_count_2',
 'tags_order_sum_3',
 'tags_order_count_3',
 'tags_order_sum_4',
 'tags_order_count_4',
 'tags_order_sum_5',
 'tags_order_count_5',
 'whole_tags_order_sum',
 'whole_tags_order_count',
 'part_1_sum',
 'part_1_count',
 'part_2_sum',
 'part_2_count',
 'part_3_sum',
 'part_3_count',
 'part_4_sum',
 'part_4_count',
 'part_5_sum',
 'part_5_count',
 'part_6_sum',
 'part_6_count',
 'part_7_sum',
 'part_7_count',
 'part_cut_sum',
 'part_cut_count',
 'lec_rolling_count',
 'rolling_mean_for_prior_question_elapsed_time',
 'positive_rolling_mean_for_prior_question_elapsed_time',
 'negative_rolling_mean_for_prior_question_elapsed_time']
stdize_params = {'target_count': (0, 5.929629),
 'tags_order_sum_0': (0, 2.034838),
 'tags_order_count_0': (0, 2.360298),
 'tags_order_sum_1': (0, 2.352074),
 'tags_order_count_1': (0, 2.646724),
 'tags_order_sum_2': (0, 3.368914),
 'tags_order_count_2': (0, 3.698544),
 'tags_order_sum_3': (0, 3.515896),
 'tags_order_count_3': (0, 3.854377),
 'tags_order_sum_4': (0, 3.355033),
 'tags_order_count_4': (0, 3.719298),
 'tags_order_sum_5': (0, 3.871832),
 'tags_order_count_5': (0, 4.218161),
 'whole_tags_order_sum': (0, 3.005077),
 'whole_tags_order_count': (0, 3.356816),
 'part_1_sum': (0, 2.838132),
 'part_1_count': (0, 3.105394),
 'part_2_sum': (0, 3.631354),
 'part_2_count': (0, 3.947554),
 'part_3_sum': (0, 2.574938),
 'part_3_count': (0, 2.8732),
 'part_4_sum': (0, 2.353698),
 'part_4_count': (0, 2.740394),
 'part_5_sum': (0, 4.350539),
 'part_5_count': (0, 4.86018),
 'part_6_sum': (0, 2.64874),
 'part_6_count': (0, 2.971092),
 'part_7_sum': (0, 1.864806),
 'part_7_count': (0, 2.166899),
 'part_cut_sum': (0, 4.154958),
 'part_cut_count': (0, 4.582286),
 'lec_rolling_count': (0, 2.10689),
 'rolling_mean_for_prior_question_elapsed_time': (0, 10.124079),
 'positive_rolling_mean_for_prior_question_elapsed_time': (0, 10.075675),
 'negative_rolling_mean_for_prior_question_elapsed_time': (0, 10.211306)}

args.f_names = f_names
args.log_f_names = log_f_names
args.stdize_f_names = stdize_f_names
args.stdize_params = stdize_params

args.log_indices = np.where(np.in1d(args.f_names, args.log_f_names))[0]
args.stdize_indices = [np.where(el == np.array(args.f_names))[0][0] for el in args.stdize_f_names]
args.stdize_mu_std = np.array([args.stdize_params[el] for el in args.stdize_f_names]).T
del f_names, log_f_names, stdize_f_names, stdize_params; gc.collect();
args_exp222 = deepcopy(args)
del args;

In [8]:
#load data
my_encoder_exp162 = MyEncoderExp162(args_exp162)
my_encoder_exp162.load_state_dict(torch.load(prefix + '/models/my_encoder_exp162_best.pth'))
my_encoder_exp162.to(args_exp162.device)
my_encoder_exp162.eval();

my_encoder_exp166 = MyEncoderExp166(args_exp166)
my_encoder_exp166.load_state_dict(torch.load(prefix + '/models/my_encoder_exp166_best.pth'))
my_encoder_exp166.to(args_exp166.device)
my_encoder_exp166.eval();

my_encoder_exp184 = MyEncoderExp184(args_exp184)
my_encoder_exp184.load_state_dict(torch.load(prefix + '/models/my_encoder_exp184_best.pth'))
my_encoder_exp184.to(args_exp184.device)
my_encoder_exp184.eval();

my_encoder_exp218 = MyEncoderExp218(args_exp218)
my_encoder_exp218.load_state_dict(torch.load(prefix + '/models/my_encoder_exp248_best.pth'))
my_encoder_exp218.to(args_exp218.device)
my_encoder_exp218.eval();

my_encoder_exp224 = MyEncoderExp224(args_exp224)
my_encoder_exp224.load_state_dict(torch.load(prefix + '/models/my_encoder_exp233_best.pth'))
my_encoder_exp224.to(args_exp224.device)
my_encoder_exp224.eval();

my_encoder_exp222 = MyEncoderExp222(args_exp222)
my_encoder_exp222.load_state_dict(torch.load(prefix + '/models/my_encoder_exp249_best.pth'))
my_encoder_exp222.to(args_exp222.device)
my_encoder_exp222.eval();

que =  pd.read_csv(prefix + '/data/questions.csv').astype({'question_id': 'int16', 'bundle_id': 'int16', 'correct_answer': 'int8', 'part': 'int8'})
que_proc = pd.read_pickle(prefix + '/others/que_proc.pkl')
que_proc_2 = pd.read_pickle(prefix + '/others/que_proc_2.pkl')
que_onehot = pd.read_pickle(prefix + '/others/que_onehot.pkl')
que_part_onehot = pd.read_pickle(prefix + '/others/que_part_onehot.pkl')
lec = pd.read_csv(prefix + '/data/lectures.csv').astype({'lecture_id': 'int16', 'tag': 'int16', 'part': 'int8'})
lec_map = np.load(prefix + '/others/lec_map.npy')
lec_proc = pd.read_pickle(prefix + '/others/lec_proc.pkl')
que_proc_int = np.load(prefix + '/others/que_proc_int.npy')

n_sample = 400
user_map_ref = load_pickle(prefix + '/others/user_map_train.npy')
n_users_ref = np.load(prefix + '/others/n_users_train.npy')
#nn
r_ref_g1 = np.load(prefix + '/references/nn_content_id_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g2 = np.load(prefix + '/references/nn_normed_log_timestamp_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g3 = np.load(prefix + '/references/nn_correctness_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g4 = np.load(prefix + '/references/nn_question_had_explanation_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g5 = np.load(prefix + '/references/nn_normed_elapsed_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g6 = np.load(prefix + '/references/nn_normed_modified_timedelta_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g7 = np.load(prefix + '/references/nn_user_answer_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g8 = np.load(prefix + '/references/nn_task_container_id_diff_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g9 = np.load(prefix + '/references/nn_content_type_id_diff_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_g10 = np.load(prefix + '/references/nn_task_container_id_history' + '_length_' + str(n_sample) + '_train.npy')
r_ref_delta = np.load(prefix + '/references/timestamp_diff_train.npy')
r_ref_delta_2 = np.load(prefix + '/references/task_container_id_diff_train.npy')
r_ref_delta_3 = np.load(prefix + '/references/content_type_id_diff_train.npy')
#cpu
r_ref_1 = np.load(prefix + '/references/rolling_mean_sum_count_train.npy')
r_ref_2 = np.load(prefix + '/references/rolling_mean_sum_count_for_6_tags_and_whole_tag_train.npy')
r_ref_3 = np.load(prefix + '/references/rolling_mean_sum_count_for_part_train.npy')
r_ref_4 = np.load(prefix + '/references/lec_rolling_count_train.npy')
r_ref_5 = np.load(prefix + '/references/lec_part_rolling_count_train.npy')
r_ref_6 = np.load(prefix + '/references/lec_type_of_rolling_count_train.npy')
r_ref_7 = np.load(prefix + '/references/lec_tags_rolling_count_train.npy')
r_ref_8 = load_pickle(prefix + '/references/rolling_mean_sum_count_for_content_id_darkness_train.npy')
r_ref_12 = np.load(prefix + '/references/norm_rolling_count_and_cut_for_user_answer_train.npy')
r_ref_13 = np.load(prefix + '/references/rolling_mean_for_prior_question_elapsed_time_train.npy')
r_ref_14 = np.load(prefix + '/references/rolling_mean_for_prior_question_had_explanation_train.npy')
r_ref_15 = np.load(prefix + '/references/rolling_sum_for_prior_question_isnull_train.npy')
r_ref_16 = np.load(prefix + '/references/positive_rolling_mean_for_prior_question_elapsed_time_train.npy')
r_ref_17 = np.load(prefix + '/references/negative_rolling_mean_for_prior_question_elapsed_time_train.npy')
r_ref_18 = np.load(prefix + '/references/positive_rolling_mean_for_prior_question_had_explanation_train.npy')
r_ref_19 = np.load(prefix + '/references/negative_rolling_mean_for_prior_question_had_explanation_train.npy')
r_ref_20 = np.load(prefix + '/references/n_samples_rolling_mean_train.npy')

In [9]:
# run
i = 0
for (tes, _) in tqdm(iter_test):
    user_map_ref, n_users_ref = update_user_map(tes, user_map_ref, n_users_ref)
    if i != 0:
        # update
        r_ref_1 = update_reference_get_rolling_mean_sum_count(r_ref_1, prev_tes, tes, user_map_ref)
        r_ref_2 = update_reference_get_rolling_mean_sum_count_for_6_tags_and_whole_tag(r_ref_2, prev_tes, tes, que_onehot, user_map_ref)
        r_ref_3 = update_reference_get_rolling_mean_sum_count_for_part(r_ref_3, prev_tes, tes, que_part_onehot, user_map_ref)
        r_ref_4 = update_reference_get_lec_rolling_count(r_ref_4, prev_tes, tes, user_map_ref)
        r_ref_5 = update_reference_get_lec_part_rolling_count(r_ref_5, prev_tes, tes, user_map_ref, lec, lec_map)
        r_ref_6 = update_reference_get_lec_type_of_rolling_count(r_ref_6, prev_tes, tes, user_map_ref, lec_proc, lec_map)
        r_ref_7 = update_reference_get_lec_tags_rolling_count(r_ref_7, prev_tes, tes, user_map_ref, lec, lec_map)
        r_ref_8 = update_reference_get_rolling_mean_sum_count_for_content_id_darkness(r_ref_8, prev_tes, tes, user_map_ref)
        
        r_ref_12 = update_reference_get_norm_rolling_count_and_cut_for_user_answer(r_ref_12, prev_tes, tes, user_map_ref)
        r_ref_15 = update_reference_get_rolling_sum_for_prior_question_isnull(r_ref_15, prev_tes, tes, user_map_ref)
        r_ref_16 = update_reference_get_positive_rolling_mean_for_prior_question_elapsed_time(r_ref_16, prev_tes, tes, user_map_ref)
        r_ref_17 = update_reference_get_negative_rolling_mean_for_prior_question_elapsed_time(r_ref_17, prev_tes, tes, user_map_ref)
        r_ref_18 = update_reference_get_positive_rolling_mean_for_prior_question_had_explanation(r_ref_18, prev_tes, tes, user_map_ref)
        r_ref_19 = update_reference_get_negative_rolling_mean_for_prior_question_had_explanation(r_ref_19, prev_tes, tes, user_map_ref)
        r_ref_20 = update_reference_get_n_samples_rolling_mean(r_ref_20, prev_tes, tes, user_map_ref)

        r_ref_delta = update_reference_get_timestamp_diff(r_ref_delta, prev_tes, tes, user_map_ref)
        r_ref_delta_2 = update_reference_get_task_container_id_diff(r_ref_delta_2, prev_tes, tes, user_map_ref)
        r_ref_delta_3 = update_reference_get_content_type_id_diff(r_ref_delta_3, prev_tes, tes, user_map_ref)
        r_ref_g1 = nn_update_reference_get_content_id_history(r_ref_g1, prev_tes, tes, user_map_ref)
        r_ref_g2 = nn_update_reference_get_normed_log_timestamp_history(r_ref_g2, prev_tes, tes, user_map_ref)
        r_ref_g3 = nn_update_reference_get_correctness_history(r_ref_g3, prev_tes, tes, user_map_ref)
        r_ref_g4 = nn_update_reference_get_question_had_explanation_history(r_ref_g4, prev_tes, tes, user_map_ref)
        r_ref_g5 = nn_update_reference_get_normed_elapsed_history(r_ref_g5, prev_tes, tes, user_map_ref)
        r_ref_g6 = nn_update_reference_get_normed_modified_timedelta_history(r_ref_g6, prev_tes, tes, f_tes_delta, user_map_ref)
        r_ref_g7 = nn_update_reference_get_user_answer_history(r_ref_g7, prev_tes, tes, user_map_ref)
        r_ref_g8 = nn_update_reference_get_task_container_id_diff_history(r_ref_g8, prev_tes, tes, f_tes_delta_2, user_map_ref)
        r_ref_g9 = nn_update_reference_get_content_type_id_diff_history(r_ref_g9, prev_tes, tes, f_tes_delta_3, user_map_ref)
        r_ref_g10 = nn_update_reference_get_task_container_id_history(r_ref_g10, prev_tes, tes, user_map_ref)
        
    # early update
    r_ref_g4 = nn_early_update_reference_get_question_had_explanation_history(r_ref_g4, tes, user_map_ref)
    r_ref_g5 = nn_early_update_reference_get_normed_elapsed_history(r_ref_g5, tes, user_map_ref)
    r_ref_13 = early_update_reference_get_rolling_mean_for_prior_question_elapsed_time(r_ref_13, tes, user_map_ref)
    r_ref_14 = early_update_reference_get_rolling_mean_for_prior_question_had_explanation(r_ref_14, tes, user_map_ref)
    r_ref_16 = early_update_reference_get_positive_rolling_mean_for_prior_question_elapsed_time(r_ref_16, tes, user_map_ref)
    r_ref_17 = early_update_reference_get_negative_rolling_mean_for_prior_question_elapsed_time(r_ref_17, tes, user_map_ref)
    r_ref_18 = early_update_reference_get_positive_rolling_mean_for_prior_question_had_explanation(r_ref_18, tes, user_map_ref)
    r_ref_19 = early_update_reference_get_negative_rolling_mean_for_prior_question_had_explanation(r_ref_19, tes, user_map_ref)
    
    # online function
    f_tes_1 = online_get_rolling_mean_sum_count(r_ref_1, tes, user_map_ref)
    f_tes_2 = online_get_rolling_mean_sum_count_for_6_tags_and_whole_tag(r_ref_2, tes, que_proc, user_map_ref)
    f_tes_3 = online_get_rolling_mean_sum_count_for_part(r_ref_3, tes, que, user_map_ref)
    f_tes_4 = online_get_lec_rolling_count(r_ref_4, tes, user_map_ref)
    f_tes_5 = online_get_lec_part_rolling_count(r_ref_5, tes, user_map_ref, que)
    f_tes_6 = online_get_lec_type_of_rolling_count(r_ref_6, tes, user_map_ref)
    f_tes_7 = online_get_lec_tags_rolling_count(r_ref_7, tes, user_map_ref, que_proc)
    f_tes_8 = online_get_rolling_mean_sum_count_for_content_id_darkness(r_ref_8, tes, user_map_ref)
    
    f_tes_12 = online_get_norm_rolling_count_and_cut_for_user_answer(r_ref_12, tes, que, user_map_ref)
    f_tes_13 = online_get_rolling_mean_for_prior_question_elapsed_time(r_ref_13, tes, user_map_ref)
    f_tes_14 = online_get_rolling_mean_for_prior_question_had_explanation(r_ref_14, tes, user_map_ref)
    f_tes_15 = online_get_rolling_sum_for_prior_question_isnull(r_ref_15, tes, user_map_ref)
    f_tes_16 = online_get_positive_rolling_mean_for_prior_question_elapsed_time(r_ref_16, tes, user_map_ref)
    f_tes_17 = online_get_negative_rolling_mean_for_prior_question_elapsed_time(r_ref_17, tes, user_map_ref)
    f_tes_18 = online_get_positive_rolling_mean_for_prior_question_had_explanation(r_ref_18, tes, user_map_ref)
    f_tes_19 = online_get_negative_rolling_mean_for_prior_question_had_explanation(r_ref_19, tes, user_map_ref)
    f_tes_20 = online_get_n_samples_rolling_mean(r_ref_20, tes, user_map_ref)
    
    f_tes_delta = online_get_modified_timedelta(r_ref_delta, tes, user_map_ref)
    f_tes_delta_2 = online_get_task_container_id_diff(r_ref_delta_2, tes, user_map_ref)
    f_tes_delta_3 = online_get_content_type_id_diff(r_ref_delta_3, tes, user_map_ref)
    
    f_tes_g1 = nn_online_get_content_id_history(r_ref_g1, tes, user_map_ref)
    f_tes_g2 = nn_online_get_normed_log_timestamp_history(r_ref_g2, tes, user_map_ref)
    f_tes_g3 = nn_online_get_correctness_history(r_ref_g3, tes, user_map_ref)
    f_tes_g4 = nn_online_get_question_had_explanation_history(r_ref_g4, tes, user_map_ref)
    f_tes_g5 = nn_online_get_normed_elapsed_history(r_ref_g5, tes, user_map_ref)
    f_tes_g6 = nn_online_get_normed_modified_timedelta_history(r_ref_g6, tes, f_tes_delta, user_map_ref)
    f_tes_g7 = nn_online_get_user_answer_history(r_ref_g7, tes, user_map_ref)
    f_tes_g8 = nn_online_get_task_container_id_diff_history(r_ref_g8, tes, f_tes_delta_2, user_map_ref)
    f_tes_g9 = nn_online_get_content_type_id_diff_history(r_ref_g9, tes, f_tes_delta_3, user_map_ref)
    f_tes_g10 = nn_online_get_task_container_id_history(r_ref_g10, tes, user_map_ref)
    

    # make a prediction 
    concated = pd.concat([f_tes_1, f_tes_2, f_tes_3, f_tes_4, f_tes_5, f_tes_6, f_tes_7, 
                         f_tes_8, f_tes_12, f_tes_13, f_tes_14, f_tes_15, f_tes_16, f_tes_17, f_tes_18, 
                         f_tes_19, f_tes_20], axis = 1)
    X_tes = concated[args_exp166.f_names].values.astype(np.float32)
    
    inputs = {'content_id' : f_tes_g1.values, 'normed_timedelta' : f_tes_g6.values, 'normed_log_timestamp' : f_tes_g2.values,
         'explanation': f_tes_g4.values, 'correctness': f_tes_g3.values, 'normed_elapsed': f_tes_g5.values, 'user_answer' : f_tes_g7.values,
             'task_container_id_diff': f_tes_g8.values, 'content_type_id_diff': f_tes_g9.values, 'task_container_id': f_tes_g10.values, 
              'features': X_tes}
    inputs_200 = {key: inputs[key][:, -201:] for key in inputs.keys()}
    
    
    score_exp162 = my_encoder_exp162.predict_on_batch(inputs_200, args_exp162, args_exp162.pred_bs, que, que_proc_int) \
    if X_tes.shape[0] > 0 else np.array([], dtype = np.float32)
    score_exp166 = my_encoder_exp166.predict_on_batch(inputs, args_exp166, args_exp166.pred_bs, que, que_proc_int) \
    if X_tes.shape[0] > 0 else np.array([], dtype = np.float32)
    score_exp184 = my_encoder_exp184.predict_on_batch(inputs, args_exp184, args_exp184.pred_bs, que, que_proc_int) \
    if X_tes.shape[0] > 0 else np.array([], dtype = np.float32)
    score_exp218 = my_encoder_exp218.predict_on_batch(inputs, args_exp218, args_exp218.pred_bs, que, que_proc_int) \
    if X_tes.shape[0] > 0 else np.array([], dtype = np.float32)
    score_exp224 = my_encoder_exp224.predict_on_batch(inputs, args_exp224, args_exp224.pred_bs, que, que_proc_int) \
    if X_tes.shape[0] > 0 else np.array([], dtype = np.float32)
    score_exp222 = my_encoder_exp222.predict_on_batch(inputs, args_exp222, args_exp222.pred_bs, que, que_proc_int) \
    if X_tes.shape[0] > 0 else np.array([], dtype = np.float32)
    
    preds = [score_exp162, score_exp166, score_exp184, score_exp218, score_exp222, score_exp224]
    weights = np.array([0.42, 0.25, 0.49, 0.85, 0.84, 0.82])
    weights[[3, 4, 5]] += 0.1
    weights = weights/weights.sum()
    score = 0
    for pred, weight in zip(preds, weights):
        score += pred * weight
    
    spc_tes = tes[tes['content_type_id'] == 0].copy()
    spc_tes['answered_correctly'] = score.astype(np.float64)
    env.predict(spc_tes[['row_id', 'answered_correctly']])
    
    # save previous test
    prev_tes = tes.copy()
    i += 1

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


