In [1]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

import gc
import pickle
from collections import defaultdict, OrderedDict
from sklearn.metrics import roc_auc_score
from bitarray import bitarray
import lightgbm as lgb

import psutil
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# LGBM with uc cache 0.794 

In [2]:
class LRUCache(OrderedDict):

    def __init__(self, capacity=8000000):
        self.capacity = capacity
        self.cache = OrderedDict()
     

    def get(self,key):
        if key in self.cache:
            value = self.cache.pop(key)
            self.cache[key] = value
        else:
            value = -1
         
        return value
     

    def set(self,key,value):
        if key in self.cache:
            value = self.cache.pop(key)
            self.cache[key] = value
        else:
            if len(self.cache) == self.capacity:
                self.cache.popitem(last = False)
                self.cache[key] = value
            else:
                self.cache[key] = value

In [3]:
# loading prepared data

with open('../input/riiiduploadcache800w/uc_cache.pickle', 'rb') as handle:
    lru_cache = pickle.load(handle)

with open('../input/riiiduploadcache800w/u_question_seen_dict.pickle', 'rb') as handle:
    u_question_seen_dict = pickle.load(handle)

with open('../input/riiiduploadcache800w/u_answered_correctly_count_dict.pickle', 'rb') as handle:
    u_answered_correctly_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_answered_count_dict.pickle', 'rb') as handle:
    u_answered_count_dict = pickle.load(handle)

with open('../input/riiiduploadcache800w/u_question_part_correctly_count_dict.pickle', 'rb') as handle:
    u_question_part_correctly_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_part_count_dict.pickle', 'rb') as handle:
    u_question_part_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_tag1_correctly_count_dict.pickle', 'rb') as handle:
    u_question_tag1_correctly_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_tag1_count_dict.pickle', 'rb') as handle:
    u_question_tag1_count_dict = pickle.load(handle)

with open('../input/riiiduploadcache800w/u_prior_question_correctly_timestamp_dict.pickle', 'rb') as handle:
    u_prior_question_correctly_timestamp_dict = pickle.load(handle) 
    
with open('../input/riiiduploadcache800w/u_prior2_question_correctly_timestamp_dict.pickle', 'rb') as handle:
    u_prior2_question_correctly_timestamp_dict = pickle.load(handle)  
    
with open('../input/riiiduploadcache800w/u_prior3_question_correctly_timestamp_dict.pickle', 'rb') as handle:
    u_prior3_question_correctly_timestamp_dict = pickle.load(handle)  
    
with open('../input/riiiduploadcache800w/u_prior_question_timestamp_dict.pickle', 'rb') as handle:
    u_prior_question_timestamp_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_prior2_question_timestamp_dict.pickle', 'rb') as handle:
    u_prior2_question_timestamp_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_prior3_question_timestamp_dict.pickle', 'rb') as handle:
    u_prior3_question_timestamp_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_prior4_question_timestamp_dict.pickle', 'rb') as handle:
    u_prior4_question_timestamp_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_prior_lecture_timestamp_dict.pickle', 'rb') as handle:
    u_prior_lecture_timestamp_dict = pickle.load(handle)

with open('../input/riiiduploadcache800w/u_prior2_lecture_timestamp_dict.pickle', 'rb') as handle:
    u_prior2_lecture_timestamp_dict = pickle.load(handle)    
    
with open('../input/riiiduploadcache800w/u_task_container_id_dict.pickle', 'rb') as handle:
    u_task_container_id_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_prior_question_explanation_count_dict.pickle', 'rb') as handle:
    u_prior_question_explanation_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_prior_question_explanation_correctly_count_dict.pickle', 'rb') as handle:
    u_prior_question_explanation_correctly_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_listening_correctly_count_dict.pickle', 'rb') as handle:
    u_question_listening_correctly_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_reading_correctly_count_dict.pickle', 'rb') as handle:
    u_question_reading_correctly_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_listening_count_dict.pickle', 'rb') as handle:
    u_question_listening_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_reading_count_dict.pickle', 'rb') as handle:
    u_question_reading_count_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_incorrect_timestamp_dict.pickle', 'rb') as handle:
    u_question_incorrect_timestamp_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_incorrect_timestamp2_dict.pickle', 'rb') as handle:
    u_question_incorrect_timestamp2_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_incorrect_timestamp3_dict.pickle', 'rb') as handle:
    u_question_incorrect_timestamp3_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_part_correct_timestamp_dict.pickle', 'rb') as handle:
    u_question_part_correct_timestamp_dict = pickle.load(handle)
    
with open('../input/riiiduploadcache800w/u_question_part_incorrect_timestamp_dict.pickle', 'rb') as handle:
    u_question_part_incorrect_timestamp_dict = pickle.load(handle)
    
questions_df = pd.read_pickle('../input/riiiduploadcache800w/questions_df.pickle')
content_df = pd.read_pickle('../input/riiiduploadcache800w/content_df.pickle')

# loading trained model

model = lgb.Booster(model_file='../input/riiiduploadcache800w/lgb_model.txt')

In [4]:
def add_user_feats_without_update(df, 
                                  u_answered_correctly_count_dict, 
                                  u_answered_count_dict, 
                                  u_question_part_correctly_count_dict,
                                  u_question_part_count_dict,
                                  u_question_tag1_correctly_count_dict,
                                  u_question_tag1_count_dict,
                                  u_prior_question_correctly_timestamp_dict,
                                  u_prior2_question_correctly_timestamp_dict,
                                  u_prior3_question_correctly_timestamp_dict,
                                  u_prior_question_timestamp_dict,
                                  u_prior2_question_timestamp_dict,
                                  u_prior3_question_timestamp_dict,
                                  u_prior4_question_timestamp_dict,
                                  u_prior_lecture_timestamp_dict,
                                  u_prior2_lecture_timestamp_dict,
                                  u_task_container_id_dict,
                                  u_prior_question_explanation_count_dict,
                                  u_prior_question_explanation_correctly_count_dict,
                                  u_question_listening_correctly_count_dict,
                                  u_question_reading_correctly_count_dict,
                                  u_question_listening_count_dict,
                                  u_question_reading_count_dict,
                                  u_question_incorrect_timestamp_dict,
                                  u_question_incorrect_timestamp2_dict,
                                  u_question_incorrect_timestamp3_dict,
                                  u_question_seen_dict,
                                  u_question_part_correct_timestamp_dict,
                                  u_question_part_incorrect_timestamp_dict,
                                 ):
    
    uacc = np.zeros(len(df), dtype=np.int32)
    uac = np.zeros(len(df), dtype=np.int32)
    uqpcc = np.zeros(len(df), dtype=np.int32)
    uqpc = np.zeros(len(df), dtype=np.int32)
    uqt1cc = np.zeros(len(df), dtype=np.int32)
    uqt1c = np.zeros(len(df), dtype=np.int32)
    upqct = np.zeros(len(df), dtype=np.int32)
    up2qct = np.zeros(len(df), dtype=np.int32)
    up3qct = np.zeros(len(df), dtype=np.int32)
    upqt = np.zeros(len(df), dtype=np.int32)
    up2qt = np.zeros(len(df), dtype=np.int32)
    up3qt = np.zeros(len(df), dtype=np.int32)
    up4qt = np.zeros(len(df), dtype=np.int32)
    uplt = np.zeros(len(df), dtype=np.int32)
    up2lt = np.zeros(len(df), dtype=np.int32)
    utci = np.zeros(len(df), dtype=np.int32)
    upqec = np.zeros(len(df), dtype=np.int32)
    upqecc = np.zeros(len(df), dtype=np.int32)
    uqlcc = np.zeros(len(df), dtype=np.int32)
    uqrcc = np.zeros(len(df), dtype=np.int32)
    uqlc = np.zeros(len(df), dtype=np.int32)
    uqrc = np.zeros(len(df), dtype=np.int32)
    uqict = np.zeros(len(df), dtype=np.int32)
    uqict2 = np.zeros(len(df), dtype=np.int32)
    uqict3 = np.zeros(len(df), dtype=np.int32)
    uqs = np.zeros(len(df), dtype=np.int32)
    uqn = np.zeros(len(df), dtype=np.int32)
    uqpct = np.zeros(len(df), dtype=np.int32)
    uqpict = np.zeros(len(df), dtype=np.int32)
    uct = np.zeros(len(df), dtype=np.int32)
    
    for cnt, row in enumerate(df[['user_id',
                                  'content_type_id', 
                                  'part',
                                  'tag1',
                                  'timestamp',
                                  'task_container_id',
                                  'prior_question_had_explanation',
                                  'LorR',
                                  'content_id']].values):
        user_id = row[0]
        content_type = row[1]
        question_part = row[2]
        question_tag1 = row[3]
        timestamp = row[4]
        task_container_id = row[5]
        prior_question_had_explanation = row[6]
        LorR = row[7]
        content_id = row[8]
        
        uid_part = str(user_id) + '_' + str(question_part)
        uid_tag1 = str(user_id) + '_' + str(question_tag1)
        uid_cid = int(user_id) + int(content_id)*10e10
        
        uct[cnt] = timestamp - lru_cache.get(uid_cid)
        uacc[cnt] = u_answered_correctly_count_dict[user_id]
        uac[cnt] = u_answered_count_dict[user_id]
        uqpcc[cnt] = u_question_part_correctly_count_dict[uid_part]
        uqpc[cnt] = u_question_part_count_dict[uid_part]
        uqt1cc[cnt] = u_question_tag1_correctly_count_dict[uid_tag1]
        uqt1c[cnt] = u_question_tag1_count_dict[uid_tag1]
        upqct[cnt] = timestamp - u_prior_question_correctly_timestamp_dict[user_id]
        up2qct[cnt] = timestamp - u_prior2_question_correctly_timestamp_dict[user_id]
        up3qct[cnt] = timestamp - u_prior3_question_correctly_timestamp_dict[user_id]
        upqt[cnt] = timestamp - u_prior_question_timestamp_dict[user_id]
        up2qt[cnt] = timestamp - u_prior2_question_timestamp_dict[user_id]
        up3qt[cnt] = timestamp - u_prior3_question_timestamp_dict[user_id]
        up4qt[cnt] = timestamp - u_prior4_question_timestamp_dict[user_id]
        uplt[cnt] = timestamp - u_prior_lecture_timestamp_dict[user_id]
        up2lt[cnt] = timestamp - u_prior2_lecture_timestamp_dict[user_id]
        utci[cnt] = task_container_id - u_task_container_id_dict[user_id]
        upqec[cnt] = u_prior_question_explanation_count_dict[user_id]
        upqecc[cnt] = u_prior_question_explanation_correctly_count_dict[user_id]
        uqlcc[cnt] = u_question_listening_correctly_count_dict[user_id]
        uqrcc[cnt] = u_question_reading_correctly_count_dict[user_id]
        uqlc[cnt] = u_question_listening_count_dict[user_id]
        uqrc[cnt] = u_question_reading_count_dict[user_id]
        uqict[cnt] = timestamp - u_question_incorrect_timestamp_dict[user_id]
        uqict2[cnt] = timestamp - u_question_incorrect_timestamp2_dict[user_id]
        uqict3[cnt] = timestamp - u_question_incorrect_timestamp3_dict[user_id]
        uqpct[cnt] = timestamp - u_question_part_correct_timestamp_dict[uid_part]
        uqpict[cnt] = timestamp - u_question_part_incorrect_timestamp_dict[uid_part]
        if content_type == 0:
            if user_id not in u_question_seen_dict:
                u_question_seen_dict[user_id] = bitarray('0'*14000, endian='little')
            uqs[cnt] = u_question_seen_dict[user_id][content_id]
            uqn[cnt] = u_question_seen_dict[user_id].count()
        
    user_feats_df = pd.DataFrame({'u_answered_correctly_count': uacc, 
                                  'u_answered_count': uac,
                                  'u_question_part_correctly_count': uqpcc,
                                  'u_question_part_count': uqpc,
                                  'u_question_tag1_correctly_count': uqt1cc,
                                  'u_question_tag1_count': uqt1c,
                                  'u_prior_question_correctly_timestamp_diff': upqct,
                                  'u_prior_question_correctly_timestamp_diff2': up2qct,
                                  'u_prior_question_correctly_timestamp_diff3': up3qct,
                                  'u_prior_question_timestamp_diff': upqt,
                                  'u_prior_question_timestamp_diff2': up2qt,
                                  'u_prior_question_timestamp_diff3': up3qt,
                                  'u_prior_question_timestamp_diff4': up4qt,
                                  'u_prior_lecture_timestamp_diff': uplt,
                                  'u_prior2_lecture_timestamp_diff': up2lt,
                                  'u_task_container_id_diff': utci,
                                  'u_prior_question_explanation_count': upqec,
                                  'u_prior_question_explanation_correctly_count': upqecc,
                                  'u_question_listening_correctly_count': uqlcc,
                                  'u_question_reading_correctly_count': uqrcc,
                                  'u_question_listening_count': uqlc,
                                  'u_question_reading_count': uqrc,
                                  'u_question_incorrect_timestamp_diff': uqict,
                                  'u_question_incorrect_timestamp_diff2': uqict2,
                                  'u_question_incorrect_timestamp_diff3': uqict3,
                                  'u_question_seen': uqs,
                                  'u_question_nunique': uqn,
                                  'u_question_part_correct_timestamp_diff': uqpct,
                                  'u_question_part_incorrect_timestamp_diff': uqpict,
                                  'uc_question_timestamp_diff': uct,
                                  })
    
    user_feats_df['u_answered_correctly_avg'] = user_feats_df['u_answered_correctly_count'] / user_feats_df['u_answered_count']
    user_feats_df['u_question_part_correctly_avg'] = user_feats_df['u_question_part_correctly_count'] / user_feats_df['u_question_part_count']
    user_feats_df['u_question_tag1_correctly_avg'] = user_feats_df['u_question_tag1_correctly_count'] / user_feats_df['u_question_tag1_count']
    user_feats_df['u_prior_question_timestamp_diff_2_1'] = user_feats_df['u_prior_question_timestamp_diff2'] - user_feats_df['u_prior_question_timestamp_diff']
    user_feats_df['u_prior_question_timestamp_diff_3_2'] = user_feats_df['u_prior_question_timestamp_diff3'] - user_feats_df['u_prior_question_timestamp_diff2']
    user_feats_df['u_prior_question_timestamp_diff_4_3'] = user_feats_df['u_prior_question_timestamp_diff4'] - user_feats_df['u_prior_question_timestamp_diff3']
    
    df = pd.concat([df, user_feats_df], axis=1)
    df['u_question_timestamp_diff']  = df['u_prior_question_timestamp_diff_2_1'] - df['prior_question_elapsed_time']
    df['u_question_timestamp_diff2']  = df['u_prior_question_timestamp_diff_3_2'] - df['prior_question_elapsed_time']
    df['u_question_timestamp_diff3']  = df['u_prior_question_timestamp_diff_4_3'] - df['prior_question_elapsed_time']
    df['u_listening_reading_ratio'] = df['u_question_listening_correctly_count'] / (1 + df['u_question_reading_correctly_count'])
    df['u_listening_correctly_avg'] = df['u_question_listening_correctly_count'] / (1 + df['u_question_listening_count'])
    df['u_reading_correctly_avg'] = df['u_question_reading_correctly_count'] / (1 + df['u_question_reading_count'])
    
    return df


def update_user_feats(df, 
                      u_answered_correctly_count_dict,
                      u_answered_count_dict, 
                      u_question_part_correctly_count_dict,
                      u_question_part_count_dict,
                      u_question_tag1_correctly_count_dict,
                      u_question_tag1_count_dict,
                      u_prior_question_correctly_timestamp_dict,
                      u_prior2_question_correctly_timestamp_dict,
                      u_prior3_question_correctly_timestamp_dict,
                      u_prior_question_timestamp_dict,
                      u_prior2_question_timestamp_dict,
                      u_prior3_question_timestamp_dict,
                      u_prior4_question_timestamp_dict,
                      u_prior_lecture_timestamp_dict, 
                      u_prior2_lecture_timestamp_dict,
                      u_task_container_id_dict,
                      u_prior_question_explanation_count_dict,
                      u_prior_question_explanation_correctly_count_dict,
                      u_question_listening_correctly_count_dict, 
                      u_question_reading_correctly_count_dict,
                      u_question_listening_count_dict,
                      u_question_reading_count_dict,
                      u_question_incorrect_timestamp_dict,
                      u_question_incorrect_timestamp2_dict,
                      u_question_incorrect_timestamp3_dict,
                      u_question_seen_dict,
                      u_question_part_correct_timestamp_dict,
                      u_question_part_incorrect_timestamp_dict,
                     ):
    
    for cnt, row in enumerate(df[['user_id',
                                  'content_type_id',
                                  'answered_correctly', 
                                  'part',
                                  'tag1',
                                  'timestamp',
                                  'task_container_id',
                                  'prior_question_had_explanation',
                                  'LorR',
                                  'content_id']].values):
        user_id = row[0]
        content_type = row[1]
        answered_correctly = row[2]
        question_part = row[3]
        question_tag1 = row[4]
        timestamp = row[5]
        task_container_id = row[6]
        prior_question_had_explanation = row[7]
        LorR = row[8]
        content_id = row[9]
        
        uid_part = str(user_id) + '_' + str(question_part)
        uid_tag1 = str(user_id) + '_' + str(question_tag1)
        uid_cid = int(user_id) + int(content_id)*10e10
        
        if content_type == 0:
            
            if answered_correctly == 1:
                u_prior3_question_correctly_timestamp_dict[user_id] = u_prior2_question_correctly_timestamp_dict[user_id]
                u_prior2_question_correctly_timestamp_dict[user_id] = u_prior_question_correctly_timestamp_dict[user_id]
                u_prior_question_correctly_timestamp_dict[user_id] = timestamp
                u_prior_question_explanation_correctly_count_dict[user_id] += 1
                u_question_part_correct_timestamp_dict[uid_part] = timestamp
                if LorR == 0:
                    u_question_listening_correctly_count_dict[user_id] += 1
                else:
                    u_question_reading_correctly_count_dict[user_id] += 1
            else:
                u_question_incorrect_timestamp3_dict[user_id] = u_question_incorrect_timestamp2_dict[user_id]
                u_question_incorrect_timestamp2_dict[user_id] = u_question_incorrect_timestamp_dict[user_id]
                u_question_incorrect_timestamp_dict[user_id] = timestamp
                u_question_part_incorrect_timestamp_dict[uid_part] = timestamp
                
            u_answered_correctly_count_dict[user_id] += answered_correctly
            u_answered_count_dict[user_id] += 1
            u_question_part_correctly_count_dict[uid_part] += answered_correctly
            u_question_part_count_dict[uid_part] += 1
            u_question_tag1_correctly_count_dict[uid_tag1] += answered_correctly
            u_question_tag1_count_dict[uid_tag1] += 1
            u_prior4_question_timestamp_dict[user_id] = u_prior3_question_timestamp_dict[user_id]
            u_prior3_question_timestamp_dict[user_id] = u_prior2_question_timestamp_dict[user_id]
            u_prior2_question_timestamp_dict[user_id] = u_prior_question_timestamp_dict[user_id]
            u_prior_question_timestamp_dict[user_id] = timestamp
            u_task_container_id_dict[user_id] = task_container_id
            u_prior_question_explanation_count_dict[user_id] += prior_question_had_explanation
            if LorR == 0:
                u_question_listening_count_dict[user_id] += 1
            else:
                u_question_reading_count_dict[user_id] += 1
            u_question_seen_dict[user_id][content_id] = 1
            
            lru_cache.set(uid_cid, timestamp)
            
        else:
            u_prior2_lecture_timestamp_dict[user_id] = u_prior_lecture_timestamp_dict[user_id]
            u_prior_lecture_timestamp_dict[user_id] = timestamp

In [5]:
TARGET = 'answered_correctly'
FEATS = ['content_id', 'prior_question_elapsed_time', 'prior_question_had_explanation', 'answered_correctly_avg_c', 'answered_correctly_var_c', 'part', 'question_asked', 'right_answers', 'bundle_size', 'bundle_rignt_answers', 'bundle_questions_asked', 'bundle_accuracy', 'part_rignt_answers', 'part_questions_asked', 'part_accuracy', 'tag1', 'tag1_answered_correctly_mean', 'tag1_answered_correctly_var', 'tags_emb_0', 'tags_emb_1', 'c_question_count_percent', 'c_question_part_percent', 'c_question_tag1_percent', 'c_question_unique_users_seen', 'u_answered_correctly_count', 'u_answered_count', 'u_question_part_correctly_count', 'u_question_part_count', 'u_question_tag1_correctly_count', 'u_question_tag1_count', 'u_prior_question_correctly_timestamp_diff', 'u_prior_question_correctly_timestamp_diff2', 'u_prior_question_correctly_timestamp_diff3', 'u_prior_question_timestamp_diff', 'u_prior_question_timestamp_diff2', 'u_prior_question_timestamp_diff3', 'u_prior_question_timestamp_diff4', 'u_prior_lecture_timestamp_diff', 'u_prior2_lecture_timestamp_diff', 'u_task_container_id_diff', 'u_prior_question_explanation_count', 'u_prior_question_explanation_correctly_count', 'u_question_listening_correctly_count', 'u_question_reading_correctly_count', 'u_question_listening_count', 'u_question_reading_count', 'u_question_incorrect_timestamp_diff', 'u_question_incorrect_timestamp_diff2', 'u_question_incorrect_timestamp_diff3', 'u_question_seen', 'u_question_nunique', 'u_question_part_correct_timestamp_diff', 'u_question_part_incorrect_timestamp_diff', 'uc_question_timestamp_diff', 'u_answered_correctly_avg', 'u_question_part_correctly_avg', 'u_question_tag1_correctly_avg', 'u_prior_question_timestamp_diff_2_1', 'u_prior_question_timestamp_diff_3_2', 'u_prior_question_timestamp_diff_4_3', 'u_question_timestamp_diff', 'u_question_timestamp_diff2', 'u_question_timestamp_diff3', 'u_listening_reading_ratio', 'u_listening_correctly_avg', 'u_reading_correctly_avg']

prior_question_elapsed_time_mean = 25340.273

# SAKT 0.777

In [6]:
MAX_SEQ = 240
ACCEPTED_USER_CONTENT_SIZE = 2
EMBED_SIZE = 256
BATCH_SIZE = 96
DROPOUT = 0.1

n_skill = 13523

In [7]:
class FFN(nn.Module):
    def __init__(self, state_size = 200, forward_expansion = 1, bn_size = MAX_SEQ - 1, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        
        self.lr1 = nn.Linear(state_size, forward_expansion * state_size)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(bn_size)
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.relu(self.lr1(x))
        x = self.bn(x)
        x = self.lr2(x)
        return self.dropout(x)
    
class FFN0(nn.Module):
    def __init__(self, state_size = 200, forward_expansion = 1, bn_size = MAX_SEQ - 1, dropout=0.2):
        super(FFN0, self).__init__()
        self.state_size = state_size

        self.lr1 = nn.Linear(state_size, forward_expansion * state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size)
        self.layer_normal = nn.LayerNorm(state_size) 
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        x=self.layer_normal(x)
        return self.dropout(x)
    

def future_mask(seq_length):
    future_mask = (np.triu(np.ones([seq_length, seq_length]), k = 1)).astype('bool')
    return torch.from_numpy(future_mask)


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads = 8, dropout = DROPOUT, forward_expansion = 1):
        super(TransformerBlock, self).__init__()
        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_normal = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim, forward_expansion = forward_expansion, dropout=dropout)
        self.ffn0  = FFN0(embed_dim, forward_expansion = forward_expansion, dropout=dropout)
        self.layer_normal_2 = nn.LayerNorm(embed_dim)

    def forward(self, value, key, query, att_mask):
        att_output, att_weight = self.multi_att(value, key, query, attn_mask=att_mask)
        att_output = self.dropout(self.layer_normal(att_output + value))
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]
        x = self.ffn(att_output)
        x1 = self.ffn0(att_output)
        x = self.dropout(self.layer_normal_2(x + x1 + att_output))
        return x.squeeze(-1), att_weight
    
    
class Encoder(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, dropout = DROPOUT, forward_expansion = 1, num_layers=1, heads = 8):
        super(Encoder, self).__init__()
        self.n_skill, self.embed_dim = n_skill, embed_dim
        self.embedding = nn.Embedding(2 * n_skill + 1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq - 1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill+1, embed_dim)
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, forward_expansion = forward_expansion) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, question_ids):
        device = x.device
        x = self.embedding(x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        x = self.dropout(x + pos_x)
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        e = self.e_embedding(question_ids)
        e = e.permute(1, 0, 2)
        for layer in self.layers:
            att_mask = future_mask(e.size(0)).to(device)
            x, att_weight = layer(e, x, x, att_mask=att_mask)
            x = x.permute(1, 0, 2)
        x = x.permute(1, 0, 2)
        return x, att_weight
    

class SAKTModel(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, dropout = DROPOUT, forward_expansion = 1, enc_layers=1, heads = 8):
        super(SAKTModel, self).__init__()
        self.encoder = Encoder(n_skill, max_seq, embed_dim, dropout, forward_expansion, num_layers=enc_layers)
        self.pred = nn.Linear(embed_dim, 1)
        
    def forward(self, x, question_ids):
        x, att_weight = self.encoder(x, question_ids)
        x = self.pred(x)
        return x.squeeze(-1), att_weight

In [8]:
group = joblib.load("../input/v4-fork-of-riiid-sakt-model-full/group.pkl.zip")

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nn_model = SAKTModel(n_skill, 
                     max_seq=MAX_SEQ, 
                     embed_dim=EMBED_SIZE, 
                     forward_expansion=1, 
                     enc_layers=1, 
                     heads=4, 
                     dropout=0.1)

try:
    nn_model.load_state_dict(torch.load("../input/v4-fork-of-riiid-sakt-model-full/sakt_model.pt"))
except:
    nn_model.load_state_dict(torch.load("../input/v4-fork-of-riiid-sakt-model-full/sakt_model.pt", 
                                        map_location='cpu'))

nn_model.to(device)
nn_model.eval()

SAKTModel(
  (encoder): Encoder(
    (embedding): Embedding(27047, 256)
    (pos_embedding): Embedding(239, 256)
    (e_embedding): Embedding(13524, 256)
    (layers): ModuleList(
      (0): TransformerBlock(
        (multi_att): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (layer_normal): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ffn): FFN(
          (lr1): Linear(in_features=256, out_features=256, bias=True)
          (relu): ReLU()
          (bn): BatchNorm1d(239, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lr2): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ffn0): FFN0(
          (lr1): Linear(in_features=256, out_features=256, bias=True)
          (relu): ReLU()
          (lr2): Linear(in_features=256, out_features=256, bias=True)

In [10]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, n_skill, max_seq=100):
        super(TestDataset, self).__init__()
        self.samples, self.user_ids, self.test_df = samples, [x for x in test_df["user_id"].unique()], test_df
        self.n_skill, self.max_seq = n_skill, max_seq

    def __len__(self):
        return self.test_df.shape[0]
    
    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]
        
        user_id = test_info['user_id']
        target_id = test_info['content_id']
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        
        if user_id in self.samples.index:
            content_id, answered_correctly = self.samples[user_id]
            
            seq_len = len(content_id)
            
            if seq_len >= self.max_seq:
                content_id_seq = content_id[-self.max_seq:]
                answered_correctly_seq = answered_correctly[-self.max_seq:]
            else:
                content_id_seq[-seq_len:] = content_id
                answered_correctly_seq[-seq_len:] = answered_correctly
                
        x = content_id_seq[1:].copy()
        x += (answered_correctly_seq[1:] == 1) * self.n_skill
        
        questions = np.append(content_id_seq[2:], [target_id])
        
        return x, questions

In [11]:
import riiideducation

env = riiideducation.make_env()
iter_test = env.iter_test()
set_predict = env.predict

In [12]:
use_cols = ['user_id', 'content_type_id', 'part', 
            'tag1', 'timestamp', 'task_container_id', 
            'prior_question_had_explanation', 'LorR', 
            'content_id']

previous_test_df = None
prev_test_df1 = None

for (test_df, sample_prediction_df) in iter_test:
    
    # SAKT inference
    test_df1 = test_df.copy()
    if prev_test_df1 is not None:
        prev_test_df1['answered_correctly'] = eval(test_df1['prior_group_answers_correct'].iloc[0])
        prev_test_df1 = prev_test_df1[prev_test_df1.content_type_id == False]        
        prev_group = prev_test_df1[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))
        
        for prev_user_id in prev_group.index:
            if prev_user_id in group.index:
                group[prev_user_id] = (
                    np.append(group[prev_user_id][0], prev_group[prev_user_id][0])[-MAX_SEQ:], 
                    np.append(group[prev_user_id][1], prev_group[prev_user_id][1])[-MAX_SEQ:]
                )
 
            else:
                group[prev_user_id] = (
                    prev_group[prev_user_id][0], 
                    prev_group[prev_user_id][1]
                )
                
    prev_test_df1 = test_df1.copy()
    test_df1 = test_df1[test_df1.content_type_id == False]
    test_dataset = TestDataset(group, test_df1, n_skill, max_seq=MAX_SEQ)
    test_dataloader = DataLoader(test_dataset, batch_size=51200, shuffle=False)
    
    outs = []
    for item in test_dataloader:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
        with torch.no_grad():
            output, _ = nn_model(x, target_id)
        outs.extend(torch.sigmoid(output)[:, -1].view(-1).data.cpu().numpy())
    
    
    
    # LGBM inference
    test_df['prior_question_had_explanation'] = test_df.prior_question_had_explanation.fillna(False).astype('int8')
    test_df = pd.merge(test_df, questions_df, on='content_id', how='left')
    test_df = pd.merge(test_df, content_df, on='content_id', how='left')
    test_df['prior_question_elapsed_time'] = test_df.prior_question_elapsed_time.fillna(prior_question_elapsed_time_mean)
    test_df.fillna(-1, inplace=True)
    test_df[use_cols] = test_df[use_cols].astype('int')
    
    if previous_test_df is not None:
        previous_test_df[TARGET] = eval(test_df["prior_group_answers_correct"].iloc[0])
        previous_test_df[TARGET] = previous_test_df[TARGET].astype('int')
        update_user_feats(previous_test_df, 
                          u_answered_correctly_count_dict,
                          u_answered_count_dict, 
                          u_question_part_correctly_count_dict,
                          u_question_part_count_dict,
                          u_question_tag1_correctly_count_dict,
                          u_question_tag1_count_dict,
                          u_prior_question_correctly_timestamp_dict,
                          u_prior2_question_correctly_timestamp_dict,
                          u_prior3_question_correctly_timestamp_dict,
                          u_prior_question_timestamp_dict,
                          u_prior2_question_timestamp_dict,
                          u_prior3_question_timestamp_dict,
                          u_prior4_question_timestamp_dict,
                          u_prior_lecture_timestamp_dict,
                          u_prior2_lecture_timestamp_dict,
                          u_task_container_id_dict,
                          u_prior_question_explanation_count_dict,
                          u_prior_question_explanation_correctly_count_dict,
                          u_question_listening_correctly_count_dict, 
                          u_question_reading_correctly_count_dict,
                          u_question_listening_count_dict, 
                          u_question_reading_count_dict,
                          u_question_incorrect_timestamp_dict,
                          u_question_incorrect_timestamp2_dict,
                          u_question_incorrect_timestamp3_dict,
                          u_question_seen_dict,
                          u_question_part_correct_timestamp_dict, 
                          u_question_part_incorrect_timestamp_dict)
        
    previous_test_df = test_df.copy()
    test_df = test_df[test_df['content_type_id'] == 0].reset_index(drop=True)
    test_df = add_user_feats_without_update(test_df,
                                            u_answered_correctly_count_dict,
                                            u_answered_count_dict, 
                                            u_question_part_correctly_count_dict,
                                            u_question_part_count_dict,
                                            u_question_tag1_correctly_count_dict,
                                            u_question_tag1_count_dict,
                                            u_prior_question_correctly_timestamp_dict,
                                            u_prior2_question_correctly_timestamp_dict,
                                            u_prior3_question_correctly_timestamp_dict,
                                            u_prior_question_timestamp_dict,
                                            u_prior2_question_timestamp_dict,
                                            u_prior3_question_timestamp_dict,
                                            u_prior4_question_timestamp_dict,
                                            u_prior_lecture_timestamp_dict,
                                            u_prior2_lecture_timestamp_dict,
                                            u_task_container_id_dict,
                                            u_prior_question_explanation_count_dict,
                                            u_prior_question_explanation_correctly_count_dict,
                                            u_question_listening_correctly_count_dict, 
                                            u_question_reading_correctly_count_dict,
                                            u_question_listening_count_dict, 
                                            u_question_reading_count_dict,
                                            u_question_incorrect_timestamp_dict,
                                            u_question_incorrect_timestamp2_dict,
                                            u_question_incorrect_timestamp3_dict,
                                            u_question_seen_dict,
                                            u_question_part_correct_timestamp_dict, 
                                            u_question_part_incorrect_timestamp_dict)
    
    test_df[TARGET] = model.predict(test_df[FEATS])
    
    # blend
    test_df[TARGET] = 0.75*test_df[TARGET] + 0.25*np.array(outs)
    
    
    set_predict(test_df[['row_id', TARGET]])

In [13]:
!head -20 submission.csv

row_id,answered_correctly
0,0.36327344343356793
1,0.8857490012077305
2,0.6622038262751303
3,0.8570988935355467
4,0.32162905766284416
5,0.6885190924316829
6,0.5101505815670148
7,0.5999243195653847
8,0.784834434358329
9,0.5629938440087059
10,0.8245712406821528
11,0.5602268239093686
12,0.4653338264689878
13,0.7435863065782544
14,0.31396741173064713
15,0.7635788674639628
16,0.5497390390475636
17,0.9463706452982978
18,0.7103668356173949
