In [None]:
import time
import pickle
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from parser import parser
from model import SaintPlus, NoamOpt
from torch.utils.data import DataLoader
from data_generator import Riiid_Sequence
from sklearn.metrics import roc_auc_score

import random

import wandb

from sklearn.model_selection import train_test_split

In [None]:
seed = 42

In [None]:
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
n_problems = 9454
n_tests = 1537
n_categories = 9
n_tags = 912

In [None]:
with open("/opt/ml/input/data/train_group.pkl.zip", 'rb') as pick:
    group = pickle.load(pick)

In [None]:
train_index, valid_index = train_test_split(group.index, train_size = 0.8, random_state = seed)

In [None]:
train_group = group[train_index]
val_group = group[valid_index]

In [None]:
sweep_configuration = {
    'method': 'bayes',
    'metric': {
        'goal': 'maximize', 
        'name': 'auc'
        },
    'parameters': {
        "num_layers": {"values": [1, 2, 4]},
        "num_heads": {"values": [2, 4, 8]},
        "d_model": {"values": [64, 128, 256]},
        "d_ffn": {"values": [2, 3, 4, 5, 6]},

        "seq_len": {"distribution": "int_uniform", "min": 10, "max": 500},
        "warmup_steps": {"distribution": "int_uniform", "min": 100, "max": 10000},
        "dropout": {"distribution": "uniform", "min": 0, "max": 0.9},

        "lr": {"distribution": "uniform", "min": 0.001, "max": 0.01},
        "batch_size": {"values": [64, 128]},
        "epochs": {"value": 100},
        "patience": {"value": 10},
    }
}

In [None]:
def objective_function():
    wandb.init()
    args = wandb.config
    
    train_loader = DataLoader(Riiid_Sequence(train_group, args.seq_len), batch_size=args.batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(Riiid_Sequence(val_group, args.seq_len), batch_size=args.batch_size, shuffle=False, num_workers=8)
    
    loss_fn = nn.BCELoss()
    model = SaintPlus(seq_len=args.seq_len, num_layers=args.num_layers, d_ffn=args.d_model*args.d_ffn, d_model=args.d_model, num_heads=args.num_heads,
                      n_problems=n_problems, n_tests=n_tests, n_categories=n_categories, n_tags=n_tags,
                      dropout=args.dropout)
    optimizer = NoamOpt(args.d_model, 1, args.warmup_steps ,optim.Adam(model.parameters(), lr=args.lr))
    model.to(device)
    loss_fn.to(device)
    
    train_losses = []
    val_losses = []
    val_aucs = []
    best_auc = 0
    count = 0
    for e in range(args.epochs):
        model.train()
        train_loss = []
        for step, data in enumerate(train_loader):
            category_sample = data[0].to(device).long()
            category_test_sample = data[1].to(device).long()
            category_test_problem_sample = data[2].to(device).long()
            problem_tag_sample = data[3].to(device).long()
            problem_time_sample = data[4].to(device).float()
            break_time_sample = data[5].to(device).float()
            answer_sample = data[6].to(device).long()
            label = data[7].to(device).float()

            optimizer.optimizer.zero_grad()

            preds = model(category_sample, category_test_sample, category_test_problem_sample, problem_tag_sample, break_time_sample, problem_time_sample, answer_sample)
            loss_mask = (answer_sample != 0)
            preds_masked = torch.masked_select(preds, loss_mask)
            label_masked = torch.masked_select(label, loss_mask)
            loss = loss_fn(preds_masked, label_masked)

            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        train_loss = np.mean(train_loss)
        model.eval()
        val_loss = []
        val_labels = []
        val_preds = []
        for step, data in enumerate(val_loader):
            category_sample = data[0].to(device).long()
            category_test_sample = data[1].to(device).long()
            category_test_problem_sample = data[2].to(device).long()
            problem_tag_sample = data[3].to(device).long()
            problem_time_sample = data[4].to(device).float()
            break_time_sample = data[5].to(device).float()
            answer_sample = data[6].to(device).long()
            label = data[7].to(device).float()

            preds = model(category_sample, category_test_sample, category_test_problem_sample, problem_tag_sample, break_time_sample, problem_time_sample, answer_sample)
            loss_mask = (answer_sample != 0)
            preds_masked = torch.masked_select(preds, loss_mask)
            label_masked = torch.masked_select(label, loss_mask)

            val_loss.append(loss.item())
            val_labels.extend(label_masked.view(-1).data.cpu().numpy())
            val_preds.extend(preds_masked.view(-1).data.cpu().numpy())

        val_loss = np.mean(val_loss)
        val_auc = roc_auc_score(val_labels, val_preds)

        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), "./saintplus.pt")
            count = 0
        else:
            count += 1

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_aucs.append(val_auc)
        wandb.log({
            "train_loss": train_losses,
            "val_loss": val_losses,
            "val_auc": val_aucs
        })
        if count >= args.patience:
            print('Early Stopping!')
            break
    wandb.log({"auc": best_auc})
    os.rename("./saintplus.pt", "./saintplus_" + best_auc + ".pt")

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = 'saintplus')

In [None]:
wandb.agent(sweep_id, function = objective_function, count = 50)