<a href="https://colab.research.google.com/github/bagherig/riiid/blob/main/riiid_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Fri Jan 15 03:27:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    44W / 300W |      0MiB / 16130MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 27.4 gigabytes of available RAM

You are using a high-RAM runtime!


# Global Variables

In [3]:
# %% [code]

COMPILE = False #@param {type:"boolean"}
VERSION = 'v1' #@param ['v1', 'v2']
TPU = True

DRIVE_DIR = '/content/drive/MyDrive/Colab Notebooks/datasets' + \
             '/riiid-answer-correctness-prediction/'
HOME_DIR = './'
DATA_DIR = DRIVE_DIR + 'data/riiid-test-answer-prediction/'
DRIVE_PARQUETS_DIR = DRIVE_DIR + 'data/parquets/'
PARQUETS_DIR = HOME_DIR + 'parquets/'
MODELS_DIR = DRIVE_DIR + 'models/'

OUT_DIR = DRIVE_DIR + 'temp/'

# Imports

In [4]:
# import riiideducation 

import os
import gc
import sys
import math
import copy
import random
import psutil
import warnings
from typing import Optional, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm

warnings.filterwarnings('ignore')
gc.enable()

if not os.path.exists(PARQUETS_DIR):
    sys.path.append(DRIVE_PARQUETS_DIR)
    zip_path = DRIVE_PARQUETS_DIR + 'parquets.zip'
    !cp '{zip_path}' $HOME_DIR
    !unzip -q 'parquets.zip'
    !rm 'parquets.zip'

TPU = 'COLAB_TPU_ADDR' in os.environ
if TPU:
    !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    
print(torch.__version__)

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)# if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
RANDOM_SEED = 1
seed_everything(RANDOM_SEED)

1.7.0+cu101


In [5]:
TARGET = 'answered_correctly'
KEY_FEATURE = 'user_id'
FEATURES = [
    'content_id',
    'timestamp',
    'prior_question_elapsed_time',
    'prior_question_had_explanation',
    # 'task_container_id',
    'part',
    'tag1',
    'tag2',
    # 'tag3',
    # 'tag4',
    # 'tag5',
    # 'tag6',
]

DTYPES = {
    'content_id': int,
    'timestamp': int,
    'prior_question_elapsed_time': int,
    'prior_question_had_explanation': int,
    # 'task_container_id': int,
    'part': int,
    'tag1': int,
    'tag2': int,
    # 'tag3': int,
    # 'tag4': int,
    # 'tag5': int,
    # 'tag6': int,
}

# For Inference
ADDED_FEATURES = [
    'part',
    'tag1',
    'tag2',
    # 'tag3',
    # 'tag4',
    # 'tag5',
    # 'tag6',
]

DEFAULT_VALUES = {
    'prior_question_had_explanation': False,
    'prior_question_elapsed_time': 26000,
    'timestamp': 0,
    'part': 0,
    'tag1': -1,
    'tag2': -1,
    # 'tag3': -1,
    # 'tag4': -1,
    # 'tag5': -1,
    # 'tag6': -1,
}

TIME_SCALES = {
    'prior_question_elapsed_time': 1000,
    'timestamp': 1000 * 5,
}
SUBMISSION_COLUMNS = [
    'row_id',
    TARGET
]

# Helper Functions

In [6]:
def prepare_data():
    print("Reading datasets...")
    dt = pd.read_parquet(PARQUETS_DIR + 'train.parquet')
    questions = pd.read_parquet(PARQUETS_DIR + 'questions.parquet')

    print("Merging datasets...")
    dt = dt[(dt['content_type_id'] == 0) & (dt['answered_correctly'] != -1)] \
        .drop(columns=['content_type_id'])
    questions.index.name = 'content_id'
    questions = questions.reset_index()
    dt = dt.merge(questions, on='content_id', how='left').reset_index(drop=True)

    print('Processing timestamps...')
    if 'timestamp' in dt.columns:
        timestamps_raw = dt['timestamp'].values.astype(int)
        dt['timestamp'] = dt[[KEY_FEATURE, 'timestamp']]\
            .groupby(KEY_FEATURE)['timestamp'].diff()
        while 0 in dt['timestamp'].values:
            dt['timestamp'][dt['timestamp'] == 0] = \
                dt['timestamp'].shift()[dt['timestamp'] == 0]
        dt['timestamp'] = dt['timestamp'].fillna(
            DEFAULT_VALUES['timestamp']).astype(int)
        dt['timestamp_raw'] = timestamps_raw

    print('Writing dataset to .parquet files...')
    dt.to_parquet(PARQUETS_DIR + 'train_merged.parquet')

    return dt


def load_data(filename, cols):
    return pd.read_parquet(PARQUETS_DIR + filename, columns=cols)


def preprocess(dt: pd.DataFrame):
    dt["content_id"] += 1
    if 'prior_question_elapsed_time' in dt.columns:
        dt["prior_question_elapsed_time"] = \
            dt["prior_question_elapsed_time"].fillna(
                DEFAULT_VALUES['prior_question_elapsed_time'])
        dt["prior_question_elapsed_time"] = \
            np.ceil(dt["prior_question_elapsed_time"] / 
                    TIME_SCALES['prior_question_elapsed_time']).astype(np.int32)
    if 'prior_question_had_explanation' in dt.columns:
        dt['prior_question_had_explanation'].fillna(
            DEFAULT_VALUES['prior_question_had_explanation'], inplace=True)
    if 'timestamp' in dt.columns:
        dt['timestamp'] = np.ceil(
            dt['timestamp'] / TIME_SCALES['timestamp']).astype(np.int32)

    for col in ["task_container_id", 
                "tag1", "tag2", "tag3", "tag4", "tag5", "tag6"]:
        if col in dt.columns:
            dt[col] += 1

    return dt


def preprocess_test(dt: np.ndarray, all_cols):
    dt[:, all_cols["content_id"]] += 1
    if 'prior_question_had_explanation' in all_cols:
        time_nans = pd.isnull(dt[:, all_cols["prior_question_elapsed_time"]])
        dt[time_nans, all_cols["prior_question_elapsed_time"]] = \
            DEFAULT_VALUES['prior_question_elapsed_time']
        dt[:, all_cols["prior_question_elapsed_time"]] = \
            np.ceil(dt[:, all_cols["prior_question_elapsed_time"]] / 
                    TIME_SCALES['prior_question_elapsed_time']).astype(np.int32)
    if 'prior_question_had_explanation' in all_cols:
        explanation_nans = pd.isnull(
            dt[:, all_cols["prior_question_had_explanation"]])
        dt[explanation_nans, all_cols["prior_question_had_explanation"]] = \
            DEFAULT_VALUES['prior_question_had_explanation']
    for col in ["task_container_id", 
                "tag1", "tag2", "tag3", "tag4", "tag5", "tag6"]:
        if col in all_cols:
            dt[:, all_cols[col]] += 1
    return dt


def split_train_valid(dt, val_fraction):
    if val_fraction == 1:
        return None, dt
    val_size = 0
    trn_size = 0
    val_uids = []
    val_fraction = val_fraction / (1 - val_fraction)
    n_samples_per_user = dt.groupby(KEY_FEATURE)[
        TARGET].count().sort_values().reset_index().values.tolist()
    while n_samples_per_user:
        uid, nsamples = n_samples_per_user.pop()
        if trn_size * val_fraction >= val_size:
            val_uids.append(uid)
            val_size += nsamples
        else:
            trn_size += nsamples

    val = dt[dt[KEY_FEATURE].isin(val_uids)]
    trn = dt.drop(val.index)
    return trn, val


def pad_batch(x, window_size, pad_value=0):
    pad_dims = (window_size - x.shape[0], 0)
    shape = (pad_dims,) + tuple(
        (0, 0) for i in range(len(x.shape) - 1))
    return np.pad(x, shape, constant_values=pad_value)


def create_scheduler(optim):
    sched = None
    if SCHED_TYPE == 'plateau':
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'max',
                                                           patience=0,
                                                           factor=0.5,
                                                           threshold=5e-6,
                                                           min_lr=1e-6)
    return sched


def create_optimizer(estimator, lr):
    return torch.optim.Adam(estimator.parameters(), lr=lr)


def create_model(model_type, **params):
    return model_type(**params).to(DEVICE)


def save_model(estimator, optim, sched=None, val_score:float=0, epoch=0):
    checkpoint = {'model_params': estimator.params,
                  'model_state_dict': estimator.state_dict(),
                  'optimizer_state_dict': optim.state_dict(),
                  'learning_rate': optim.param_groups[0]['lr'],
                  'val_score': val_score}
    if sched is not None:
        checkpoint = {**checkpoint,
                      'sheduler_state_dict': sched.state_dict(),
                      'epoch': epoch}
    torch.save(checkpoint, OUT_DIR + MODEL_FILENAME)


def load_model(model_type, for_training=False):
    if os.path.exists(OUT_DIR + MODEL_FILENAME):
        filepath = OUT_DIR + MODEL_FILENAME
    else:
        filepath = MODELS_DIR + f'transformer/{VERSION}/{MODEL_FILENAME}'

    print(f'Loading model from {filepath}')
    checkpoint = torch.load(filepath, map_location=DEVICE)

    estimator = create_model(model_type, **checkpoint['model_params'])
    estimator.load_state_dict(checkpoint['model_state_dict'])

    optim = None
    sched = None
    if for_training:
        optim = create_optimizer(estimator, lr=checkpoint['learning_rate'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])

        if 'sheduler_state_dict' in checkpoint:
            sched = create_scheduler(optim)
            sched.load_state_dict(checkpoint['sheduler_state_dict'])

    return estimator, optim, sched, checkpoint

#Transformer Model

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (
                -math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, pos):
        return self.pe[pos, :]


class TransformerEncoder(Module):
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        output = src

        for mod in self.layers:
            output = mod(output, mask=mask)

        if self.norm is not None:
            output = self.norm(output)

        return output


class TransformerDecoder(Module):
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, memory: Tensor, mask) -> Tensor:
        output = tgt

        for mod in self.layers:
            output = mod(output, memory, mask=mask)

        if self.norm is not None:
            output = self.norm(output)

        return output


class TransformerEncoderLayer(Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        src2 = self.self_attn(src, src, src, attn_mask=mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class TransformerDecoderLayer(Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead,
                                                 dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, 
                mask: Optional[Tensor] = None) -> Tensor:
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=mask )[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    elif activation == "tanh":
        return F.tanh
    elif activation == "sigmoid":
        return F.sigmoid

    raise RuntimeError(
        "activation should be relu/gelu, not {}".format(activation))

In [8]:
class TransformerModel(nn.Module):
    def __init__(self, **params):
        print(params)
        super().__init__()
        self.params = params.copy()
        self.d_model = params['d_model']
        self.nhead = params['nhead']
        self.seq_len = params.pop('seq_len') - 1
        n_enc_layers = params.pop('num_encoder_layers')
        n_dec_layers = params.pop('num_decoder_layers')
        dropout_rate = params['dropout']

        self.enc_norm = nn.LayerNorm(self.d_model)
        self.dec_norm = nn.LayerNorm(self.d_model)

        encoder_layer = TransformerEncoderLayer(**params)
        self.encoder = TransformerEncoder(encoder_layer=encoder_layer,
                                          num_layers=n_enc_layers,
                                          norm=self.enc_norm)
        decoder_layer = TransformerDecoderLayer(**params)
        self.decoder = TransformerDecoder(decoder_layer=decoder_layer,
                                          num_layers=n_dec_layers,
                                          norm=self.dec_norm)
        self.max_exercise = 13523
        self.max_part = 7
        self.max_explanation = 1
        self.max_tag = 188
        self.max_target = 2
        self.max_time = 300  # seconds
        self.max_timestamp = int((1 * 24 * 60 * 60 * 1000) /
                                 TIME_SCALES['timestamp'])  # 1 day

        self.position_embeddings = PositionalEncoding(self.d_model,
                                                      self.seq_len)
        self.exercise_embeddings = \
            nn.Embedding(num_embeddings=self.max_exercise + 1,
                         embedding_dim=self.d_model,
                         scale_grad_by_freq=True)
        self.part_embeddings = \
            nn.Embedding(num_embeddings=self.max_part + 1,
                         embedding_dim=self.d_model,
                         scale_grad_by_freq=True)
        self.explanation_embeddings = \
            nn.Embedding(num_embeddings=self.max_explanation + 1,
                         embedding_dim=self.d_model,
                         scale_grad_by_freq=True)
        self.timestamp_embeddings = (
            nn.Embedding(num_embeddings=self.max_timestamp + 1,
                         embedding_dim=self.d_model))
        self.elapsed_time_embeddings = (
            nn.Embedding(num_embeddings=self.max_time + 1,
                         embedding_dim=self.d_model))
        self.tag1_embeddings = \
            nn.Embedding(num_embeddings=self.max_tag + 1,
                         embedding_dim=self.d_model,
                         scale_grad_by_freq=True)
        self.tag2_embeddings = \
            nn.Embedding(num_embeddings=self.max_tag + 1,
                         embedding_dim=self.d_model,
                         scale_grad_by_freq=True)
        self.target_embeddings = \
            nn.Embedding(num_embeddings=self.max_target + 1,
                         embedding_dim=self.d_model,
                         scale_grad_by_freq=True)

        self.linear = nn.Linear(self.d_model, 1)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout_rate)

        self.device = DEVICE
        self.future_mask = self.get_future_mask().to(self.device)
        self.positions = torch.arange(0, self.seq_len,
                                      device=self.device).unsqueeze(0)
        self.init_weights()

    def get_future_mask(self):
        return torch.triu(torch.ones(self.seq_len, self.seq_len) == 1, 
                          diagonal=1)

    def init_weights(self):
        initrange = 0.1
        self.exercise_embeddings.weight.data.uniform_(-initrange, initrange)
        self.part_embeddings.weight.data.uniform_(-initrange, initrange)
        self.explanation_embeddings.weight.data.uniform_(-initrange, initrange)
        self.elapsed_time_embeddings.weight.data.uniform_(-initrange,
                                                          initrange)
        self.timestamp_embeddings.weight.data.uniform_(-initrange, initrange)
        self.tag1_embeddings.weight.data.uniform_(-initrange, initrange)
        self.tag2_embeddings.weight.data.uniform_(-initrange, initrange)
        self.target_embeddings.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, inputs: dict):
        (content_ids,
         timestamps,
         elapsed_times,
         explanations,
         parts,
         tags1, tags2,
         targets) = inputs.values()  # (N, S)

        content_ids[content_ids > self.max_exercise] = 0
        parts[parts > self.max_part] = 0
        timestamps[timestamps > self.max_timestamp] = self.max_timestamp
        elapsed_times[elapsed_times > self.max_time] = self.max_time

        embedded_inp = (self.exercise_embeddings(content_ids)
                        + self.part_embeddings(parts)
                        + self.explanation_embeddings(explanations)
                        + self.elapsed_time_embeddings(elapsed_times)
                        + self.timestamp_embeddings(timestamps)
                        + self.tag1_embeddings(tags1)
                        + self.tag2_embeddings(tags2) * 0.8
                        + self.target_embeddings(targets)
                        ).transpose(0, 1) * np.sqrt(self.d_model) # (S, N, E)

        pos = self.positions.repeat(content_ids.shape[0], 1)
        embedded_pos = self.position_embeddings(pos).transpose(0, 1)

        output = self.encoder(src=(embedded_inp + embedded_pos),
                              mask=self.future_mask)  # (S, N, E)

        output = (self.target_embeddings(targets).transpose(0, 1)
                  + output) * np.sqrt(self.d_model)
        tgt = (self.exercise_embeddings(content_ids)
               + self.target_embeddings(targets)
               ).transpose(0, 1) * np.sqrt(self.d_model)
        output = self.decoder(tgt=(tgt + embedded_pos),
                              memory=(output + embedded_pos),
                              mask=self.future_mask)  # (S, N, E)

        output = self.linear(output.transpose(1, 0)).squeeze(-1)
        output = self.sigmoid(output)

        return output

# Training Functions

In [9]:
class Riiid(torch.utils.data.Dataset):
    def __init__(self, dt, seq_len, pad_value=0, is_training=True):
        super().__init__()
        self.seq_len = seq_len
        self.min_seq_len = 20
        self.pad_value = pad_value
        self.is_training = is_training
        self.dcols = {col: i for i, col in
                      enumerate(FEATURES + [TARGET])}
        self.users_data = dt.groupby(KEY_FEATURE).apply(
            lambda r: r[FEATURES + [TARGET]].values).values
        if is_training:
            samples = [[gi] * (min(1000, len(g)) // self.seq_len + 1)
                        for gi, g in enumerate(self.users_data) 
                        if len(g) > self.min_seq_len]
        else:
            samples = [[gi] * (len(g) // self.seq_len + 1)
                        for gi, g in enumerate(self.users_data) 
                        if len(g) > self.min_seq_len]            
        self.samples = [(gi, len(self.users_data[gi]) - gii * self.seq_len) 
                        for g in samples for gii, gi in enumerate(g) 
                        if len(self.users_data[gi]) - gii * self.seq_len > 
                        self.min_seq_len]
        random.shuffle(self.samples)
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        uid, end_idx = self.samples[idx]
        inputs = self.users_data[uid]
        if self.is_training:
            end_idx = np.random.randint(self.min_seq_len, len(inputs))
        start_idx = max(0, end_idx - self.seq_len)
        inputs = inputs[start_idx: end_idx]
        targets = inputs[:, self.dcols[TARGET]] + 1
        inputs = pad_batch(inputs, self.seq_len, self.pad_value)
        targets = pad_batch(targets, self.seq_len, self.pad_value)
        inputs = np.c_[inputs[1:], targets[:-1]]
        return inputs


def collate_fn(batch):
    return np.array(batch).transpose(2, 0, 1)


def train_epoch(estimator, trn_iterator, optim, criterion,
                device="cpu"):
    target_idx = -1
    label_idx = -2
    estimator.train()

    num_corrects = 0
    loss_sum = 0
    batch_count = 0
    sample_count = 0

    if TPU:
        trn_iterator = pl.ParallelLoader(trn_iterator,
                                         [device]).per_device_loader(device)
    tbar = tqdm(trn_iterator, ncols=700)
    for batch in tbar:
        inputs = {}
        for i, feat in enumerate(FEATURES):
            tens = torch.Tensor(batch[i]).to(device)
            if DTYPES[feat] is int:
                inputs[feat] = tens.long()
            elif DTYPES[feat] is float:
                inputs[feat] = tens.float()
            elif DTYPES[feat] is bool:
                inputs[feat] = tens.bool()
        inputs['targets'] = torch.Tensor(batch[target_idx]).to(
            device).long()
        labels = torch.Tensor(batch[label_idx]).to(
            device).float()
        optim.zero_grad()
        output = estimator(inputs=inputs)
        mask = inputs['targets'] != PAD_VALUE
        output = torch.masked_select(output, mask)
        labels = torch.masked_select(labels, mask)
        loss = criterion(output, labels)
        loss.backward()
        if TPU:
            xm.optimizer_step(optim)
        else:
            optim.step()

        loss_sum += loss.item()
        pred = (output >= 0.5).long()
        num_corrects += (pred == labels).sum().item()
        batch_count += 1
        sample_count += len(labels)

        tbar.set_description(
            'trn_loss - {:.4f}'.format(loss_sum / batch_count))

    acc = num_corrects / sample_count
    loss = loss_sum / batch_count

    return loss, acc


def val_epoch(estimator, val_iterator, criterion, device="cpu"):
    target_idx = -1
    label_idx = -2
    estimator.eval()

    loss_sum = 0
    batch_count = 0
    num_corrects = 0
    sample_count = 0
    truth = np.empty(0)
    outs = np.empty(0)
    truth_single = np.empty(0)
    outs_single = np.empty(0)
    
    if TPU:
        val_iterator = pl.ParallelLoader(val_iterator,
                                         [device]).per_device_loader(device)
    tbar = tqdm(val_iterator, ncols=700)
    for batch in tbar:
        inputs = {}
        for i, feat in enumerate(FEATURES):
            tens = torch.Tensor(batch[i]).to(device)
            if DTYPES[feat] is int:
                inputs[feat] = tens.long()
            elif DTYPES[feat] is float:
                inputs[feat] = tens.float()
            elif DTYPES[feat] is bool:
                inputs[feat] = tens.bool()
        inputs['targets'] = torch.Tensor(batch[target_idx]).to(device).long()
        labels = torch.Tensor(batch[label_idx]).to(device).float()

        with torch.no_grad():
            output = estimator(inputs=inputs)
        output_single = output[:, -1]
        labels_single = labels[:, -1]
        mask = inputs['targets'] != PAD_VALUE
        output = torch.masked_select(output, mask)
        labels = torch.masked_select(labels, mask)
        loss = criterion(output, labels)
        loss_sum += loss.item()
        batch_count += 1

        pred = (output >= 0.5).long()
        num_corrects += (pred == labels).sum().item()
        sample_count += len(labels)
        truth = np.r_[truth, labels.view(-1).data.cpu().numpy()]
        outs = np.r_[outs, output.view(-1).data.cpu().numpy()]
        truth_single = np.r_[truth_single, 
                             labels_single.view(-1).data.cpu().numpy()]
        outs_single = np.r_[outs_single, 
                            output_single.view(-1).data.cpu().numpy()]
        
        tbar.set_description(
            'val_loss - {:.4f}'.format(loss_sum / batch_count))

    loss = loss_sum / batch_count
    acc = num_corrects / sample_count
    auc = roc_auc_score(truth, outs)
    single_auc = roc_auc_score(truth_single, outs_single)
    return loss, acc, auc, single_auc


def train_transformer(estimator_type, estimator_params, epochs, batch_size,
                      device, early_stopping, eps, nworkers, cont, debug,
                      debug_size, data_size, valid_size):
    train = load_data('train_merged.parquet', 
                      cols=[KEY_FEATURE] + FEATURES + [TARGET])
    print('Using Columns -', list(train.columns))
    print('total size:', train.shape)
    data_size = int(len(train) * (debug_size if debug else data_size))
    train = train.tail(data_size)
    print('data size:', train.shape,
          '- num users:', train['user_id'].nunique())
    train = preprocess(train)
    train, valid = split_train_valid(train, valid_size)
    print('train size:', train.shape,
          '- num users:', train['user_id'].nunique())
    print('valid size:', valid.shape,
          '- num users:', valid['user_id'].nunique())
    gc.collect()

    
    last_auc = 0
    last_epoch = 0
    if cont and not debug:
        estimator, optim, sched, cp = \
            load_model(MODEL, for_training=True)
        last_auc = cp['val_score']
        last_epoch = cp['epoch']
        print('Previous Validation AUC:', last_auc)
    else:
        estimator = create_model(estimator_type, **estimator_params)
        optim = create_optimizer(estimator, lr=LEARNING_RATE)
        sched = create_scheduler(optim)

    trn_dataset = Riiid(dt=train, seq_len=SEQ_LEN, pad_value=PAD_VALUE, 
                        is_training=True)
    val_dataset = Riiid(dt=valid, seq_len=SEQ_LEN, pad_value=PAD_VALUE, 
                        is_training=False)

    if TPU:
        trn_sampler = torch.utils.data.distributed.DistributedSampler(
            trn_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)

        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)

    trn_dataloader = torch.utils.data.DataLoader(
        dataset=trn_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=nworkers,
        sampler=trn_sampler if TPU else None,
        shuffle=False if TPU else True)
    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=nworkers,
        sampler=val_sampler if TPU else None,
        shuffle=False)

    criterion = nn.BCELoss().to(device)
    over_fit = 0
    seed_everything(RANDOM_SEED)
    for epoch in range(last_epoch, epochs):
        if SCHED_TYPE == 'lambda':
            for g in optim.param_groups:
                g['lr'] = ((estimator.d_model ** -0.5) * 
                           min((epoch + 1) ** -0.5, 
                               (epoch + 1) * (WARMUP_STEPS ** -1.5)))
        print(f'\nEpoch {epoch + 1} - Learning rate:',
              optim.param_groups[0]['lr'])
        trn_loss, trn_acc = train_epoch(estimator, trn_dataloader,
                                        optim, criterion, device)
        print("  Training -- loss: {:.6f} - acc: {:.6f}" \
              .format(trn_loss, trn_acc))

        val_loss, val_acc, val_auc, val_single_auc = \
            val_epoch(estimator, val_dataloader, criterion, device)
        color = '\033[91m' if val_single_auc > last_auc and epoch != 0 else ''
        print(color + "  Validation -- loss: {:.6f} - acc: {:.6f} "\
              .format(val_loss, val_acc) + 
              "- auc: {:.6f} - custom_auc: {:.6f}"\
              .format(val_auc, val_single_auc) + '\033[0m')
        if SCHED_TYPE == 'plateau':
            sched.step(val_auc)

        if val_single_auc > last_auc + eps:
            last_auc = val_single_auc
            over_fit = 0
            save_model(estimator, optim, sched, val_single_auc, epoch)
        else:
            over_fit += 1

        if over_fit >= early_stopping:
            print("early stop epoch ", epoch + 1)
            break

    return estimator


def tpu_map_fn(index, flags):
    train_transformer(flags)


# ========================== TEST =======================================


class RiiidTest(torch.utils.data.Dataset):
    def __init__(self, dt, queries, seq_len, pad_value=0, local=False):
        self.data = dt
        self.dcols = {col: i for i, col in enumerate(
            [KEY_FEATURE] + FEATURES + [TARGET, 'timestamp_raw'])}

        self.queries = queries.copy()
        self.groups = None
        if local:
            self.groups = queries.groupby(KEY_FEATURE) \
                .apply(lambda r: r.values).values.tolist()
        self.seq_len = seq_len
        self.pad_value = pad_value
        self.is_local = local

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        if self.is_local:
            random.shuffle(self.groups)
            query = self.groups[0][[0]]
            query_label = query[0, self.dcols[TARGET]]
            query = np.delete(query, self.dcols[TARGET], axis=1)
            self.groups[0] = self.groups[0][1:]
            if self.groups[0].shape[0] == 0:
                self.groups = self.groups[1:]
        else:
            query = self.queries[[idx]]
        uid = query[0, self.dcols[KEY_FEATURE]]

        if uid in self.data:
            inputs = self.data[uid][-self.seq_len + 1:]
            labels = inputs[:, self.dcols[TARGET]]
            last_timestamp = inputs[-1, self.dcols['timestamp_raw']]
            curr_timestamp = query[0, self.dcols['timestamp']]
            diff_timestamp = np.ceil((curr_timestamp - last_timestamp) / 
                                     TIME_SCALES['timestamp'])
            query[0, self.dcols['timestamp']] = diff_timestamp
            query = np.delete(query, self.dcols[KEY_FEATURE], axis=1)
            inputs = np.delete(inputs, [self.dcols[KEY_FEATURE],
                                        self.dcols[TARGET],
                                        self.dcols['timestamp_raw']], axis=1)
            inputs = np.r_[inputs, query]
        else:
            query[0, self.dcols['timestamp']] = DEFAULT_VALUES['timestamp']
            query = np.delete(query, self.dcols[KEY_FEATURE], axis=1)
            inputs = query
            labels = np.empty(0)

        targets = labels + 1
        inputs = pad_batch(inputs, self.seq_len, self.pad_value)
        targets = pad_batch(targets, self.seq_len, self.pad_value)

        if self.is_local:
            labels = np.r_[labels, query_label]
            labels = pad_batch(labels, self.seq_len,
                               self.pad_value)
            uids = np.full(inputs.shape[0], uid)
            inputs = np.c_[inputs, uids, labels]

        return np.c_[inputs[1:], targets[:-1]].astype(np.float32)


def collate_fn_test(batch):
    return np.array(batch).transpose(2, 0, 1)


def update_stats(prev_data, prev_batch):
    dcols = {col: i for i, col in
             enumerate([KEY_FEATURE] + FEATURES + 
                       [TARGET, 'timestamp_raw'])}

    def update_stat(trow):
        uid = trow[0]
        if uid in prev_data:
            last_timestamp = prev_data[uid][-1, dcols['timestamp_raw']]
            curr_timestamp = trow[dcols['timestamp']]
            diff_timestamp = np.ceil((curr_timestamp - last_timestamp) / 
                                     TIME_SCALES['timestamp'])
            trow[dcols['timestamp']] = diff_timestamp
            trow = np.r_[trow, curr_timestamp]  # Add timestamp_raw of the query
            prev_data[uid] = np.r_[prev_data[uid][-SEQ_LEN + 2:], [trow]] \
                .astype(int)
        else:
            curr_timestamp = trow[dcols['timestamp']]
            trow[dcols['timestamp']] = DEFAULT_VALUES['timestamp']
            trow = np.r_[trow, curr_timestamp]  # Add timestamp_raw of the query
            prev_data[uid] = np.array([trow])
            
    np.apply_along_axis(update_stat, arr=prev_batch, axis=1)


def predict_test(estimator,
                 tst_iterator,
                 local=False,
                 prev_data=None,
                 device="cpu"):
    target_idx = -1
    estimator.eval()

    outs = np.empty(0)
    for batch in tst_iterator:
        inputs = {}
        for i, feat in enumerate(FEATURES):
            if DTYPES[feat] is int:
                inputs[feat] = torch.Tensor(batch[i].astype(int)) \
                    .to(device).long()
            elif DTYPES[feat] is float:
                inputs[feat] = torch.Tensor(batch[i]).to(device).float()
            elif DTYPES[feat] is bool:
                inputs[feat] = torch.Tensor(batch[i].astype(bool)) \
                    .to(device).bool()
        inputs['targets'] = torch.Tensor(batch[target_idx].astype(np.int64)) \
            .to(device).long()

        with torch.no_grad():
            output = estimator(inputs=inputs)[:, -1]
        outs = np.r_[outs, output.data.cpu().numpy()]

    return outs


def predict_submission_group(estimator,
                             tst_batch,
                             prev_batch,
                             prev_data,
                             batch_size=128):
    all_cols = list(tst_batch.columns) + ADDED_FEATURES + [TARGET]
    all_cols = dict(zip(all_cols, range(len(all_cols))))
    used_cols = [all_cols[feat] for feat in [KEY_FEATURE] + FEATURES]
    tst_batch = tst_batch.values

    if (prev_batch is not None) & (psutil.virtual_memory().percent < 95):
        prev_batch = np.c_[prev_batch, eval(
            tst_batch[0, all_cols['prior_group_answers_correct']])]
        prev_batch = prev_batch[prev_batch[:, all_cols['content_type_id']] == 0
                                ][:, used_cols + [all_cols[TARGET]]]
        update_stats(prev_data, prev_batch)

    default_values = [DEFAULT_VALUES[feat] for feat in ADDED_FEATURES]
    question_feats = np.apply_along_axis(
        lambda rid: [QUESTIONS_DF[rid[0]][feat] for feat in ADDED_FEATURES]
        if rid[0] in QUESTIONS_DF else default_values,
        axis=1, arr=tst_batch[:, [all_cols['content_id']]])
    tst_batch = np.c_[tst_batch, question_feats]
    tst_batch = preprocess_test(tst_batch, all_cols)
    prev_batch = tst_batch.copy()

    qrows = tst_batch[:, all_cols['content_type_id']] == 0
    tst_batch = tst_batch[qrows]
    tst_dataset = RiiidTest(prev_data, tst_batch[:, used_cols], SEQ_LEN)
    tst_dataloader = torch.utils.data.DataLoader(dataset=tst_dataset,
                                                 batch_size=batch_size,
                                                 collate_fn=collate_fn_test,
                                                 num_workers=0)

    _preds = predict_test(estimator, tst_dataloader, device=DEVICE)
    tst_batch = np.c_[tst_batch, _preds]
    _predictions = pd.DataFrame(
        tst_batch[:, [all_cols[col] for col in SUBMISSION_COLUMNS]],
        columns=SUBMISSION_COLUMNS)

    return {'preds': _predictions,
            'prev_batch': prev_batch,
            'prev_data': prev_data}


def load_previous_data(pdata_path):
    pdata = preprocess(load_data(pdata_path, cols=[KEY_FEATURE] + 
                                 FEATURES +  [TARGET, 'timestamp_raw']))
    gc.collect()
    print(f'getting {SEQ_LEN} long sequences...')
    pdata = pdata.groupby(KEY_FEATURE).tail(SEQ_LEN)
    pdata = pdata.groupby(KEY_FEATURE).apply(lambda g: g.values)
    gc.collect()
    qdata = pd.read_parquet(PARQUETS_DIR + 
                            'questions.parquet').to_dict('index')
    return pdata, qdata

# Train

In [10]:
MODEL = TransformerModel  # @param ["TransformerModel"] {type:"raw"}

MODEL_FILENAME = f'{MODEL.__name__}_best.pth'  # @param {type:"string"}


# @markdown # Model Settings

SEQ_LEN = 257  # @param {type:"integer"}

D_MODEL = 256  # @param {type:"integer"}

NHEAD = 8  # @param {type:"integer"}

N_ENC_LAYERS = 2  # @param {type:"integer"}

N_DEC_LAYERS = 2  # @param {type:"integer"}

DIM_FEEDFORWARD = 2048  # @param {type:"integer"}

DROPOUT = 0.1  # @param {type:"slider", min:0, max:1, step:0.05}

ACTIVATION = "gelu"  # @param ["relu", "gelu"]

LEARNING_RATE = 0.0001  # @param {type:"number"}


# @markdown # Training Settings

EPOCHS = 100  # @param {type:"integer"}

BATCH_SIZE = 64  # @param {type:"integer"}

EARLY_STOPPING = 10  # @param {type:"integer"}

NUM_WORKERS = 0  # @param {type:"integer"}

SCHED_TYPE = "plateau"  # @param ["plateau", "lambda", "NONE"]

WARMUP_STEPS = 4000  # @param {type:"integer"}




PAD_VALUE = 0
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") \
    if not TPU else xm.xla_device()

In [11]:
RETRAIN = True  # @param {type:"boolean"}

CONTINUE = False  # @param {type:"boolean"}

DEBUG = False  # @param {type:"boolean"}

DATA_SIZE = 1  # @param {type:"slider", min:0.01, max:1, step:0.01}

DEBUG_SIZE = 0.1  # @param {type:"slider", min:0.01, max:1, step:0.01}

LOCAL_SAMPLE = False  # @param {type:"boolean"}

In [None]:
if __name__ == '__main__':
    print('Using Device -', DEVICE)
    print('Using Model -', MODEL.__name__)
    model_params = {
        'seq_len': SEQ_LEN,
        'd_model': D_MODEL,
        'nhead': NHEAD,
        'num_encoder_layers': N_ENC_LAYERS,
        'num_decoder_layers': N_DEC_LAYERS,
        'dim_feedforward': DIM_FEEDFORWARD,
        'dropout': DROPOUT,
        'activation': ACTIVATION
    }

    train_params = {
        'estimator_type': MODEL,
        'estimator_params': model_params,
        'epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'early_stopping': EARLY_STOPPING,
        'nworkers': NUM_WORKERS,
        'device': DEVICE,
        'eps': 1e-6,
        'cont': CONTINUE,
        'debug': DEBUG,
        'data_size': DATA_SIZE,
        'debug_size': DEBUG_SIZE,
        'valid_size': 0.1
    }

    if RETRAIN and not COMPILE:
        print('Parameters:', train_params)
        if TPU:
            xmp.spawn(tpu_map_fn, args=(train_params,), nprocs=8, 
                      start_method='fork')
        else:
            model = train_transformer(**train_params)

Using Device - cuda:0
Using Model - TransformerModel
Parameters: {'estimator_type': <class '__main__.TransformerModel'>, 'estimator_params': {'seq_len': 257, 'd_model': 256, 'nhead': 8, 'num_encoder_layers': 2, 'num_decoder_layers': 2, 'dim_feedforward': 2048, 'dropout': 0.1, 'activation': 'gelu'}, 'epochs': 100, 'batch_size': 64, 'early_stopping': 10, 'nworkers': 0, 'device': device(type='cuda', index=0), 'eps': 1e-06, 'cont': False, 'debug': False, 'data_size': 1, 'debug_size': 0.1, 'valid_size': 0.1}
Using Columns - ['user_id', 'content_id', 'timestamp', 'prior_question_elapsed_time', 'prior_question_had_explanation', 'part', 'tag1', 'tag2', 'answered_correctly']
total size: (99271300, 9)
data size: (99271300, 9) - num users: 393656
train size: (89344170, 9) - num users: 354295
valid size: (9927130, 9) - num users: 39361
{'seq_len': 257, 'd_model': 256, 'nhead': 8, 'num_encoder_layers': 2, 'num_decoder_layers': 2, 'dim_feedforward': 2048, 'dropout': 0.1, 'activation': 'gelu'}

Epoch

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.561163 - acc: 0.709668


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.541736 - acc: 0.724489 - auc: 0.761726 - custom_auc: 0.766612[0m

Epoch 2 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.541063 - acc: 0.722509


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.525650 - acc: 0.732457 - auc: 0.776894 - custom_auc: 0.776189[0m

Epoch 3 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.533975 - acc: 0.726331


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.522515 - acc: 0.734564 - auc: 0.780570 - custom_auc: 0.780706[0m

Epoch 4 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.530640 - acc: 0.728371


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.520654 - acc: 0.735573 - auc: 0.782469 - custom_auc: 0.782450[0m

Epoch 5 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.528391 - acc: 0.729845


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.518911 - acc: 0.736614 - auc: 0.784251 - custom_auc: 0.784358[0m

Epoch 6 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.526218 - acc: 0.731366


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.517989 - acc: 0.737284 - auc: 0.785477 - custom_auc: 0.785865[0m

Epoch 7 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.524382 - acc: 0.732511


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.517307 - acc: 0.738070 - auc: 0.786610 - custom_auc: 0.786990[0m

Epoch 8 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.522850 - acc: 0.733647


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.516803 - acc: 0.738763 - auc: 0.787397 - custom_auc: 0.787486[0m

Epoch 9 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.521206 - acc: 0.734798


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


[91m  Validation -- loss: 0.515382 - acc: 0.739443 - auc: 0.788345 - custom_auc: 0.789655[0m

Epoch 10 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.519619 - acc: 0.735936


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.514885 - acc: 0.739585 - auc: 0.788606 - custom_auc: 0.788490[0m

Epoch 11 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.518072 - acc: 0.737111


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.514851 - acc: 0.739956 - auc: 0.788941 - custom_auc: 0.789917[0m

Epoch 12 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.516610 - acc: 0.738038


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.514795 - acc: 0.740146 - auc: 0.789243 - custom_auc: 0.790131[0m

Epoch 13 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.514995 - acc: 0.739217


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.514954 - acc: 0.740422 - auc: 0.789468 - custom_auc: 0.791141[0m

Epoch 14 - Learning rate: 0.0001


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.513208 - acc: 0.740425


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.515881 - acc: 0.740163 - auc: 0.789144 - custom_auc: 0.790713[0m

Epoch 15 - Learning rate: 5e-05


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.510193 - acc: 0.742446


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.514358 - acc: 0.740656 - auc: 0.789868 - custom_auc: 0.792054[0m

Epoch 16 - Learning rate: 5e-05


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…


  Training -- loss: 0.508635 - acc: 0.743480


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=951.0), HTML(value='')), layout=Layout(di…


  Validation -- loss: 0.516144 - acc: 0.740264 - auc: 0.789257 - custom_auc: 0.791602[0m

Epoch 17 - Learning rate: 2.5e-05


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6642.0), HTML(value='')), layout=Layout(d…

# Test

In [None]:
# ============================= TESTING ===================================
print('TESTING')
print('loading previous data...')
previous_data, QUESTIONS_DF = load_previous_data('train_merged.parquet')
model, _, _, _ = load_model(MODEL)
model.eval()
gc.collect()

if not COMPILE:
    print('Submitting locally...')
    previous_batch = None
    tgts = []
    example_test = pd.read_csv(DATA_DIR + 'example_test.csv')
    submission = pd.DataFrame(columns=SUBMISSION_COLUMNS)
    for gnum in tqdm(example_test['group_num'].unique()):
        test_batch = example_test[
            example_test['group_num'] == gnum].copy()
        preds, previous_batch, previous_data = predict_submission_group(
            model, test_batch,
            previous_batch,
            previous_data,
            batch_size=1024).values()
        tgts.extend(
            eval(test_batch['prior_group_answers_correct'].iloc[0]))
        submission = submission.append(preds)
    tgts.extend([-1] * len(test_batch))
    submission['target'] = tgts
    submission['pred'] = (submission[TARGET] >= 0.5).astype(np.int8)
    submission = submission.reset_index(drop=True)
    acc = submission[submission['target'] != -1]['target'] == \
          submission[submission['target'] != -1]['pred']
    acc = sum(acc) / len(acc)
    print(submission)
    print('Accuracy:', acc)
    print('AUC:', roc_auc_score(submission[submission['target'] != -1]['target'], 
                                submission[submission['target'] != -1]['pred']))
else:
    print('Submitting...')
    env = riiideducation.make_env()
    previous_batch = None
    for test_batch, _ in env.iter_test():
        preds, previous_batch, previous_data = \
            predict_submission_group(model,
                                     test_batch,
                                     previous_batch,
                                     previous_data,
                                     batch_size=1024).values()
        env.predict(preds)