In [None]:
import time
import pickle
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from parser import parser
from model import SaintPlus, NoamOpt
from torch.utils.data import DataLoader
from data_generator import Riiid_Sequence
from sklearn.metrics import roc_auc_score

from sklearn.model_selection import train_test_split

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
n_problems = 9454
n_tests = 1537
n_categories = 9
n_tags = 912

num_layers = 2
num_heads = 4
d_model = 128
d_ffn = d_model*4

seq_len = 100
warmup_steps = 4000
dropout = 0.1
epochs = 1000
patience = 100
batch_size = 128

fold_n = 10

In [None]:
with open("/opt/ml/input/data/train_group.pkl.zip", 'rb') as pick:
    group = pickle.load(pick)

In [None]:
from sklearn.model_selection import KFold

In [None]:
kf = KFold(n_splits = fold_n, shuffle = True)

In [None]:
fold_best_auc = [0] * fold_n
for i, (train_index, valid_index) in enumerate(kf.split(group.index)):
    train_group = group[group.index[train_index]]
    val_group = group[group.index[valid_index]]
    
    train_loader = DataLoader(Riiid_Sequence(train_group, seq_len), batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(Riiid_Sequence(val_group, seq_len), batch_size=batch_size, shuffle=False, num_workers=8)
    
    loss_fn = nn.BCELoss()
    model = SaintPlus(seq_len=seq_len, num_layers=num_layers, d_ffn=d_ffn, d_model=d_model, num_heads=num_heads,
                      n_problems=n_problems, n_tests=n_tests, n_categories=n_categories, n_tags=n_tags,
                      dropout=dropout)
    optimizer = NoamOpt(d_model, 1, 4000 ,optim.Adam(model.parameters(), lr=0))
    model.to(device)
    loss_fn.to(device)
    
    train_losses = []
    val_losses = []
    val_aucs = []
    best_auc = 0
    count = 0
    for e in range(epochs):
        print("==========Epoch {} Start Training==========".format(e+1))
        model.train()
        t_s = time.time()
        train_loss = []
        for step, data in enumerate(train_loader):
            category_sample = data[0].to(device).long()
            category_test_sample = data[1].to(device).long()
            category_test_problem_sample = data[2].to(device).long()
            problem_tag_sample = data[3].to(device).long()
            problem_time_sample = data[4].to(device).float()
            break_time_sample = data[5].to(device).float()
            answer_sample = data[6].to(device).long()
            label = data[7].to(device).float()

            optimizer.optimizer.zero_grad()

            preds = model(category_sample, category_test_sample, category_test_problem_sample, problem_tag_sample, break_time_sample, problem_time_sample, answer_sample)
            loss_mask = (answer_sample != 0)
            preds_masked = torch.masked_select(preds, loss_mask)
            label_masked = torch.masked_select(label, loss_mask)
            loss = loss_fn(preds_masked, label_masked)

            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        train_loss = np.mean(train_loss)
        print("==========Epoch {} Start Validation==========".format(e+1))
        model.eval()
        val_loss = []
        val_labels = []
        val_preds = []
        for step, data in enumerate(val_loader):
            category_sample = data[0].to(device).long()
            category_test_sample = data[1].to(device).long()
            category_test_problem_sample = data[2].to(device).long()
            problem_tag_sample = data[3].to(device).long()
            problem_time_sample = data[4].to(device).float()
            break_time_sample = data[5].to(device).float()
            answer_sample = data[6].to(device).long()
            label = data[7].to(device).float()

            preds = model(category_sample, category_test_sample, category_test_problem_sample, problem_tag_sample, break_time_sample, problem_time_sample, answer_sample)
            loss_mask = (answer_sample != 0)
            preds_masked = torch.masked_select(preds, loss_mask)
            label_masked = torch.masked_select(label, loss_mask)

            val_loss.append(loss.item())
            val_labels.extend(label_masked.view(-1).data.cpu().numpy())
            val_preds.extend(preds_masked.view(-1).data.cpu().numpy())

        val_loss = np.mean(val_loss)
        val_auc = roc_auc_score(val_labels, val_preds)

        if val_auc > best_auc:
            print("Save model at epoch {}".format(e+1))
            torch.save(model.state_dict(), "./saint.pt")
            best_auc = val_auc
            fold_best_auc[i] = best_auc
            count = 0
        else:
            count += 1

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_aucs.append(val_auc)
        exec_t = int((time.time() - t_s)/60)
        print("Train Loss {:.4f}/ Val Loss {:.4f}, Val AUC {:.4f} / Exec time {} min".format(train_loss, val_loss, val_auc, exec_t))
        if count >= patience:
            print('Early Stopping!')
            break
        
print(mean(fold_best_auc))