In [20]:
import math
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import KFold

from copy import deepcopy

from gensim.models import Word2Vec

import warnings

warnings.filterwarnings(action='ignore')
torch.set_printoptions(sci_mode=True)

In [21]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [22]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# 데이터 전처리

In [23]:
class MakeDataset():

    def __init__(self, DATA_PATH):
        self.preporcessing(DATA_PATH)
        self.oof_user_set = self.split_data()
    
    def split_data(self):
        user_list = self.all_df['userID'].unique().tolist()
        oof_user_set = {}
        kf = KFold(n_splits = 5, random_state = 22, shuffle = True)
        for idx, (train_user, valid_user) in enumerate(kf.split(user_list)):
            oof_user_set[idx] = valid_user.tolist()
        
        return oof_user_set

    def preporcessing(self, DATA_PATH):

        dtype = {
            'userID': 'int16',
            'answerCode': 'int8',
            'KnowledgeTag': 'int16'
        }
        
        train_df = pd.read_csv(os.path.join(DATA_PATH, 'train_data.csv'), dtype=dtype, parse_dates=['Timestamp'])
        train_df = train_df.sort_values(by=['userID', 'Timestamp']).reset_index(drop=True)

        test_df = pd.read_csv(os.path.join(DATA_PATH, 'test_data.csv'), dtype=dtype, parse_dates=['Timestamp'])
        test_df = test_df.sort_values(by=['userID', 'Timestamp']).reset_index(drop=True)


        def get_large_paper_number(x):
            return x[1:4]
        
        train_df['large_paper_number'] = train_df['assessmentItemID'].apply(lambda x : get_large_paper_number(x))
        test_df['large_paper_number'] = test_df['assessmentItemID'].apply(lambda x : get_large_paper_number(x))

        # 문제 푸는데 걸린 시간
        def get_now_elapsed(df):
            
            diff = df.loc[:, ['userID','Timestamp']].groupby('userID').diff().fillna(pd.Timedelta(seconds=0))
            diff = diff.fillna(pd.Timedelta(seconds=0))
            diff = diff['Timestamp'].apply(lambda x: x.total_seconds())
            df['now_elapsed'] = diff
            df['now_elapsed'] = df['now_elapsed'].apply(lambda x : x if x < 650 and x >=0 else 0)
            df['now_elapsed'] = df['now_elapsed']

            return df

        train_df = get_now_elapsed(df = train_df)
        test_df = get_now_elapsed(df = test_df)

        all_df = pd.concat([train_df, test_df])
        all_df = all_df[all_df['answerCode'] != -1].reset_index(drop = True)

        # normalize_score
        def get_normalize_score(df, all_df):
            ret_df = []

            group_df = df.groupby('userID')
            mean_answerCode_df = all_df.groupby('testId').mean()['answerCode']
            std_answerCode_df = all_df.groupby('testId').std()['answerCode']
            for userID, get_df in group_df:
                normalize_score_df = (get_df[get_df['answerCode'] != -1].groupby('testId').mean()['answerCode'] - mean_answerCode_df) / std_answerCode_df
                get_df = get_df.copy().set_index('testId')
                get_df['normalize_score'] = normalize_score_df
                ret_df.append(get_df.reset_index(drop = False))

            ret_df = pd.concat(ret_df).reset_index(drop = True)
            
            return ret_df
        
        # train_df = get_normalize_score(df = train_df, all_df = all_df)
        # test_df = get_normalize_score(df = test_df, all_df = all_df)

        # 문항별 정답률
        train_df = train_df.set_index('assessmentItemID')
        train_df['assessmentItemID_mean_answerCode'] = all_df.groupby('assessmentItemID').mean()['answerCode']
        train_df = train_df.reset_index(drop = False)

        test_df = test_df.set_index('assessmentItemID')
        test_df['assessmentItemID_mean_answerCode'] = all_df.groupby('assessmentItemID').mean()['answerCode']
        test_df = test_df.reset_index(drop = False)

        # 문항별 정답률 표준편차
        train_df = train_df.set_index('assessmentItemID')
        train_df['assessmentItemID_std_answerCode'] = all_df.groupby('assessmentItemID').std()['answerCode']
        train_df = train_df.reset_index(drop = False)

        test_df = test_df.set_index('assessmentItemID')
        test_df['assessmentItemID_std_answerCode'] = all_df.groupby('assessmentItemID').std()['answerCode']
        test_df = test_df.reset_index(drop = False)

        # 올바르게 푼 사람들의 문항별 풀이 시간 평균
        train_df = train_df.set_index('assessmentItemID')
        train_df['assessmentItemID_mean_now_elapsed'] = all_df[all_df['answerCode'] == 1].groupby('assessmentItemID').mean()['now_elapsed']
        train_df = train_df.reset_index(drop = False)

        test_df = test_df.set_index('assessmentItemID')
        test_df['assessmentItemID_mean_now_elapsed'] = all_df[all_df['answerCode'] == 1].groupby('assessmentItemID').mean()['now_elapsed']
        test_df = test_df.reset_index(drop = False)

        # 올바르게 푼 사람들의 문항별 풀이 시간 표준 편차
        train_df = train_df.set_index('assessmentItemID')
        train_df['assessmentItemID_std_now_elapsed'] = all_df[all_df['answerCode'] == 1].groupby('assessmentItemID').std()['now_elapsed']
        train_df = train_df.reset_index(drop = False)

        test_df = test_df.set_index('assessmentItemID')
        test_df['assessmentItemID_std_now_elapsed'] = all_df[all_df['answerCode'] == 1].groupby('assessmentItemID').std()['now_elapsed']
        test_df = test_df.reset_index(drop = False)

        # 문제 푼 시간
        train_df['hour'] = train_df['Timestamp'].dt.hour
        test_df['hour'] = test_df['Timestamp'].dt.hour

        # 문제 푼 요일
        train_df['dayofweek'] = train_df['Timestamp'].dt.dayofweek
        test_df['dayofweek'] = test_df['Timestamp'].dt.dayofweek

        # index 로 변환

        def get_val2idx(val_list : list) -> dict:
            val2idx = {}
            for idx, val in enumerate(val_list):
                val2idx[val] = idx
            
            return val2idx

        assessmentItemID2idx = get_val2idx(all_df['assessmentItemID'].unique().tolist())
        testId2idx = get_val2idx(all_df['testId'].unique().tolist())
        KnowledgeTag2idx = get_val2idx(all_df['KnowledgeTag'].unique().tolist())
        large_paper_number2idx = get_val2idx(all_df['large_paper_number'].unique().tolist())

        train_df['assessmentItemID2idx'] = train_df['assessmentItemID'].apply(lambda x : assessmentItemID2idx[x])
        train_df['testId2idx'] = train_df['testId'].apply(lambda x : testId2idx[x])
        train_df['KnowledgeTag2idx'] = train_df['KnowledgeTag'].apply(lambda x : KnowledgeTag2idx[x])
        train_df['large_paper_number2idx'] = train_df['large_paper_number'].apply(lambda x : large_paper_number2idx[x])

        test_df['assessmentItemID2idx'] = test_df['assessmentItemID'].apply(lambda x : assessmentItemID2idx[x])
        test_df['testId2idx'] = test_df['testId'].apply(lambda x : testId2idx[x])
        test_df['KnowledgeTag2idx'] = test_df['KnowledgeTag'].apply(lambda x : KnowledgeTag2idx[x])
        test_df['large_paper_number2idx'] = test_df['large_paper_number'].apply(lambda x : large_paper_number2idx[x])

        self.assessmentItemID2idx = assessmentItemID2idx
        self.train_df, self.test_df = train_df, test_df
        self.all_df = pd.concat([train_df, test_df[test_df['answerCode'] != -1]]).reset_index(drop=True)
        self.num_assessmentItemID = len(assessmentItemID2idx)
        self.num_testId = len(testId2idx)
        self.num_KnowledgeTag = len(KnowledgeTag2idx)
        self.num_large_paper_number = len(large_paper_number2idx)
        self.num_hour = 24
        self.num_dayofweek = 7

    def get_oof_data(self, oof):

        val_user_list = self.oof_user_set[oof]

        train = []
        valid = []

        group_df = self.all_df.groupby('userID')

        for userID, df in group_df:
            if userID in val_user_list:
                trn_df = df.iloc[:-1, :]
                val_df = df.copy()
                train.append(trn_df)
                valid.append(val_df)
            else:
                train.append(df)

        # normalize_score
        def get_normalize_score(df, all_df, vailid = False):
            ret_df = []

            group_df = df.groupby('userID')
            mean_answerCode_df = all_df.groupby('testId').mean()['answerCode']
            std_answerCode_df = all_df.groupby('testId').std()['answerCode']
            for userID, get_df in group_df:
                if vailid:
                    normalize_score_df = (get_df.iloc[:-1, :].groupby('testId').mean()['answerCode'] - mean_answerCode_df) / std_answerCode_df
                else:
                    normalize_score_df = (get_df.groupby('testId').mean()['answerCode'] - mean_answerCode_df) / std_answerCode_df
                    
                get_df = get_df.copy().set_index('testId')
                get_df['normalize_score'] = normalize_score_df
                ret_df.append(get_df.reset_index(drop = False))

            ret_df = pd.concat(ret_df).reset_index(drop = True)
            
            return ret_df

        train = pd.concat(train).reset_index(drop = True)
        valid = pd.concat(valid).reset_index(drop = True)

        # train = get_normalize_score(df = train, all_df = train)
        # valid = get_normalize_score(df = valid, all_df = train, vailid = True)
        
        return train, valid
    
    def get_test_data(self):
        return self.test_df.copy()

In [24]:
class CustomDataset(Dataset):
    def __init__(
        self, 
        df,
        cat_cols = ['assessmentItemID2idx', 'testId2idx', 'KnowledgeTag2idx', 'large_paper_number2idx', 'hour', 'dayofweek'],
        num_cols = ['now_elapsed', 'assessmentItemID_mean_now_elapsed', 'assessmentItemID_std_now_elapsed', 'assessmentItemID_mean_answerCode', 'assessmentItemID_std_answerCode'],
        max_len = None,
        window = None,
        data_augmentation = False,
        ):

        self.cat_cols = cat_cols
        self.num_cols = num_cols
        self.get_df = df.groupby('userID')
        self.user_list = df['userID'].unique().tolist()
        self.max_len = max_len
        self.window = window
        self.data_augmentation = data_augmentation
        if self.data_augmentation:
            self.cat_feature_list, self.num_feature_list, self.answerCode_list = self._data_augmentation()


    def __len__(self):
        if self.data_augmentation:
            return len(self.cat_feature_list)
        return len(self.user_list)

    def __getitem__(self, idx):
        if self.data_augmentation:
            cat_feature = self.cat_feature_list[idx]
            num_feature = self.num_feature_list[idx]
            answerCode = self.answerCode_list[idx]

            now_cat_feature = cat_feature[1:, :]
            now_num_feature = num_feature[1:, :]
            now_answerCode = answerCode[1:]
            
            past_cat_feature = cat_feature[:-1, :]
            past_num_feature = num_feature[:-1, :]
            past_answerCode = answerCode[:-1]
            
        else:
            user = self.user_list[idx]
            if self.max_len:
                get_df = self.get_df.get_group(user).iloc[-self.max_len:, :]
            else:
                get_df = self.get_df.get_group(user)

            now_df = get_df.iloc[1:, :]
            now_cat_feature = now_df[self.cat_cols].values
            now_num_feature = now_df[self.num_cols].values
            now_answerCode = now_df['answerCode'].values

            past_df = get_df.iloc[:-1, :]
            past_cat_feature = past_df[self.cat_cols].values
            past_num_feature = past_df[self.num_cols].values
            past_answerCode = past_df['answerCode'].values

        return {
            'past_cat_feature' : past_cat_feature, 
            'past_num_feature' : past_num_feature, 
            'past_answerCode' : past_answerCode, 
            'now_cat_feature' : now_cat_feature, 
            'now_num_feature' : now_num_feature, 
            'now_answerCode' : now_answerCode
            }
    

    def _data_augmentation(self):
        cat_feature_list = []
        num_feature_list = []
        answerCode_list = []
        for userID, get_df in tqdm(self.get_df):
            cat_feature = get_df[self.cat_cols].values[::-1]
            num_feature = get_df[self.num_cols].values[::-1]
            answerCode = get_df['answerCode'].values[::-1]

            start_idx = 0

            if len(get_df) <= self.max_len:
                cat_feature_list.append(cat_feature[::-1])
                num_feature_list.append(num_feature[::-1])
                answerCode_list.append(answerCode[::-1])
            else:
                while True:
                    if len(cat_feature[start_idx: start_idx + self.max_len, :]) < self.max_len:
                        cat_feature_list.append(cat_feature[start_idx: start_idx + self.max_len, :][::-1])
                        num_feature_list.append(num_feature[start_idx: start_idx + self.max_len, :][::-1])
                        answerCode_list.append(answerCode[start_idx: start_idx + self.max_len][::-1])
                        break
                    cat_feature_list.append(cat_feature[start_idx: start_idx + self.max_len, :][::-1])
                    num_feature_list.append(num_feature[start_idx: start_idx + self.max_len, :][::-1])
                    answerCode_list.append(answerCode[start_idx: start_idx + self.max_len][::-1])
                    start_idx += self.window
            
        return cat_feature_list, num_feature_list, answerCode_list

In [25]:
def pad_sequence(seq, max_len, padding_value = 0):
    try:
        seq_len, col = seq.shape
        padding = np.zeros((max_len - seq_len, col)) + padding_value
    except:
        seq_len = seq.shape[0]
        padding = np.zeros((max_len - seq_len, )) + padding_value

    padding_seq = np.concatenate([padding, seq])

    return padding_seq

def train_make_batch(samples):
    max_len = 0
    for sample in samples:
        seq_len, col = sample['past_cat_feature'].shape
        if max_len < seq_len:
            max_len = seq_len
    
    past_cat_feature = []
    past_num_feature = []
    past_answerCode = []
    now_cat_feature = []
    now_num_feature = []
    now_answerCode = []

    for sample in samples:
        past_cat_feature += [pad_sequence(sample['past_cat_feature'] + 1, max_len = max_len, padding_value = 0)]
        past_num_feature += [pad_sequence(sample['past_num_feature'], max_len = max_len, padding_value = 0)]
        past_answerCode += [pad_sequence(sample['past_answerCode'] + 1, max_len = max_len, padding_value = 0)]
        now_cat_feature += [pad_sequence(sample['now_cat_feature'] + 1, max_len = max_len, padding_value = 0)]
        now_num_feature += [pad_sequence(sample['now_num_feature'], max_len = max_len, padding_value = 0)]
        now_answerCode += [pad_sequence(sample['now_answerCode'], max_len = max_len, padding_value = -1)]

    return torch.tensor(past_cat_feature, dtype = torch.long), torch.tensor(past_num_feature, dtype = torch.float32), torch.tensor(past_answerCode, dtype = torch.long), torch.tensor(now_cat_feature, dtype = torch.long), torch.tensor(now_num_feature, dtype = torch.float32), torch.tensor(now_answerCode, dtype = torch.float32)

# 모델

In [26]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(ScaledDotProductAttention, self).__init__()
        self.hidden_units = hidden_units
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, Q, K, V, mask):
        """
        Q, K, V : (batch_size, num_heads, max_len, hidden_units)
        mask : (batch_size, 1, max_len, max_len)
        """
        attn_score = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(self.hidden_units) # (batch_size, num_heads, max_len, max_len)
        attn_score = attn_score.masked_fill(mask == 0, -1e9)  # 유사도가 0인 지점은 -infinity로 보내 softmax 결과가 0이 되도록 함
        attn_dist = self.dropout(F.softmax(attn_score, dim=-1))  # attention distribution
        output = torch.matmul(attn_dist, V)  # (batch_size, num_heads, max_len, hidden_units) / # dim of output : batchSize x num_head x seqLen x hidden_units
        return output, attn_dist


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, hidden_units, dropout_rate):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads # head의 수
        self.hidden_units = hidden_units
        
        # query, key, value, output 생성을 위해 Linear 모델 생성
        self.W_Q = nn.Linear(hidden_units, hidden_units, bias=False)
        self.W_K = nn.Linear(hidden_units, hidden_units, bias=False)
        self.W_V = nn.Linear(hidden_units, hidden_units, bias=False)
        self.W_O = nn.Linear(hidden_units, hidden_units, bias=False)

        self.attention = ScaledDotProductAttention(hidden_units, dropout_rate)
        self.dropout = nn.Dropout(dropout_rate) # dropout rate
        self.layerNorm = nn.LayerNorm(hidden_units, 1e-6) # layer normalization

    def forward(self, enc, mask):
        """
        enc : (batch_size, max_len, hidden_units)
        mask : (batch_size, 1, max_len, max_len)
        
        """
        residual = enc # residual connection을 위해 residual 부분을 저장
        batch_size, seqlen = enc.size(0), enc.size(1)

        # Query, Key, Value를 (num_head)개의 Head로 나누어 각기 다른 Linear projection을 통과시킴
        Q = self.W_Q(enc).view(batch_size, seqlen, self.num_heads, self.hidden_units // self.num_heads) # (batch_size, max_len, num_heads, hidden_units)
        K = self.W_K(enc).view(batch_size, seqlen, self.num_heads, self.hidden_units // self.num_heads) # (batch_size, max_len, num_heads, hidden_units)
        V = self.W_V(enc).view(batch_size, seqlen, self.num_heads, self.hidden_units // self.num_heads) # (batch_size, max_len, num_heads, hidden_units)

        # Head별로 각기 다른 attention이 가능하도록 Transpose 후 각각 attention에 통과시킴
        Q, K, V = Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2) # (batch_size, num_heads, max_len, hidden_units)
        output, attn_dist = self.attention(Q, K, V, mask) # output : (batch_size, num_heads, max_len, hidden_units) / attn_dist : (batch_size, num_heads, max_len, max_len)

        # 다시 Transpose한 후 모든 head들의 attention 결과를 합칩니다.
        output = output.transpose(1, 2).contiguous() # (batch_size, max_len, num_heads, hidden_units) / contiguous() : 가변적 메모리 할당
        output = output.view(batch_size, seqlen, -1) # (batch_size, max_len, hidden_units * num_heads)

        # Linear Projection, Dropout, Residual sum, and Layer Normalization
        output = self.layerNorm(self.dropout(self.W_O(output)) + residual) # (batch_size, max_len, hidden_units)
        return output, attn_dist


class PositionwiseFeedForward(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(PositionwiseFeedForward, self).__init__()

        self.W_1 = nn.Linear(hidden_units, hidden_units)
        self.W_2 = nn.Linear(hidden_units, hidden_units)
        self.dropout = nn.Dropout(dropout_rate)
        self.layerNorm = nn.LayerNorm(hidden_units, 1e-6) # layer normalization

    def forward(self, x):
        residual = x
        output = self.W_2(F.relu(self.dropout(self.W_1(x))))
        output = self.layerNorm(self.dropout(output) + residual)
        return output


class SASRecBlock(nn.Module):
    def __init__(self, num_heads, hidden_units, dropout_rate):
        super(SASRecBlock, self).__init__()
        self.attention = MultiHeadAttention(num_heads, hidden_units, dropout_rate)
        self.pointwise_feedforward = PositionwiseFeedForward(hidden_units, dropout_rate)

    def forward(self, input_enc, mask):
        """
        input_enc : (batch_size, max_len, hidden_units)
        mask : (batch_size, 1, max_len, max_len)
        """
        output_enc, attn_dist = self.attention(input_enc, mask)
        output_enc = self.pointwise_feedforward(output_enc)
        return output_enc, attn_dist


class SASRec(nn.Module):
    def __init__(
        self, 
        num_assessmentItemID, 
        num_testId,
        num_KnowledgeTag,
        num_large_paper_number,
        num_hour,
        num_dayofweek,
        num_cols,
        cat_cols,
        emb_size,
        hidden_units,
        num_heads, 
        num_layers, 
        dropout_rate, 
        device):
        super(SASRec, self).__init__()

        # past
        self.past_assessmentItemID_emb = nn.Embedding(num_assessmentItemID + 1, emb_size, padding_idx = 0) # 문항에 대한 정보
        self.past_testId_emb = nn.Embedding(num_testId + 1, emb_size, padding_idx = 0) # 시험지에 대한 정보
        self.past_KnowledgeTag_emb = nn.Embedding(num_KnowledgeTag + 1, emb_size, padding_idx = 0) # 지식 태그에 대한 정보
        self.past_large_paper_number_emb = nn.Embedding(num_large_paper_number + 1, emb_size, padding_idx = 0) # 핫년에 대한 정보
        self.past_hour_emb = nn.Embedding(num_hour + 1, emb_size, padding_idx = 0) # 문제 풀이 시간에 대한 정보
        self.past_dayofweek_emb = nn.Embedding(num_dayofweek + 1, emb_size, padding_idx = 0) # 문제 풀이 요일에 대항 정보
        self.past_answerCode_emb = nn.Embedding(3, hidden_units, padding_idx = 0) # 문제 정답 여부에 대한 정보

        self.past_cat_emb = nn.Sequential(
            nn.Linear(len(cat_cols) * emb_size, hidden_units // 2),
            nn.LayerNorm(hidden_units // 2, eps=1e-6)
        )

        self.past_num_emb = nn.Sequential(
            nn.Linear(len(num_cols), hidden_units // 2),
            nn.LayerNorm(hidden_units // 2, eps=1e-6)
        )

        self.emb_layernorm = nn.LayerNorm(hidden_units, eps=1e-6)

        self.past_lstm = nn.LSTM(
            input_size = hidden_units,
            hidden_size = hidden_units,
            num_layers = num_layers,
            batch_first = True,
            bidirectional = False,
            dropout = dropout_rate,
            )

        self.past_blocks = nn.ModuleList([SASRecBlock(num_heads, hidden_units, dropout_rate) for _ in range(num_layers)])

        # now
        self.now_assessmentItemID_emb = nn.Embedding(num_assessmentItemID + 1, emb_size, padding_idx = 0) # 문항에 대한 정보
        self.now_testId_emb = nn.Embedding(num_testId + 1, emb_size, padding_idx = 0) # 시험지에 대한 정보
        self.now_KnowledgeTag_emb = nn.Embedding(num_KnowledgeTag + 1, emb_size, padding_idx = 0) # 지식 태그에 대한 정보
        self.now_large_paper_number_emb = nn.Embedding(num_large_paper_number + 1, emb_size, padding_idx = 0) # 핫년에 대한 정보
        self.now_hour_emb = nn.Embedding(num_hour + 1, emb_size, padding_idx = 0) # 문제 풀이 시간에 대한 정보
        self.now_dayofweek_emb = nn.Embedding(num_dayofweek + 1, emb_size, padding_idx = 0) # 문제 풀이 요일에 대항 정보

        self.now_cat_emb = nn.Sequential(
            nn.Linear(len(cat_cols) * emb_size, hidden_units // 2),
            nn.LayerNorm(hidden_units // 2, eps=1e-6)
        )

        self.now_num_emb = nn.Sequential(
            nn.Linear(len(num_cols), hidden_units // 2),
            nn.LayerNorm(hidden_units // 2, eps=1e-6)
        )

        self.now_lstm = nn.LSTM(
            input_size = hidden_units,
            hidden_size = hidden_units,
            num_layers = num_layers,
            batch_first = True,
            bidirectional = False,
            dropout = dropout_rate,
            )

        self.now_blocks = nn.ModuleList([SASRecBlock(num_heads, hidden_units, dropout_rate) for _ in range(num_layers)])

        # predict

        self.dropout = nn.Dropout(dropout_rate)

        self.predict_layer = nn.Sequential(
            nn.Linear(hidden_units * 2, 1),
            nn.Sigmoid()
        )

        self.cat_cols = cat_cols
        self.num_cols = num_cols
        
        self.hidden_units = hidden_units
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.device = device
    
    
    def forward(self, past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature):
        """
        past_cat_feature : (batch_size, max_len, cat_cols)
        past_num_feature : (batch_size, max_len, num_cols)
        past_answerCode : (batch_size, max_len)

        now_cat_feature : (batch_size, max_len, cat_cols)
        now_num_feature : (batch_size, max_len, num_cols)
        
        """

        past_cat_emb_list = []
        for idx in range(len(self.cat_cols)):
            if self.cat_cols[idx] == 'assessmentItemID2idx':
                past_cat_emb_list.append(self.past_assessmentItemID_emb(past_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'testId2idx':
                past_cat_emb_list.append(self.past_testId_emb(past_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'KnowledgeTag2idx':
                past_cat_emb_list.append(self.past_KnowledgeTag_emb(past_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'large_paper_number2idx':
                past_cat_emb_list.append(self.past_large_paper_number_emb(past_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'hour':
                past_cat_emb_list.append(self.past_hour_emb(past_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'dayofweek':
                past_cat_emb_list.append(self.past_dayofweek_emb(past_cat_feature[:, :, idx]))

        past_cat_emb = torch.concat(past_cat_emb_list, dim = -1)
        past_cat_emb = self.past_cat_emb(past_cat_emb)
        past_num_emb = self.past_num_emb(past_num_feature)

        past_emb = torch.concat([past_cat_emb, past_num_emb], dim = -1)
        past_emb += self.past_answerCode_emb(past_answerCode.to(self.device))
        past_emb = self.emb_layernorm(past_emb) # LayerNorm

        # masking 
        mask_pad = torch.BoolTensor(past_answerCode > 0).unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_len)
        mask_time = (1 - torch.triu(torch.ones((1, 1, past_answerCode.size(1), past_answerCode.size(1))), diagonal=1)).bool() # (batch_size, 1, max_len, max_len)
        mask = (mask_pad & mask_time).to(self.device) # (batch_size, 1, max_len, max_len)
        for block in self.past_blocks:
            past_emb, attn_dist = block(past_emb, mask)

        past_emb, _ = self.past_lstm(past_emb)

        now_cat_emb_list = []
        for idx in range(len(self.cat_cols)):
            if self.cat_cols[idx] == 'assessmentItemID2idx':
                now_cat_emb_list.append(self.now_assessmentItemID_emb(now_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'testId2idx':
                now_cat_emb_list.append(self.now_testId_emb(now_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'KnowledgeTag2idx':
                now_cat_emb_list.append(self.now_KnowledgeTag_emb(now_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'large_paper_number2idx':
                now_cat_emb_list.append(self.now_large_paper_number_emb(now_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'hour':
                now_cat_emb_list.append(self.now_hour_emb(now_cat_feature[:, :, idx]))
            elif self.cat_cols[idx] == 'dayofweek':
                now_cat_emb_list.append(self.now_dayofweek_emb(now_cat_feature[:, :, idx]))

        now_cat_emb = torch.concat(now_cat_emb_list, dim = -1)
        now_cat_emb = self.now_cat_emb(now_cat_emb)
        now_num_emb = self.now_num_emb(now_num_feature)

        now_emb = torch.concat([now_cat_emb, now_num_emb], dim = -1)

        for block in self.now_blocks:
            now_emb, attn_dist = block(now_emb, mask)

        now_emb, _ = self.now_lstm(now_emb)

        emb = torch.concat([past_emb, now_emb], dim = -1)
        
        output = self.predict_layer(self.dropout(emb))

        return output

# 학습 함수

In [27]:
from sklearn.metrics import roc_auc_score

def train(model, data_loader, criterion, optimizer):
    model.train()
    loss_val = 0

    for past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature, now_answerCode in data_loader:

        past_cat_feature, past_num_feature, past_answerCode = past_cat_feature.to(device), past_num_feature.to(device), past_answerCode
        now_cat_feature, now_num_feature, now_answerCode = now_cat_feature.to(device), now_num_feature.to(device), now_answerCode.to(device)

        optimizer.zero_grad()

        output = model(past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature).squeeze(2)
        loss = criterion(output[now_answerCode != -1], now_answerCode[now_answerCode != -1])

        loss.backward()
        optimizer.step()

        loss_val += loss.item()

    loss_val /= len(data_loader)

    return loss_val

def evaluate(model, data_loader):
    model.eval()

    target = []
    pred = []

    with torch.no_grad():
        for past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature, now_answerCode in data_loader:
            past_cat_feature, past_num_feature, past_answerCode = past_cat_feature.to(device), past_num_feature.to(device), past_answerCode
            now_cat_feature, now_num_feature, now_answerCode = now_cat_feature.to(device), now_num_feature.to(device), now_answerCode.to(device)
            
            output = model(past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature).squeeze(2)

            target.extend(now_answerCode[:, -1].cpu().numpy().tolist())
            pred.extend(output[:, -1].cpu().numpy().tolist())

    roc_auc = roc_auc_score(target, pred)

    return roc_auc


def predict(model, data_loader):
    model.eval()

    pred = []

    with torch.no_grad():
        for past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature, now_answerCode in data_loader:
            past_cat_feature, past_num_feature, past_answerCode = past_cat_feature.to(device), past_num_feature.to(device), past_answerCode
            now_cat_feature, now_num_feature = now_cat_feature.to(device), now_num_feature.to(device)
            
            output = model(past_cat_feature, past_num_feature, past_answerCode, now_cat_feature, now_num_feature).squeeze(2)
            pred.extend(output[:, -1].cpu().numpy().tolist())

    return pred

# 학습

In [28]:
batch_size = 32
epochs = 20
lr = 0.001
device = 'cuda' if torch.cuda.is_available() else 'cpu'

emb_size = 64
hidden_units = 128
num_heads = 2 # 2,4,8,16,32
num_layers = 1
dropout_rate = 0.5
num_workers = 8

max_len = 50
window = 10
data_augmentation = False

DATA_PATH = '/opt/ml/input/data'
MODEL_PATH = '/opt/ml/model'
SUBMISSION_PATH = '/opt/ml/submission'

model_name = 'Transformer-and-LSTM-Encoder-Decoder-each-Embedding-add-feature.pt'
submission_name = 'Transformer-and-LSTM-Encoder-Decoder-each-Embedding-add-feature.csv'

In [29]:
if not os.path.isdir(MODEL_PATH):
    os.mkdir(MODEL_PATH)

In [30]:
if not os.path.isdir(SUBMISSION_PATH):
    os.mkdir(SUBMISSION_PATH)

In [12]:
make_dataset = MakeDataset(DATA_PATH = DATA_PATH)

# OOF Ensemble

In [31]:
oof_roc_auc = 0

for oof in make_dataset.oof_user_set.keys():
    train_df, valid_df = make_dataset.get_oof_data(oof)
    
    seed_everything(22 + oof)
    
    train_dataset = CustomDataset(df = train_df,)
    train_data_loader = DataLoader(
        train_dataset, 
        batch_size = batch_size, 
        shuffle = True, 
        drop_last = False,
        collate_fn = train_make_batch,
        num_workers = num_workers)

    valid_dataset = CustomDataset(df = valid_df,)
    valid_data_loader = DataLoader(
        valid_dataset, 
        batch_size = 1, 
        shuffle = False, 
        drop_last = False,
        collate_fn = train_make_batch,
        num_workers = num_workers)

    model = SASRec(
        num_assessmentItemID = make_dataset.num_assessmentItemID, 
        num_testId = make_dataset.num_testId,
        num_KnowledgeTag = make_dataset.num_KnowledgeTag,
        num_large_paper_number = make_dataset.num_large_paper_number,
        num_hour = make_dataset.num_hour,
        num_dayofweek = make_dataset.num_dayofweek,
        num_cols = train_dataset.num_cols,
        cat_cols = train_dataset.cat_cols,
        emb_size = emb_size,
        hidden_units = hidden_units,
        num_heads = num_heads,
        num_layers = num_layers,
        dropout_rate = dropout_rate,
        device = device).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    criterion = nn.BCELoss()

    # pre_emb = Word2Vec.load(os.path.join(MODEL_PATH, 'Word2Vec_Embedding_Model_window_50.model'))

    # assessmentItemID_li = make_dataset.assessmentItemID2idx.keys()

    # with torch.no_grad():
    #     for assessmentItemID in assessmentItemID_li:
    #         idx = make_dataset.assessmentItemID2idx[assessmentItemID]
    #         model.assessmentItemID_emb.weight[idx + 1] = torch.tensor(pre_emb.wv[assessmentItemID]).to(device)

    best_epoch = 0
    best_train_loss = 0
    best_roc_auc = 0

    for epoch in range(1, epochs + 1):
        tbar = tqdm(range(1))
        for _ in tbar:
            train_loss = train(model = model, data_loader = train_data_loader, criterion = criterion, optimizer = optimizer)
            roc_auc = evaluate(model = model, data_loader = valid_data_loader)
            if best_roc_auc < roc_auc:
                best_epoch = epoch
                best_train_loss = train_loss
                best_roc_auc = roc_auc
                torch.save(model.state_dict(), os.path.join(MODEL_PATH, f'oof_{oof}_' + model_name))

            tbar.set_description(f'OOF-{oof}| Epoch: {epoch:3d}| Train loss: {train_loss:.5f}| roc_auc: {roc_auc:.5f}')
    
    print(f'BEST OOF-{oof}| Epoch: {best_epoch:3d}| Train loss: {best_train_loss:.5f}| roc_auc: {best_roc_auc:.5f}')

    oof_roc_auc += best_roc_auc

print(f'Total roc_auc: {oof_roc_auc / len(make_dataset.oof_user_set.keys()):.5f}')

OOF-0| Epoch:   1| Train loss: 0.51096| roc_auc: 0.82697: 100%|██████████| 1/1 [02:05<00:00, 125.52s/it]
OOF-0| Epoch:   2| Train loss: 0.46838| roc_auc: 0.84584: 100%|██████████| 1/1 [02:03<00:00, 123.60s/it]
OOF-0| Epoch:   3| Train loss: 0.45416| roc_auc: 0.84911: 100%|██████████| 1/1 [02:04<00:00, 124.01s/it]
OOF-0| Epoch:   4| Train loss: 0.44686| roc_auc: 0.85202: 100%|██████████| 1/1 [02:03<00:00, 123.37s/it]
OOF-0| Epoch:   5| Train loss: 0.44319| roc_auc: 0.85429: 100%|██████████| 1/1 [02:03<00:00, 123.33s/it]
OOF-0| Epoch:   6| Train loss: 0.44022| roc_auc: 0.85450: 100%|██████████| 1/1 [02:02<00:00, 122.78s/it]
OOF-0| Epoch:   7| Train loss: 0.43847| roc_auc: 0.85417: 100%|██████████| 1/1 [02:03<00:00, 123.18s/it]
OOF-0| Epoch:   8| Train loss: 0.43620| roc_auc: 0.85411: 100%|██████████| 1/1 [02:02<00:00, 122.66s/it]
OOF-0| Epoch:   9| Train loss: 0.43414| roc_auc: 0.85365: 100%|██████████| 1/1 [02:05<00:00, 125.45s/it]
OOF-0| Epoch:  10| Train loss: 0.43314| roc_auc: 0.8544

```
OOF-0| Epoch:   1| Train loss: 0.52036| roc_auc: 0.81918: 100%|██████████| 1/1 [01:54<00:00, 114.87s/it]
OOF-0| Epoch:   2| Train loss: 0.47700| roc_auc: 0.84003: 100%|██████████| 1/1 [01:57<00:00, 117.49s/it]
OOF-0| Epoch:   3| Train loss: 0.45824| roc_auc: 0.84969: 100%|██████████| 1/1 [01:58<00:00, 118.67s/it]
OOF-0| Epoch:   4| Train loss: 0.45016| roc_auc: 0.85025: 100%|██████████| 1/1 [01:57<00:00, 117.37s/it]
OOF-0| Epoch:   5| Train loss: 0.44466| roc_auc: 0.85088: 100%|██████████| 1/1 [01:57<00:00, 117.79s/it]
OOF-0| Epoch:   6| Train loss: 0.44167| roc_auc: 0.85585: 100%|██████████| 1/1 [01:55<00:00, 115.66s/it]
OOF-0| Epoch:   7| Train loss: 0.43878| roc_auc: 0.85691: 100%|██████████| 1/1 [01:56<00:00, 116.95s/it]
OOF-0| Epoch:   8| Train loss: 0.43657| roc_auc: 0.85590: 100%|██████████| 1/1 [01:55<00:00, 115.28s/it]
OOF-0| Epoch:   9| Train loss: 0.43426| roc_auc: 0.85459: 100%|██████████| 1/1 [01:57<00:00, 117.32s/it]
OOF-0| Epoch:  10| Train loss: 0.43261| roc_auc: 0.85731: 100%|██████████| 1/1 [01:54<00:00, 114.69s/it]
OOF-0| Epoch:  11| Train loss: 0.43107| roc_auc: 0.85737: 100%|██████████| 1/1 [01:56<00:00, 116.13s/it]
OOF-0| Epoch:  12| Train loss: 0.42911| roc_auc: 0.85611: 100%|██████████| 1/1 [01:57<00:00, 117.01s/it]
OOF-0| Epoch:  13| Train loss: 0.42757| roc_auc: 0.85523: 100%|██████████| 1/1 [01:55<00:00, 115.94s/it]
OOF-0| Epoch:  14| Train loss: 0.42626| roc_auc: 0.85689: 100%|██████████| 1/1 [01:56<00:00, 116.23s/it]
OOF-0| Epoch:  15| Train loss: 0.42450| roc_auc: 0.85636: 100%|██████████| 1/1 [01:54<00:00, 114.13s/it]
OOF-0| Epoch:  16| Train loss: 0.42297| roc_auc: 0.85745: 100%|██████████| 1/1 [01:55<00:00, 115.89s/it]
OOF-0| Epoch:  17| Train loss: 0.42085| roc_auc: 0.85434: 100%|██████████| 1/1 [01:56<00:00, 116.49s/it]
OOF-0| Epoch:  18| Train loss: 0.41954| roc_auc: 0.85886: 100%|██████████| 1/1 [01:58<00:00, 118.36s/it]
OOF-0| Epoch:  19| Train loss: 0.41753| roc_auc: 0.85614: 100%|██████████| 1/1 [01:57<00:00, 117.64s/it]
OOF-0| Epoch:  20| Train loss: 0.41564| roc_auc: 0.85646: 100%|██████████| 1/1 [01:58<00:00, 118.19s/it]
BEST OOF-0| Epoch:  18| Train loss: 0.41954| roc_auc: 0.85886

```

```
num-head-8

OOF-0| Epoch:   1| Train loss: 0.50920| roc_auc: 0.82475: 100%|██████████| 1/1 [02:19<00:00, 139.90s/it]
OOF-0| Epoch:   2| Train loss: 0.46884| roc_auc: 0.84497: 100%|██████████| 1/1 [02:18<00:00, 138.22s/it]
OOF-0| Epoch:   3| Train loss: 0.45432| roc_auc: 0.85040: 100%|██████████| 1/1 [02:19<00:00, 139.15s/it]
OOF-0| Epoch:   4| Train loss: 0.44692| roc_auc: 0.85332: 100%|██████████| 1/1 [02:16<00:00, 136.10s/it]
OOF-0| Epoch:   5| Train loss: 0.44264| roc_auc: 0.85380: 100%|██████████| 1/1 [02:16<00:00, 136.87s/it]
OOF-0| Epoch:   6| Train loss: 0.44024| roc_auc: 0.85650: 100%|██████████| 1/1 [02:17<00:00, 137.41s/it]
OOF-0| Epoch:   7| Train loss: 0.43706| roc_auc: 0.85545: 100%|██████████| 1/1 [02:15<00:00, 135.47s/it]
OOF-0| Epoch:   8| Train loss: 0.43546| roc_auc: 0.85398: 100%|██████████| 1/1 [02:18<00:00, 138.32s/it]
OOF-0| Epoch:   9| Train loss: 0.43349| roc_auc: 0.85546: 100%|██████████| 1/1 [02:16<00:00, 136.43s/it]
OOF-0| Epoch:  10| Train loss: 0.43202| roc_auc: 0.85714: 100%|██████████| 1/1 [02:17<00:00, 137.22s/it]
OOF-0| Epoch:  11| Train loss: 0.43065| roc_auc: 0.85590: 100%|██████████| 1/1 [02:18<00:00, 138.73s/it]
OOF-0| Epoch:  12| Train loss: 0.42957| roc_auc: 0.85310: 100%|██████████| 1/1 [02:16<00:00, 136.90s/it]
OOF-0| Epoch:  13| Train loss: 0.42805| roc_auc: 0.85498: 100%|██████████| 1/1 [02:16<00:00, 136.56s/it]
OOF-0| Epoch:  14| Train loss: 0.42686| roc_auc: 0.85525: 100%|██████████| 1/1 [02:17<00:00, 137.80s/it]
OOF-0| Epoch:  15| Train loss: 0.42572| roc_auc: 0.85763: 100%|██████████| 1/1 [02:16<00:00, 136.77s/it]
OOF-0| Epoch:  16| Train loss: 0.42424| roc_auc: 0.85718: 100%|██████████| 1/1 [02:18<00:00, 138.39s/it]
OOF-0| Epoch:  17| Train loss: 0.42323| roc_auc: 0.85417: 100%|██████████| 1/1 [02:18<00:00, 138.44s/it]
OOF-0| Epoch:  18| Train loss: 0.42205| roc_auc: 0.85348: 100%|██████████| 1/1 [02:23<00:00, 143.78s/it]
OOF-0| Epoch:  19| Train loss: 0.42115| roc_auc: 0.85546: 100%|██████████| 1/1 [02:17<00:00, 137.05s/it]
OOF-0| Epoch:  20| Train loss: 0.41956| roc_auc: 0.85327: 100%|██████████| 1/1 [02:18<00:00, 138.90s/it]
BEST OOF-0| Epoch:  15| Train loss: 0.42572| roc_auc: 0.85763

OOF-1| Epoch:   1| Train loss: 0.51279| roc_auc: 0.81536: 100%|██████████| 1/1 [02:22<00:00, 142.50s/it]
OOF-1| Epoch:   2| Train loss: 0.46936| roc_auc: 0.83921: 100%|██████████| 1/1 [02:22<00:00, 142.29s/it]
OOF-1| Epoch:   3| Train loss: 0.45532| roc_auc: 0.84452: 100%|██████████| 1/1 [02:21<00:00, 141.83s/it]
OOF-1| Epoch:   4| Train loss: 0.44919| roc_auc: 0.84952: 100%|██████████| 1/1 [02:22<00:00, 142.93s/it]
OOF-1| Epoch:   5| Train loss: 0.44409| roc_auc: 0.85425: 100%|██████████| 1/1 [02:36<00:00, 156.87s/it]
OOF-1| Epoch:   6| Train loss: 0.44063| roc_auc: 0.85290: 100%|██████████| 1/1 [02:21<00:00, 141.98s/it]
OOF-1| Epoch:   7| Train loss: 0.43809| roc_auc: 0.85419: 100%|██████████| 1/1 [02:20<00:00, 140.42s/it]
OOF-1| Epoch:   8| Train loss: 0.43596| roc_auc: 0.85666: 100%|██████████| 1/1 [02:21<00:00, 141.32s/it]
OOF-1| Epoch:   9| Train loss: 0.43374| roc_auc: 0.85673: 100%|██████████| 1/1 [02:21<00:00, 141.56s/it]
OOF-1| Epoch:  10| Train loss: 0.43214| roc_auc: 0.85754: 100%|██████████| 1/1 [02:20<00:00, 140.37s/it]
OOF-1| Epoch:  11| Train loss: 0.43081| roc_auc: 0.85575: 100%|██████████| 1/1 [02:19<00:00, 139.63s/it]
OOF-1| Epoch:  12| Train loss: 0.42971| roc_auc: 0.85489: 100%|██████████| 1/1 [02:21<00:00, 141.79s/it]
OOF-1| Epoch:  13| Train loss: 0.42812| roc_auc: 0.85803: 100%|██████████| 1/1 [02:21<00:00, 141.54s/it]
OOF-1| Epoch:  14| Train loss: 0.42707| roc_auc: 0.85766: 100%|██████████| 1/1 [02:20<00:00, 140.56s/it]
OOF-1| Epoch:  15| Train loss: 0.42542| roc_auc: 0.85789: 100%|██████████| 1/1 [02:21<00:00, 141.67s/it]
OOF-1| Epoch:  16| Train loss: 0.42444| roc_auc: 0.85779: 100%|██████████| 1/1 [02:20<00:00, 140.40s/it]
OOF-1| Epoch:  17| Train loss: 0.42325| roc_auc: 0.85840: 100%|██████████| 1/1 [02:20<00:00, 140.55s/it]
OOF-1| Epoch:  18| Train loss: 0.42214| roc_auc: 0.85964: 100%|██████████| 1/1 [02:24<00:00, 144.61s/it]
OOF-1| Epoch:  19| Train loss: 0.42051| roc_auc: 0.85680: 100%|██████████| 1/1 [02:21<00:00, 141.18s/it]
OOF-1| Epoch:  20| Train loss: 0.41990| roc_auc: 0.85859: 100%|██████████| 1/1 [02:18<00:00, 138.96s/it]
BEST OOF-1| Epoch:  18| Train loss: 0.42214| roc_auc: 0.85964

OOF-2| Epoch:   1| Train loss: 0.51173| roc_auc: 0.79473: 100%|██████████| 1/1 [02:17<00:00, 137.77s/it]
OOF-2| Epoch:   2| Train loss: 0.46982| roc_auc: 0.81169: 100%|██████████| 1/1 [02:17<00:00, 137.17s/it]
OOF-2| Epoch:   3| Train loss: 0.45566| roc_auc: 0.82065: 100%|██████████| 1/1 [02:20<00:00, 140.65s/it]
OOF-2| Epoch:   4| Train loss: 0.44883| roc_auc: 0.82155: 100%|██████████| 1/1 [02:30<00:00, 150.64s/it]
OOF-2| Epoch:   5| Train loss: 0.44482| roc_auc: 0.82134: 100%|██████████| 1/1 [02:17<00:00, 137.73s/it]
OOF-2| Epoch:   6| Train loss: 0.44070| roc_auc: 0.82342: 100%|██████████| 1/1 [02:22<00:00, 142.38s/it]
OOF-2| Epoch:   7| Train loss: 0.43813| roc_auc: 0.82699: 100%|██████████| 1/1 [02:28<00:00, 148.92s/it]
OOF-2| Epoch:   8| Train loss: 0.43610| roc_auc: 0.82629: 100%|██████████| 1/1 [02:26<00:00, 146.50s/it]
OOF-2| Epoch:   9| Train loss: 0.43438| roc_auc: 0.82808: 100%|██████████| 1/1 [02:34<00:00, 154.04s/it]
OOF-2| Epoch:  10| Train loss: 0.43267| roc_auc: 0.82730: 100%|██████████| 1/1 [02:27<00:00, 147.95s/it]
OOF-2| Epoch:  11| Train loss: 0.43127| roc_auc: 0.82876: 100%|██████████| 1/1 [02:22<00:00, 142.05s/it]
OOF-2| Epoch:  12| Train loss: 0.43009| roc_auc: 0.82857: 100%|██████████| 1/1 [02:17<00:00, 137.14s/it]
OOF-2| Epoch:  13| Train loss: 0.42863| roc_auc: 0.82909: 100%|██████████| 1/1 [02:21<00:00, 141.41s/it]
OOF-2| Epoch:  14| Train loss: 0.42723| roc_auc: 0.83026: 100%|██████████| 1/1 [02:22<00:00, 142.32s/it]
OOF-2| Epoch:  15| Train loss: 0.42572| roc_auc: 0.82854: 100%|██████████| 1/1 [02:20<00:00, 140.03s/it]
OOF-2| Epoch:  16| Train loss: 0.42461| roc_auc: 0.82929: 100%|██████████| 1/1 [02:21<00:00, 141.04s/it]
OOF-2| Epoch:  17| Train loss: 0.42370| roc_auc: 0.83052: 100%|██████████| 1/1 [02:27<00:00, 147.07s/it]
OOF-2| Epoch:  18| Train loss: 0.42196| roc_auc: 0.82952: 100%|██████████| 1/1 [02:20<00:00, 140.20s/it]
OOF-2| Epoch:  19| Train loss: 0.42149| roc_auc: 0.83075: 100%|██████████| 1/1 [02:20<00:00, 140.22s/it]
OOF-2| Epoch:  20| Train loss: 0.41978| roc_auc: 0.82729: 100%|██████████| 1/1 [02:21<00:00, 141.82s/it]
BEST OOF-2| Epoch:  19| Train loss: 0.42149| roc_auc: 0.83075

OOF-3| Epoch:   1| Train loss: 0.50995| roc_auc: 0.79665: 100%|██████████| 1/1 [02:37<00:00, 157.60s/it]
OOF-3| Epoch:   2| Train loss: 0.46852| roc_auc: 0.81843: 100%|██████████| 1/1 [02:38<00:00, 158.77s/it]
OOF-3| Epoch:   3| Train loss: 0.45517| roc_auc: 0.82834: 100%|██████████| 1/1 [02:37<00:00, 157.10s/it]
OOF-3| Epoch:   4| Train loss: 0.44760| roc_auc: 0.82956: 100%|██████████| 1/1 [02:31<00:00, 151.33s/it]
OOF-3| Epoch:   5| Train loss: 0.44386| roc_auc: 0.82926: 100%|██████████| 1/1 [02:28<00:00, 148.99s/it]
OOF-3| Epoch:   6| Train loss: 0.44080| roc_auc: 0.83329: 100%|██████████| 1/1 [02:25<00:00, 145.31s/it]
OOF-3| Epoch:   7| Train loss: 0.43836| roc_auc: 0.83230: 100%|██████████| 1/1 [02:24<00:00, 144.73s/it]
OOF-3| Epoch:   8| Train loss: 0.43664| roc_auc: 0.83361: 100%|██████████| 1/1 [02:27<00:00, 147.86s/it]
OOF-3| Epoch:   9| Train loss: 0.43412| roc_auc: 0.83277: 100%|██████████| 1/1 [02:25<00:00, 145.65s/it]
OOF-3| Epoch:  10| Train loss: 0.43257| roc_auc: 0.83396: 100%|██████████| 1/1 [02:29<00:00, 149.01s/it]
OOF-3| Epoch:  11| Train loss: 0.43092| roc_auc: 0.83664: 100%|██████████| 1/1 [02:25<00:00, 145.97s/it]
OOF-3| Epoch:  12| Train loss: 0.42942| roc_auc: 0.83432: 100%|██████████| 1/1 [02:24<00:00, 144.50s/it]
OOF-3| Epoch:  13| Train loss: 0.42819| roc_auc: 0.83605: 100%|██████████| 1/1 [02:23<00:00, 143.15s/it]
OOF-3| Epoch:  14| Train loss: 0.42680| roc_auc: 0.83396: 100%|██████████| 1/1 [02:32<00:00, 152.35s/it]
OOF-3| Epoch:  15| Train loss: 0.42528| roc_auc: 0.83877: 100%|██████████| 1/1 [02:25<00:00, 145.03s/it]
OOF-3| Epoch:  16| Train loss: 0.42438| roc_auc: 0.83804: 100%|██████████| 1/1 [02:22<00:00, 142.87s/it]
OOF-3| Epoch:  17| Train loss: 0.42302| roc_auc: 0.83676: 100%|██████████| 1/1 [02:24<00:00, 144.06s/it]
OOF-3| Epoch:  18| Train loss: 0.42179| roc_auc: 0.83798: 100%|██████████| 1/1 [02:27<00:00, 147.40s/it]
OOF-3| Epoch:  19| Train loss: 0.42055| roc_auc: 0.83865: 100%|██████████| 1/1 [02:23<00:00, 143.90s/it]
OOF-3| Epoch:  20| Train loss: 0.41964| roc_auc: 0.83802: 100%|██████████| 1/1 [02:30<00:00, 150.22s/it]
BEST OOF-3| Epoch:  15| Train loss: 0.42528| roc_auc: 0.83877

OOF-4| Epoch:   1| Train loss: 0.50810| roc_auc: 0.80734: 100%|██████████| 1/1 [02:26<00:00, 146.30s/it]
OOF-4| Epoch:   2| Train loss: 0.46918| roc_auc: 0.82926: 100%|██████████| 1/1 [02:23<00:00, 143.44s/it]
OOF-4| Epoch:   3| Train loss: 0.45513| roc_auc: 0.83609: 100%|██████████| 1/1 [02:23<00:00, 143.62s/it]
OOF-4| Epoch:   4| Train loss: 0.44781| roc_auc: 0.83864: 100%|██████████| 1/1 [02:24<00:00, 144.23s/it]
OOF-4| Epoch:   5| Train loss: 0.44328| roc_auc: 0.83966: 100%|██████████| 1/1 [02:24<00:00, 144.25s/it]
OOF-4| Epoch:   6| Train loss: 0.44016| roc_auc: 0.84340: 100%|██████████| 1/1 [02:22<00:00, 142.59s/it]
OOF-4| Epoch:   7| Train loss: 0.43716| roc_auc: 0.84532: 100%|██████████| 1/1 [02:30<00:00, 150.12s/it]
OOF-4| Epoch:   8| Train loss: 0.43563| roc_auc: 0.84293: 100%|██████████| 1/1 [02:24<00:00, 144.37s/it]
OOF-4| Epoch:   9| Train loss: 0.43350| roc_auc: 0.84527: 100%|██████████| 1/1 [02:25<00:00, 145.76s/it]
OOF-4| Epoch:  10| Train loss: 0.43198| roc_auc: 0.84752: 100%|██████████| 1/1 [02:24<00:00, 144.53s/it]
OOF-4| Epoch:  11| Train loss: 0.43080| roc_auc: 0.84394: 100%|██████████| 1/1 [02:24<00:00, 144.84s/it]
OOF-4| Epoch:  12| Train loss: 0.42884| roc_auc: 0.84691: 100%|██████████| 1/1 [02:24<00:00, 144.77s/it]
OOF-4| Epoch:  13| Train loss: 0.42803| roc_auc: 0.84576: 100%|██████████| 1/1 [02:23<00:00, 143.13s/it]
OOF-4| Epoch:  14| Train loss: 0.42662| roc_auc: 0.84762: 100%|██████████| 1/1 [02:25<00:00, 145.26s/it]
OOF-4| Epoch:  15| Train loss: 0.42544| roc_auc: 0.84574: 100%|██████████| 1/1 [02:27<00:00, 147.13s/it]
OOF-4| Epoch:  16| Train loss: 0.42445| roc_auc: 0.84528: 100%|██████████| 1/1 [02:26<00:00, 146.96s/it]
OOF-4| Epoch:  17| Train loss: 0.42323| roc_auc: 0.84549: 100%|██████████| 1/1 [02:25<00:00, 145.53s/it]
OOF-4| Epoch:  18| Train loss: 0.42183| roc_auc: 0.84524: 100%|██████████| 1/1 [02:23<00:00, 143.60s/it]
OOF-4| Epoch:  19| Train loss: 0.42043| roc_auc: 0.84609: 100%|██████████| 1/1 [02:24<00:00, 144.84s/it]
OOF-4| Epoch:  20| Train loss: 0.41954| roc_auc: 0.84506: 100%|██████████| 1/1 [02:26<00:00, 146.87s/it]
BEST OOF-4| Epoch:  14| Train loss: 0.42662| roc_auc: 0.84762

Total roc_auc: 0.84688
```

```
num-head-4
OOF-0| Epoch:   1| Train loss: 0.50902| roc_auc: 0.82562: 100%|██████████| 1/1 [01:57<00:00, 117.15s/it]
OOF-0| Epoch:   2| Train loss: 0.46809| roc_auc: 0.84612: 100%|██████████| 1/1 [01:55<00:00, 115.90s/it]
OOF-0| Epoch:   3| Train loss: 0.45299| roc_auc: 0.85097: 100%|██████████| 1/1 [01:56<00:00, 116.11s/it]
OOF-0| Epoch:   4| Train loss: 0.44614| roc_auc: 0.85433: 100%|██████████| 1/1 [01:54<00:00, 114.93s/it]
OOF-0| Epoch:   5| Train loss: 0.44228| roc_auc: 0.85423: 100%|██████████| 1/1 [01:56<00:00, 116.62s/it]
OOF-0| Epoch:   6| Train loss: 0.43995| roc_auc: 0.85500: 100%|██████████| 1/1 [01:56<00:00, 116.18s/it]
OOF-0| Epoch:   7| Train loss: 0.43697| roc_auc: 0.85591: 100%|██████████| 1/1 [01:55<00:00, 115.39s/it]
OOF-0| Epoch:   8| Train loss: 0.43548| roc_auc: 0.85338: 100%|██████████| 1/1 [01:55<00:00, 115.17s/it]
OOF-0| Epoch:   9| Train loss: 0.43343| roc_auc: 0.85503: 100%|██████████| 1/1 [01:55<00:00, 115.27s/it]
OOF-0| Epoch:  10| Train loss: 0.43202| roc_auc: 0.85662: 100%|██████████| 1/1 [01:55<00:00, 115.12s/it]
OOF-0| Epoch:  11| Train loss: 0.43063| roc_auc: 0.85552: 100%|██████████| 1/1 [01:55<00:00, 115.99s/it]
OOF-0| Epoch:  12| Train loss: 0.42975| roc_auc: 0.85550: 100%|██████████| 1/1 [01:53<00:00, 113.77s/it]
OOF-0| Epoch:  13| Train loss: 0.42828| roc_auc: 0.85552: 100%|██████████| 1/1 [01:55<00:00, 115.69s/it]
OOF-0| Epoch:  14| Train loss: 0.42702| roc_auc: 0.85437: 100%|██████████| 1/1 [01:56<00:00, 116.17s/it]
OOF-0| Epoch:  15| Train loss: 0.42586| roc_auc: 0.85719: 100%|██████████| 1/1 [01:54<00:00, 114.93s/it]
OOF-0| Epoch:  16| Train loss: 0.42442| roc_auc: 0.85702: 100%|██████████| 1/1 [01:53<00:00, 113.31s/it]
OOF-0| Epoch:  17| Train loss: 0.42344| roc_auc: 0.85581: 100%|██████████| 1/1 [01:55<00:00, 115.23s/it]
OOF-0| Epoch:  18| Train loss: 0.42239| roc_auc: 0.85651: 100%|██████████| 1/1 [01:54<00:00, 114.78s/it]
OOF-0| Epoch:  19| Train loss: 0.42130| roc_auc: 0.85805: 100%|██████████| 1/1 [01:53<00:00, 113.14s/it]
OOF-0| Epoch:  20| Train loss: 0.41978| roc_auc: 0.85687: 100%|██████████| 1/1 [01:54<00:00, 114.15s/it]
BEST OOF-0| Epoch:  19| Train loss: 0.42130| roc_auc: 0.85805
OOF-1| Epoch:   1| Train loss: 0.51285| roc_auc: 0.81700: 100%|██████████| 1/1 [01:55<00:00, 115.82s/it]
OOF-1| Epoch:   2| Train loss: 0.46937| roc_auc: 0.83834: 100%|██████████| 1/1 [01:57<00:00, 117.06s/it]
OOF-1| Epoch:   3| Train loss: 0.45446| roc_auc: 0.84686: 100%|██████████| 1/1 [01:56<00:00, 116.05s/it]
OOF-1| Epoch:   4| Train loss: 0.44771| roc_auc: 0.85008: 100%|██████████| 1/1 [01:57<00:00, 117.28s/it]
OOF-1| Epoch:   5| Train loss: 0.44301| roc_auc: 0.85404: 100%|██████████| 1/1 [01:58<00:00, 118.22s/it]
OOF-1| Epoch:   6| Train loss: 0.44004| roc_auc: 0.85412: 100%|██████████| 1/1 [01:57<00:00, 117.40s/it]
OOF-1| Epoch:   7| Train loss: 0.43781| roc_auc: 0.85416: 100%|██████████| 1/1 [01:56<00:00, 116.22s/it]
OOF-1| Epoch:   8| Train loss: 0.43591| roc_auc: 0.85418: 100%|██████████| 1/1 [01:56<00:00, 116.29s/it]
OOF-1| Epoch:   9| Train loss: 0.43387| roc_auc: 0.85624: 100%|██████████| 1/1 [01:57<00:00, 117.77s/it]
OOF-1| Epoch:  10| Train loss: 0.43199| roc_auc: 0.85756: 100%|██████████| 1/1 [01:56<00:00, 116.91s/it]
OOF-1| Epoch:  11| Train loss: 0.43090| roc_auc: 0.85780: 100%|██████████| 1/1 [01:57<00:00, 117.43s/it]
OOF-1| Epoch:  12| Train loss: 0.42961| roc_auc: 0.85557: 100%|██████████| 1/1 [01:56<00:00, 116.16s/it]
OOF-1| Epoch:  13| Train loss: 0.42827| roc_auc: 0.85748: 100%|██████████| 1/1 [01:56<00:00, 116.58s/it]
OOF-1| Epoch:  14| Train loss: 0.42734| roc_auc: 0.85987: 100%|██████████| 1/1 [01:57<00:00, 117.85s/it]
OOF-1| Epoch:  15| Train loss: 0.42588| roc_auc: 0.85912: 100%|██████████| 1/1 [01:55<00:00, 115.99s/it]
OOF-1| Epoch:  16| Train loss: 0.42482| roc_auc: 0.86096: 100%|██████████| 1/1 [01:58<00:00, 118.42s/it]
OOF-1| Epoch:  17| Train loss: 0.42367| roc_auc: 0.86259: 100%|██████████| 1/1 [01:55<00:00, 115.72s/it]
OOF-1| Epoch:  18| Train loss: 0.42282| roc_auc: 0.86145: 100%|██████████| 1/1 [01:58<00:00, 118.47s/it]
OOF-1| Epoch:  19| Train loss: 0.42118| roc_auc: 0.85960: 100%|██████████| 1/1 [01:58<00:00, 118.09s/it]
OOF-1| Epoch:  20| Train loss: 0.42068| roc_auc: 0.86045: 100%|██████████| 1/1 [01:57<00:00, 117.78s/it]
BEST OOF-1| Epoch:  17| Train loss: 0.42367| roc_auc: 0.86259
OOF-2| Epoch:   1| Train loss: 0.51174| roc_auc: 0.79507: 100%|██████████| 1/1 [01:53<00:00, 113.93s/it]
OOF-2| Epoch:   2| Train loss: 0.46915| roc_auc: 0.81242: 100%|██████████| 1/1 [01:56<00:00, 116.49s/it]
OOF-2| Epoch:   3| Train loss: 0.45443| roc_auc: 0.82134: 100%|██████████| 1/1 [01:56<00:00, 116.67s/it]
OOF-2| Epoch:   4| Train loss: 0.44754| roc_auc: 0.82415: 100%|██████████| 1/1 [01:56<00:00, 116.01s/it]
OOF-2| Epoch:   5| Train loss: 0.44379| roc_auc: 0.82391: 100%|██████████| 1/1 [01:53<00:00, 113.19s/it]
OOF-2| Epoch:   6| Train loss: 0.44011| roc_auc: 0.82470: 100%|██████████| 1/1 [01:56<00:00, 116.60s/it]
OOF-2| Epoch:   7| Train loss: 0.43781| roc_auc: 0.82604: 100%|██████████| 1/1 [01:57<00:00, 117.48s/it]
OOF-2| Epoch:   8| Train loss: 0.43609| roc_auc: 0.82529: 100%|██████████| 1/1 [01:56<00:00, 116.44s/it]
OOF-2| Epoch:   9| Train loss: 0.43446| roc_auc: 0.82994: 100%|██████████| 1/1 [01:57<00:00, 117.08s/it]
OOF-2| Epoch:  10| Train loss: 0.43298| roc_auc: 0.82598: 100%|██████████| 1/1 [01:53<00:00, 113.98s/it]
OOF-2| Epoch:  11| Train loss: 0.43168| roc_auc: 0.83130: 100%|██████████| 1/1 [01:54<00:00, 114.59s/it]
OOF-2| Epoch:  12| Train loss: 0.43063| roc_auc: 0.83011: 100%|██████████| 1/1 [01:55<00:00, 115.30s/it]
OOF-2| Epoch:  13| Train loss: 0.42915| roc_auc: 0.83015: 100%|██████████| 1/1 [01:56<00:00, 116.12s/it]
OOF-2| Epoch:  14| Train loss: 0.42755| roc_auc: 0.83018: 100%|██████████| 1/1 [01:56<00:00, 116.48s/it]
OOF-2| Epoch:  15| Train loss: 0.42616| roc_auc: 0.82988: 100%|██████████| 1/1 [01:55<00:00, 115.58s/it]
OOF-2| Epoch:  16| Train loss: 0.42525| roc_auc: 0.83054: 100%|██████████| 1/1 [01:55<00:00, 115.25s/it]
OOF-2| Epoch:  17| Train loss: 0.42407| roc_auc: 0.83171: 100%|██████████| 1/1 [01:53<00:00, 113.34s/it]
OOF-2| Epoch:  18| Train loss: 0.42251| roc_auc: 0.82987: 100%|██████████| 1/1 [01:56<00:00, 116.50s/it]
OOF-2| Epoch:  19| Train loss: 0.42215| roc_auc: 0.83093: 100%|██████████| 1/1 [01:54<00:00, 114.01s/it]
OOF-2| Epoch:  20| Train loss: 0.42035| roc_auc: 0.82769: 100%|██████████| 1/1 [01:55<00:00, 115.53s/it]
BEST OOF-2| Epoch:  17| Train loss: 0.42407| roc_auc: 0.83171
OOF-3| Epoch:   1| Train loss: 0.50986| roc_auc: 0.79456: 100%|██████████| 1/1 [01:56<00:00, 116.35s/it]
OOF-3| Epoch:   2| Train loss: 0.46829| roc_auc: 0.81776: 100%|██████████| 1/1 [01:58<00:00, 118.48s/it]
OOF-3| Epoch:   3| Train loss: 0.45429| roc_auc: 0.82726: 100%|██████████| 1/1 [01:55<00:00, 115.56s/it]
OOF-3| Epoch:   4| Train loss: 0.44735| roc_auc: 0.82892: 100%|██████████| 1/1 [01:55<00:00, 115.90s/it]
OOF-3| Epoch:   5| Train loss: 0.44347| roc_auc: 0.82818: 100%|██████████| 1/1 [01:58<00:00, 118.92s/it]
OOF-3| Epoch:   6| Train loss: 0.44045| roc_auc: 0.83282: 100%|██████████| 1/1 [01:58<00:00, 118.88s/it]
OOF-3| Epoch:   7| Train loss: 0.43825| roc_auc: 0.83213: 100%|██████████| 1/1 [01:56<00:00, 116.99s/it]
OOF-3| Epoch:   8| Train loss: 0.43650| roc_auc: 0.83542: 100%|██████████| 1/1 [01:58<00:00, 118.83s/it]
OOF-3| Epoch:   9| Train loss: 0.43390| roc_auc: 0.83386: 100%|██████████| 1/1 [01:57<00:00, 117.81s/it]
OOF-3| Epoch:  10| Train loss: 0.43257| roc_auc: 0.83359: 100%|██████████| 1/1 [01:56<00:00, 116.30s/it]
OOF-3| Epoch:  11| Train loss: 0.43083| roc_auc: 0.83521: 100%|██████████| 1/1 [01:56<00:00, 116.46s/it]
OOF-3| Epoch:  12| Train loss: 0.42935| roc_auc: 0.83358: 100%|██████████| 1/1 [01:57<00:00, 117.09s/it]
OOF-3| Epoch:  13| Train loss: 0.42800| roc_auc: 0.83720: 100%|██████████| 1/1 [01:57<00:00, 117.71s/it]
OOF-3| Epoch:  14| Train loss: 0.42677| roc_auc: 0.83551: 100%|██████████| 1/1 [01:57<00:00, 117.13s/it]
OOF-3| Epoch:  15| Train loss: 0.42529| roc_auc: 0.83762: 100%|██████████| 1/1 [01:57<00:00, 117.83s/it]
OOF-3| Epoch:  16| Train loss: 0.42435| roc_auc: 0.83665: 100%|██████████| 1/1 [01:59<00:00, 119.27s/it]
OOF-3| Epoch:  17| Train loss: 0.42324| roc_auc: 0.83612: 100%|██████████| 1/1 [01:57<00:00, 117.74s/it]
OOF-3| Epoch:  18| Train loss: 0.42194| roc_auc: 0.83714: 100%|██████████| 1/1 [01:56<00:00, 116.28s/it]
OOF-3| Epoch:  19| Train loss: 0.42063| roc_auc: 0.83718: 100%|██████████| 1/1 [01:57<00:00, 117.85s/it]
OOF-3| Epoch:  20| Train loss: 0.41991| roc_auc: 0.83624: 100%|██████████| 1/1 [02:03<00:00, 123.29s/it]
BEST OOF-3| Epoch:  15| Train loss: 0.42529| roc_auc: 0.83762
OOF-4| Epoch:   1| Train loss: 0.50803| roc_auc: 0.80679: 100%|██████████| 1/1 [02:02<00:00, 122.05s/it]
OOF-4| Epoch:   2| Train loss: 0.46878| roc_auc: 0.83093: 100%|██████████| 1/1 [01:57<00:00, 117.42s/it]
OOF-4| Epoch:   3| Train loss: 0.45452| roc_auc: 0.83637: 100%|██████████| 1/1 [01:58<00:00, 118.88s/it]
OOF-4| Epoch:   4| Train loss: 0.44752| roc_auc: 0.83825: 100%|██████████| 1/1 [02:02<00:00, 122.77s/it]
OOF-4| Epoch:   5| Train loss: 0.44336| roc_auc: 0.84125: 100%|██████████| 1/1 [02:02<00:00, 122.48s/it]
OOF-4| Epoch:   6| Train loss: 0.44035| roc_auc: 0.84286: 100%|██████████| 1/1 [02:01<00:00, 121.50s/it]
OOF-4| Epoch:   7| Train loss: 0.43780| roc_auc: 0.84547: 100%|██████████| 1/1 [02:03<00:00, 123.44s/it]
OOF-4| Epoch:   8| Train loss: 0.43600| roc_auc: 0.84276: 100%|██████████| 1/1 [02:03<00:00, 123.28s/it]
OOF-4| Epoch:   9| Train loss: 0.43430| roc_auc: 0.84459: 100%|██████████| 1/1 [02:03<00:00, 123.39s/it]
OOF-4| Epoch:  10| Train loss: 0.43291| roc_auc: 0.84620: 100%|██████████| 1/1 [02:02<00:00, 122.22s/it]
OOF-4| Epoch:  11| Train loss: 0.43139| roc_auc: 0.84427: 100%|██████████| 1/1 [02:03<00:00, 123.91s/it]
OOF-4| Epoch:  12| Train loss: 0.42975| roc_auc: 0.84544: 100%|██████████| 1/1 [02:00<00:00, 120.78s/it]
OOF-4| Epoch:  13| Train loss: 0.42908| roc_auc: 0.84650: 100%|██████████| 1/1 [02:00<00:00, 120.85s/it]
OOF-4| Epoch:  14| Train loss: 0.42732| roc_auc: 0.84850: 100%|██████████| 1/1 [02:01<00:00, 121.80s/it]
OOF-4| Epoch:  15| Train loss: 0.42605| roc_auc: 0.84544: 100%|██████████| 1/1 [02:02<00:00, 122.84s/it]
OOF-4| Epoch:  16| Train loss: 0.42498| roc_auc: 0.84496: 100%|██████████| 1/1 [02:01<00:00, 121.84s/it]
OOF-4| Epoch:  17| Train loss: 0.42330| roc_auc: 0.84284: 100%|██████████| 1/1 [02:03<00:00, 123.95s/it]
OOF-4| Epoch:  18| Train loss: 0.42218| roc_auc: 0.84292: 100%|██████████| 1/1 [02:01<00:00, 121.40s/it]
OOF-4| Epoch:  19| Train loss: 0.42094| roc_auc: 0.84504: 100%|██████████| 1/1 [02:02<00:00, 122.80s/it]
OOF-4| Epoch:  20| Train loss: 0.42022| roc_auc: 0.84301: 100%|██████████| 1/1 [02:03<00:00, 123.39s/it]
BEST OOF-4| Epoch:  14| Train loss: 0.42732| roc_auc: 0.84850
Total roc_auc: 0.84769

```

```
num-head-2

OOF-0| Epoch:   1| Train loss: 0.50858| roc_auc: 0.82706: 100%|██████████| 1/1 [02:03<00:00, 123.25s/it]
OOF-0| Epoch:   2| Train loss: 0.46791| roc_auc: 0.84448: 100%|██████████| 1/1 [02:02<00:00, 122.69s/it]
OOF-0| Epoch:   3| Train loss: 0.45332| roc_auc: 0.84909: 100%|██████████| 1/1 [02:03<00:00, 123.32s/it]
OOF-0| Epoch:   4| Train loss: 0.44674| roc_auc: 0.85422: 100%|██████████| 1/1 [02:02<00:00, 122.90s/it]
OOF-0| Epoch:   5| Train loss: 0.44278| roc_auc: 0.85371: 100%|██████████| 1/1 [02:02<00:00, 122.54s/it]
OOF-0| Epoch:   6| Train loss: 0.44034| roc_auc: 0.85612: 100%|██████████| 1/1 [02:01<00:00, 121.98s/it]
OOF-0| Epoch:   7| Train loss: 0.43729| roc_auc: 0.85691: 100%|██████████| 1/1 [02:02<00:00, 122.18s/it]
OOF-0| Epoch:   8| Train loss: 0.43586| roc_auc: 0.85380: 100%|██████████| 1/1 [02:03<00:00, 123.93s/it]
OOF-0| Epoch:   9| Train loss: 0.43391| roc_auc: 0.85510: 100%|██████████| 1/1 [02:05<00:00, 125.92s/it]
OOF-0| Epoch:  10| Train loss: 0.43245| roc_auc: 0.85718: 100%|██████████| 1/1 [02:02<00:00, 122.81s/it]
OOF-0| Epoch:  11| Train loss: 0.43095| roc_auc: 0.85469: 100%|██████████| 1/1 [02:01<00:00, 121.42s/it]
OOF-0| Epoch:  12| Train loss: 0.42998| roc_auc: 0.85344: 100%|██████████| 1/1 [02:03<00:00, 123.49s/it]
OOF-0| Epoch:  13| Train loss: 0.42843| roc_auc: 0.85439: 100%|██████████| 1/1 [02:00<00:00, 120.31s/it]
OOF-0| Epoch:  14| Train loss: 0.42698| roc_auc: 0.85367: 100%|██████████| 1/1 [02:00<00:00, 120.31s/it]
OOF-0| Epoch:  15| Train loss: 0.42620| roc_auc: 0.85526: 100%|██████████| 1/1 [02:01<00:00, 121.08s/it]
OOF-0| Epoch:  16| Train loss: 0.42500| roc_auc: 0.85509: 100%|██████████| 1/1 [01:58<00:00, 118.33s/it]
OOF-0| Epoch:  17| Train loss: 0.42390| roc_auc: 0.85435: 100%|██████████| 1/1 [01:59<00:00, 119.80s/it]
OOF-0| Epoch:  18| Train loss: 0.42260| roc_auc: 0.85455: 100%|██████████| 1/1 [01:57<00:00, 117.17s/it]
OOF-0| Epoch:  19| Train loss: 0.42165| roc_auc: 0.85624: 100%|██████████| 1/1 [01:56<00:00, 116.83s/it]
OOF-0| Epoch:  20| Train loss: 0.42004| roc_auc: 0.85612: 100%|██████████| 1/1 [01:56<00:00, 116.26s/it]
BEST OOF-0| Epoch:  10| Train loss: 0.43245| roc_auc: 0.85718

OOF-1| Epoch:   1| Train loss: 0.51246| roc_auc: 0.81910: 100%|██████████| 1/1 [02:01<00:00, 121.98s/it]
OOF-1| Epoch:   2| Train loss: 0.46823| roc_auc: 0.84305: 100%|██████████| 1/1 [02:02<00:00, 122.80s/it]
OOF-1| Epoch:   3| Train loss: 0.45340| roc_auc: 0.84986: 100%|██████████| 1/1 [02:01<00:00, 121.49s/it]
OOF-1| Epoch:   4| Train loss: 0.44746| roc_auc: 0.85024: 100%|██████████| 1/1 [01:59<00:00, 119.15s/it]
OOF-1| Epoch:   5| Train loss: 0.44332| roc_auc: 0.85493: 100%|██████████| 1/1 [01:57<00:00, 117.11s/it]
OOF-1| Epoch:   6| Train loss: 0.44067| roc_auc: 0.85214: 100%|██████████| 1/1 [01:57<00:00, 117.35s/it]
OOF-1| Epoch:   7| Train loss: 0.43816| roc_auc: 0.85739: 100%|██████████| 1/1 [01:55<00:00, 115.40s/it]
OOF-1| Epoch:   8| Train loss: 0.43633| roc_auc: 0.85651: 100%|██████████| 1/1 [02:02<00:00, 122.66s/it]
OOF-1| Epoch:   9| Train loss: 0.43419| roc_auc: 0.85588: 100%|██████████| 1/1 [01:59<00:00, 119.82s/it]
OOF-1| Epoch:  10| Train loss: 0.43259| roc_auc: 0.85834: 100%|██████████| 1/1 [01:58<00:00, 118.95s/it]
OOF-1| Epoch:  11| Train loss: 0.43123| roc_auc: 0.85700: 100%|██████████| 1/1 [01:58<00:00, 118.09s/it]
OOF-1| Epoch:  12| Train loss: 0.43013| roc_auc: 0.85617: 100%|██████████| 1/1 [02:05<00:00, 125.98s/it]
OOF-1| Epoch:  13| Train loss: 0.42880| roc_auc: 0.85912: 100%|██████████| 1/1 [01:59<00:00, 119.05s/it]
OOF-1| Epoch:  14| Train loss: 0.42790| roc_auc: 0.85804: 100%|██████████| 1/1 [01:58<00:00, 118.77s/it]
OOF-1| Epoch:  15| Train loss: 0.42638| roc_auc: 0.85811: 100%|██████████| 1/1 [02:01<00:00, 121.06s/it]
OOF-1| Epoch:  16| Train loss: 0.42547| roc_auc: 0.85898: 100%|██████████| 1/1 [01:59<00:00, 119.59s/it]
OOF-1| Epoch:  17| Train loss: 0.42433| roc_auc: 0.85918: 100%|██████████| 1/1 [02:01<00:00, 121.23s/it]
OOF-1| Epoch:  18| Train loss: 0.42331| roc_auc: 0.85816: 100%|██████████| 1/1 [01:58<00:00, 118.67s/it]
OOF-1| Epoch:  19| Train loss: 0.42187| roc_auc: 0.85730: 100%|██████████| 1/1 [02:02<00:00, 122.19s/it]
OOF-1| Epoch:  20| Train loss: 0.42127| roc_auc: 0.85648: 100%|██████████| 1/1 [02:00<00:00, 120.38s/it]
BEST OOF-1| Epoch:  17| Train loss: 0.42433| roc_auc: 0.85918

OOF-2| Epoch:   1| Train loss: 0.51142| roc_auc: 0.79601: 100%|██████████| 1/1 [01:58<00:00, 118.50s/it]
OOF-2| Epoch:   2| Train loss: 0.46885| roc_auc: 0.81396: 100%|██████████| 1/1 [01:59<00:00, 119.01s/it]
OOF-2| Epoch:   3| Train loss: 0.45433| roc_auc: 0.82116: 100%|██████████| 1/1 [01:58<00:00, 118.16s/it]
OOF-2| Epoch:   4| Train loss: 0.44758| roc_auc: 0.82253: 100%|██████████| 1/1 [01:55<00:00, 115.09s/it]
OOF-2| Epoch:   5| Train loss: 0.44414| roc_auc: 0.82432: 100%|██████████| 1/1 [01:58<00:00, 118.77s/it]
OOF-2| Epoch:   6| Train loss: 0.44046| roc_auc: 0.82371: 100%|██████████| 1/1 [02:00<00:00, 120.85s/it]
OOF-2| Epoch:   7| Train loss: 0.43899| roc_auc: 0.82584: 100%|██████████| 1/1 [01:56<00:00, 116.90s/it]
OOF-2| Epoch:   8| Train loss: 0.43688| roc_auc: 0.82684: 100%|██████████| 1/1 [01:57<00:00, 117.17s/it]
OOF-2| Epoch:   9| Train loss: 0.43528| roc_auc: 0.82799: 100%|██████████| 1/1 [01:58<00:00, 118.52s/it]
OOF-2| Epoch:  10| Train loss: 0.43384| roc_auc: 0.82661: 100%|██████████| 1/1 [02:00<00:00, 120.03s/it]
OOF-2| Epoch:  11| Train loss: 0.43257| roc_auc: 0.83009: 100%|██████████| 1/1 [01:57<00:00, 117.98s/it]
OOF-2| Epoch:  12| Train loss: 0.43140| roc_auc: 0.82843: 100%|██████████| 1/1 [01:59<00:00, 119.56s/it]
OOF-2| Epoch:  13| Train loss: 0.43006| roc_auc: 0.82755: 100%|██████████| 1/1 [01:56<00:00, 116.50s/it]
OOF-2| Epoch:  14| Train loss: 0.42855| roc_auc: 0.83049: 100%|██████████| 1/1 [01:57<00:00, 117.12s/it]
OOF-2| Epoch:  15| Train loss: 0.42716| roc_auc: 0.82899: 100%|██████████| 1/1 [01:56<00:00, 116.98s/it]
OOF-2| Epoch:  16| Train loss: 0.42623| roc_auc: 0.82861: 100%|██████████| 1/1 [01:58<00:00, 118.81s/it]
OOF-2| Epoch:  17| Train loss: 0.42522| roc_auc: 0.82610: 100%|██████████| 1/1 [01:53<00:00, 113.30s/it]
OOF-2| Epoch:  18| Train loss: 0.42364| roc_auc: 0.82639: 100%|██████████| 1/1 [01:55<00:00, 115.48s/it]
OOF-2| Epoch:  19| Train loss: 0.42318| roc_auc: 0.82957: 100%|██████████| 1/1 [01:56<00:00, 116.82s/it]
OOF-2| Epoch:  20| Train loss: 0.42144| roc_auc: 0.82631: 100%|██████████| 1/1 [01:59<00:00, 119.69s/it]
BEST OOF-2| Epoch:  14| Train loss: 0.42855| roc_auc: 0.83049

OOF-3| Epoch:   1| Train loss: 0.50971| roc_auc: 0.79570: 100%|██████████| 1/1 [02:02<00:00, 122.54s/it]
OOF-3| Epoch:   2| Train loss: 0.46781| roc_auc: 0.82055: 100%|██████████| 1/1 [02:04<00:00, 124.66s/it]
OOF-3| Epoch:   3| Train loss: 0.45411| roc_auc: 0.82631: 100%|██████████| 1/1 [02:02<00:00, 122.37s/it]
OOF-3| Epoch:   4| Train loss: 0.44730| roc_auc: 0.82849: 100%|██████████| 1/1 [02:02<00:00, 122.98s/it]
OOF-3| Epoch:   5| Train loss: 0.44350| roc_auc: 0.83021: 100%|██████████| 1/1 [02:03<00:00, 123.41s/it]
OOF-3| Epoch:   6| Train loss: 0.44060| roc_auc: 0.83509: 100%|██████████| 1/1 [02:01<00:00, 121.74s/it]
OOF-3| Epoch:   7| Train loss: 0.43862| roc_auc: 0.83220: 100%|██████████| 1/1 [02:01<00:00, 121.48s/it]
OOF-3| Epoch:   8| Train loss: 0.43681| roc_auc: 0.83439: 100%|██████████| 1/1 [02:02<00:00, 122.34s/it]
OOF-3| Epoch:   9| Train loss: 0.43458| roc_auc: 0.83282: 100%|██████████| 1/1 [02:02<00:00, 122.76s/it]
OOF-3| Epoch:  10| Train loss: 0.43325| roc_auc: 0.83455: 100%|██████████| 1/1 [02:03<00:00, 123.84s/it]
OOF-3| Epoch:  11| Train loss: 0.43151| roc_auc: 0.83559: 100%|██████████| 1/1 [02:02<00:00, 122.65s/it]
OOF-3| Epoch:  12| Train loss: 0.43006| roc_auc: 0.83558: 100%|██████████| 1/1 [02:02<00:00, 122.41s/it]
OOF-3| Epoch:  13| Train loss: 0.42861| roc_auc: 0.83806: 100%|██████████| 1/1 [01:59<00:00, 119.82s/it]
OOF-3| Epoch:  14| Train loss: 0.42748| roc_auc: 0.83684: 100%|██████████| 1/1 [02:01<00:00, 121.69s/it]
OOF-3| Epoch:  15| Train loss: 0.42606| roc_auc: 0.83885: 100%|██████████| 1/1 [02:01<00:00, 121.51s/it]
OOF-3| Epoch:  16| Train loss: 0.42530| roc_auc: 0.83961: 100%|██████████| 1/1 [02:02<00:00, 122.61s/it]
OOF-3| Epoch:  17| Train loss: 0.42400| roc_auc: 0.83737: 100%|██████████| 1/1 [02:00<00:00, 120.90s/it]
OOF-3| Epoch:  18| Train loss: 0.42272| roc_auc: 0.83888: 100%|██████████| 1/1 [02:01<00:00, 121.95s/it]
OOF-3| Epoch:  19| Train loss: 0.42148| roc_auc: 0.83861: 100%|██████████| 1/1 [02:01<00:00, 121.19s/it]
OOF-3| Epoch:  20| Train loss: 0.42069| roc_auc: 0.83983: 100%|██████████| 1/1 [02:01<00:00, 121.66s/it]
BEST OOF-3| Epoch:  20| Train loss: 0.42069| roc_auc: 0.83983

OOF-4| Epoch:   1| Train loss: 0.50790| roc_auc: 0.80640: 100%|██████████| 1/1 [02:01<00:00, 121.78s/it]
OOF-4| Epoch:   2| Train loss: 0.46806| roc_auc: 0.82975: 100%|██████████| 1/1 [02:01<00:00, 121.26s/it]
OOF-4| Epoch:   3| Train loss: 0.45404| roc_auc: 0.83418: 100%|██████████| 1/1 [02:01<00:00, 121.38s/it]
OOF-4| Epoch:   4| Train loss: 0.44747| roc_auc: 0.83818: 100%|██████████| 1/1 [02:02<00:00, 122.35s/it]
OOF-4| Epoch:   5| Train loss: 0.44355| roc_auc: 0.83937: 100%|██████████| 1/1 [01:59<00:00, 119.69s/it]
OOF-4| Epoch:   6| Train loss: 0.44079| roc_auc: 0.84145: 100%|██████████| 1/1 [02:00<00:00, 120.66s/it]
OOF-4| Epoch:   7| Train loss: 0.43819| roc_auc: 0.84354: 100%|██████████| 1/1 [02:04<00:00, 124.13s/it]
OOF-4| Epoch:   8| Train loss: 0.43658| roc_auc: 0.84237: 100%|██████████| 1/1 [02:01<00:00, 121.35s/it]
OOF-4| Epoch:   9| Train loss: 0.43476| roc_auc: 0.84456: 100%|██████████| 1/1 [02:00<00:00, 120.28s/it]
OOF-4| Epoch:  10| Train loss: 0.43347| roc_auc: 0.84327: 100%|██████████| 1/1 [02:02<00:00, 122.39s/it]
OOF-4| Epoch:  11| Train loss: 0.43257| roc_auc: 0.84308: 100%|██████████| 1/1 [02:00<00:00, 120.90s/it]
OOF-4| Epoch:  12| Train loss: 0.43087| roc_auc: 0.84594: 100%|██████████| 1/1 [01:59<00:00, 119.75s/it]
OOF-4| Epoch:  13| Train loss: 0.43005| roc_auc: 0.84592: 100%|██████████| 1/1 [02:01<00:00, 121.57s/it]
OOF-4| Epoch:  14| Train loss: 0.42850| roc_auc: 0.84861: 100%|██████████| 1/1 [02:00<00:00, 120.06s/it]
OOF-4| Epoch:  15| Train loss: 0.42722| roc_auc: 0.84901: 100%|██████████| 1/1 [02:00<00:00, 120.92s/it]
OOF-4| Epoch:  16| Train loss: 0.42609| roc_auc: 0.84590: 100%|██████████| 1/1 [02:00<00:00, 120.62s/it]
OOF-4| Epoch:  17| Train loss: 0.42477| roc_auc: 0.84602: 100%|██████████| 1/1 [02:02<00:00, 122.59s/it]
OOF-4| Epoch:  18| Train loss: 0.42326| roc_auc: 0.84428: 100%|██████████| 1/1 [02:00<00:00, 120.75s/it]
OOF-4| Epoch:  19| Train loss: 0.42234| roc_auc: 0.84595: 100%|██████████| 1/1 [02:01<00:00, 121.15s/it]
OOF-4| Epoch:  20| Train loss: 0.42147| roc_auc: 0.84370: 100%|██████████| 1/1 [01:59<00:00, 119.21s/it]
BEST OOF-4| Epoch:  15| Train loss: 0.42722| roc_auc: 0.84901

Total roc_auc: 0.84714


```

# 예측

In [14]:
test_df = make_dataset.get_test_data()
test_dataset = CustomDataset(df = test_df,)
test_data_loader = DataLoader(
    test_dataset,
    batch_size = 1, 
    shuffle = False, 
    drop_last = False,
    collate_fn = train_make_batch,
    num_workers = num_workers)

pred_list = []

model = SASRec(
    num_assessmentItemID = make_dataset.num_assessmentItemID, 
    num_testId = make_dataset.num_testId,
    num_KnowledgeTag = make_dataset.num_KnowledgeTag,
    num_large_paper_number = make_dataset.num_large_paper_number,
    num_cols = train_dataset.num_cols,
    cat_cols = train_dataset.cat_cols,
    emb_size = emb_size, 
    hidden_units = hidden_units, 
    num_heads = num_heads, 
    num_layers = num_layers, 
    dropout_rate = dropout_rate, 
    device = device).to(device)

for oof in make_dataset.oof_user_set.keys():
    model.load_state_dict(torch.load(os.path.join(MODEL_PATH, f'oof_{oof}_' + model_name)))
    pred = predict(model = model, data_loader = test_data_loader)
    pred_list.append(pred)

pred_list = np.array(pred_list).mean(axis = 0)

In [15]:
submission = pd.DataFrame(data = np.array(pred_list), columns = ['prediction'])
submission['id'] = submission.index
submission = submission[['id', 'prediction']]
submission.to_csv(os.path.join(SUBMISSION_PATH, 'OOF-Ensemble-' + submission_name), index = False)