<a href="https://colab.research.google.com/github/bagherig/riiid/blob/main/riiid_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Dec 20 02:26:38 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    23W / 300W |      0MiB / 16130MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 13.7 gigabytes of available RAM

To enable a high-RAM runtime, select the Runtime > "Change runtime type"
menu, and then select High-RAM in the Runtime shape dropdown. Then, 
re-execute this cell.


# Global Variables

In [None]:
# %% [code]

COMPILE = False #@param {type:"boolean"}
VERSION = 'v1' #@param ['v1', 'v2']

DRIVE_DIR = '/content/drive/MyDrive/Colab Notebooks/datasets' + \
             '/riiid-answer-correctness-prediction/'
HOME_DIR = './'
DATA_DIR = DRIVE_DIR + 'data/riiid-test-answer-prediction/'
DRIVE_PARQUETS_DIR = DRIVE_DIR + 'data/parquets/'
PARQUETS_DIR = HOME_DIR + 'parquets/'
MODELS_DIR = DRIVE_DIR + 'models/'

OUT_DIR = DRIVE_DIR + 'temp/'

In [None]:
TARGET = 'answered_correctly'
KEY_FEATURE = 'user_id'
FEATURES = [
    'content_id',
    'prior_question_elapsed_time',
    # 'prior_question_had_explanation',
    'task_container_id',
    'part',
    # 'tag_1',
]

DTYPES = {
    'content_id': int,
    'prior_question_elapsed_time': int,
    # 'prior_question_had_explanation': bool,
    'task_container_id': int,
    'part': int,
    # 'tag_1': int,
}

ADDED_FEATURES = [
    'part'
]

SUBMISSION_COLUMNS = [
    'row_id',
    TARGET
]

# Imports

In [None]:
# import riiideducation 

import os
import gc
import sys
import math
import random
import psutil
import warnings

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm

RANDOM_SEED = 1
np.random.seed(RANDOM_SEED)

warnings.filterwarnings('ignore')
gc.enable()

if not os.path.exists(PARQUETS_DIR):
    sys.path.append(DRIVE_PARQUETS_DIR)
    zip_path = DRIVE_PARQUETS_DIR + 'transformer.zip'
    !cp '{zip_path}' $HOME_DIR
    !unzip -q 'transformer.zip'
    !rm 'transformer.zip'

if torch.__version__ != '1.6.0+cu101':
    !pip uninstall torch==1.7.0 -y
    !pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
    os.kill(os.getpid(), 9)
print(torch.__version__)

1.6.0+cu101


In [None]:
TPU = 'COLAB_TPU_ADDR' in os.environ
if TPU:
    !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
    
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl

# Helper Functions

In [None]:
def prepare_data():
    dt = load_data('train.parquet')
    questions = load_data('questions.parquet')
    
    dt = dt[(dt['content_type_id'] == 0) & (dt['answered_correctly'] != -1)] \
        .drop(columns=['content_type_id'])

    questions.index.name = 'content_id'
    questions = questions.reset_index()

    dt = dt.merge(questions, on='content_id', how='left')
    dt.to_parquet(DRIVE_PARQUETS_DIR + 'train_merged.parquet')

    trn, tst = split_train_valid(dt, 0.025)
    trn.to_parquet(DRIVE_PARQUETS_DIR + 'train_transformer.parquet')
    tst.to_parquet(DRIVE_PARQUETS_DIR + 'test_transformer.parquet')


def load_data(filename):
    return pd.read_parquet(PARQUETS_DIR + filename,
                           columns=[KEY_FEATURE] + FEATURES + [TARGET])


def preprocess(dt):
    dt["prior_question_elapsed_time"] = \
        dt["prior_question_elapsed_time"].fillna(26000)
    dt["prior_question_elapsed_time"] = \
        np.ceil(dt["prior_question_elapsed_time"] / 1000).astype(np.int16)
    dt["content_id"] += 1
    dt['task_container_id'] += 1
    # dt[["tag1", "tag2", "tag3", "tag4", "tag5", "tag6"]] += 1

    return dt


def split_train_valid(dt, val_fraction):
    val_size = 0
    trn_size = 0
    val_uids = []
    n_samples_per_user = dt.groupby(KEY_FEATURE)[
        TARGET].count().sort_values().reset_index().values.tolist()
    while n_samples_per_user:
        uid, nsamples = n_samples_per_user.pop()
        if trn_size * val_fraction > val_size:
            val_uids.append(uid)
            val_size += nsamples
        else:
            trn_size += nsamples

    val = dt[dt[KEY_FEATURE].isin(val_uids)]
    trn = dt.drop(val.index)
    return trn, val


def pad_batch(x, window_size, pad_value=0):
    shape = ((0, window_size - x.shape[0]),) + tuple(
        (0, 0) for i in range(len(x.shape) - 1))
    return np.pad(x, shape, constant_values=pad_value)


def rolling_window(a, w):
    s0, s1 = a.strides
    m, n = a.shape
    return np.lib.stride_tricks.as_strided(
        a,
        shape=(m - w + 1, w, n),
        strides=(s0, s0, s1))


def make_time_series(x, windows_size, pad_value=0):
    x = np.pad(x, [[0, windows_size - 1], [0, 0]], constant_values=pad_value)
    x = rolling_window(x, windows_size)
    return x


def create_scheduler(estimator, optim, warmup_steps=10, last_epoch=-1):
    lr_lambda = lambda epoch: ((estimator.d_model ** -0.5) *
                               min(((epoch + 1) ** -0.5),
                                   (epoch + 1) * (warmup_steps ** -1.5)))
    sched = torch.optim.lr_scheduler.LambdaLR(optim,
                                              lr_lambda=lr_lambda,
                                              last_epoch=last_epoch)
    # sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'max', 
    #                                                    min_lr=1e-6)
    return sched


def create_optimizer(estimator, lr):
    return torch.optim.Adam(estimator.parameters(), lr=lr)


def create_model(model_type, **params):
    return model_type(**params)


def save_model(estimator, optim, sched=None, val_score:float=0):
    checkpoint = {'model_params': estimator.params,
                  'model_state_dict': estimator.state_dict(),
                  'optimizer_state_dict': optim.state_dict(),
                  'learning_rate': optim.param_groups[0]['lr'],
                  'val_score': val_score}
    if sched is not None:
        checkpoint = {**checkpoint,
                      'sheduler_state_dict': sched.state_dict(),
                      'epoch': sched.last_epoch}
    torch.save(checkpoint, OUT_DIR + MODEL_FILENAME)

def load_model(model_type, for_training=False, warmup_steps=10):
    if os.path.exists(OUT_DIR + MODEL_FILENAME):
        filepath = OUT_DIR + MODEL_FILENAME
    else:
        filepath = MODELS_DIR + f'transformer/{VERSION}/{MODEL_FILENAME}'

    print(f'Loading model from {filepath}')
    checkpoint = torch.load(filepath, map_location=DEVICE)

    estimator = create_model(model_type, **checkpoint['model_params'])
    estimator.load_state_dict(checkpoint['model_state_dict'])
    estimator.to(DEVICE)

    optim = None
    sched = None
    if for_training:
        optim = create_optimizer(estimator, lr=checkpoint['learning_rate'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])

        if 'sheduler_state_dict' in checkpoint:
            sched = create_scheduler(estimator, optim, warmup_steps,
                                     last_epoch=checkpoint['epoch'])
            sched.load_state_dict(checkpoint['sheduler_state_dict'])

    return estimator, optim, sched, checkpoint['val_score']

# TransformerEncoder Model

In [None]:
# ========================== Transformer Model ================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=96):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (
                -math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerEncoderModel(nn.TransformerEncoder):
    def __init__(self, **params):
        print(params)
        self.params = params.copy()
        self.d_model = params['d_model']
        self.seq_len = params.pop('seq_len', 96)
        n_layers = params.pop('num_encoder_layers', 4)
        dropout_rate = params['dropout']
        encoder_layer = nn.TransformerEncoderLayer(**params)
        encoder_norm = nn.LayerNorm(self.d_model)
        super().__init__(encoder_layer=encoder_layer,
                         num_layers=n_layers,
                         norm=encoder_norm)
        self.max_exercise = 13523
        self.max_part = 7
        self.max_time = 300
        self.max_container = 10000
        self.max_target = 2
        self.start_token = START_TOKEN
        self.pad_value = PAD_VALUE

        # self.pos_embedding = nn.Embedding(self.d_model, self.d_model) # positional embeddings
        self.pos_embedding = PositionalEncoding(self.d_model,
                                                dropout_rate,
                                                self.seq_len)
        self.exercise_embeddings = \
            nn.Embedding(num_embeddings=self.max_exercise + 1,
                         embedding_dim=self.d_model)  # exercise_id
        self.part_embeddings = \
            nn.Embedding(num_embeddings=self.max_part + 1,
                         embedding_dim=self.d_model)
        self.elapsed_time_embeddings = (# nn.Linear(self.seq_len, self.d_model))
            nn.Embedding(num_embeddings=self.max_time + 1,
                         embedding_dim=self.d_model))
        self.target_embeddings = \
            nn.Embedding(num_embeddings=self.max_target + 2,
                         embedding_dim=self.d_model)
        self.linear1 = nn.Linear(self.d_model, 1)
        self.linear2 = nn.Linear(self.seq_len, 1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        # self.dropout2 = nn.Dropout(dropout_rate)
        self.norm1 = nn.LayerNorm(self.d_model)

        self.device = DEVICE
        self.future_mask = self.generate_square_subsequent_mask(
            self.seq_len).to(self.device)
        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_pad_mask(self, inp):
        return inp == self.pad_value

    def set_start_token(self, inp):
        inp[:, 0] = self.start_token

    def init_weights(self):
        initrange = 0.1
        # init embeddings
        self.exercise_embeddings.weight.data.uniform_(-initrange, initrange)
        self.part_embeddings.weight.data.uniform_(-initrange, initrange)
        self.elapsed_time_embeddings.weight.data.uniform_(-initrange,
                                                          initrange)
        self.target_embeddings.weight.data.uniform_(-initrange, initrange)
        self.linear1.bias.data.zero_()
        self.linear1.weight.data.uniform_(-initrange, initrange)
        self.linear2.bias.data.zero_()
        self.linear2.weight.data.uniform_(-initrange, initrange)

    def forward(self, inputs: np.ndarray):
        """
        S is the sequence length, N the batch size and E the Embedding Dimension (number of features).
        src: (S, N, E)
        src_mask: (S, S)
        src_key_padding_mask: (N, S)
        padding mask is (N, S) with boolean True/False.
        SRC_MASK is (S, S) with float(’-inf’) and float(0.0).
        """
        content_ids, elapsed_times, parts, targets = \
            inputs['content_id'], inputs['prior_question_elapsed_time'], \
            inputs['part'], inputs['targets']
        self.set_start_token(targets)
        pad_mask = self.make_pad_mask(targets)
        pred_col_idx = (~pad_mask).byte().data.cpu().numpy().cumsum(1).argmax(1)
        assert(True not in pad_mask[np.arange(pad_mask.shape[0]),
                                    pred_col_idx])
        
        content_ids[content_ids > self.max_exercise] = 0
        elapsed_times[elapsed_times > self.max_time] = self.max_time
        parts[parts > self.max_part] = 0

        embedded_inp = (self.exercise_embeddings(content_ids)
                        + self.elapsed_time_embeddings(elapsed_times)
                        + self.part_embeddings(parts)
                        + self.target_embeddings(targets)
                        ) * np.sqrt(self.d_model)  # (N, S, E)
        embedded_inp = self.pos_embedding(embedded_inp.transpose(0, 1))
        # embedded_inp = self.norm1(embedded_inp)  # (S, N, E)

        output = super().forward(src=embedded_inp,
                                 # mask=self.future_mask,
                                 src_key_padding_mask=pad_mask)  # (S, N, E)
        # output = self.norm1(output)
        # output = self.dropout1(output)
        output = self.linear1(output).squeeze(-1).transpose(1, 0) # (N, S)
        # output = self.relu(output)

        # output = self.dropout1(output)
        # output = self.linear2(output).squeeze(-1) # (N)
        output = self.sigmoid(output)
        output = output[np.arange(output.shape[0]), pred_col_idx]
        
        return output

# Transformer Model

In [None]:
class TransformerModel(nn.Transformer):
    def __init__(self, **params):
        """
        nhead -> number of heads in the transformer multi attention thing.
        nhid -> the number of hidden dimension neurons in the model.
        nlayers -> how many layers we want to stack.
        """
        print(params)
        self.params = params.copy()
        self.seq_len = params.pop('seq_len')
        dropout_rate = params['dropout']
        super().__init__(**params)
        self.dropout_rate = params['dropout']
        self.max_exercise = 13523
        self.max_part = 7
        self.max_time = 300
        self.max_container = 10000
        self.max_target = 2
        self.start_token = START_TOKEN
        self.pad_value = PAD_VALUE

        # self.pos_embedding = nn.Embedding(self.d_model, self.d_model) # positional embeddings
        self.src_pos_embedding = PositionalEncoding(self.d_model, dropout_rate,
                                                    self.seq_len)
        self.tgt_pos_embedding = PositionalEncoding(self.d_model, dropout_rate,
                                                    self.seq_len)
        self.exercise_embeddings = \
            nn.Embedding(num_embeddings=self.max_exercise + 1,
                         embedding_dim=self.d_model)  # exercise_id
        self.part_embeddings = \
            nn.Embedding(num_embeddings=self.max_part + 1,
                         embedding_dim=self.d_model)
        self.elapsed_time_embeddings = \
            nn.Embedding(num_embeddings=self.max_time + 1,
                         embedding_dim=self.d_model)
        self.container_embeddings = \
            nn.Embedding(num_embeddings=self.max_container + 1,
                         embedding_dim=self.d_model)
        self.target_embeddings = \
            nn.Embedding(num_embeddings=self.max_target + 2,
                         embedding_dim=self.d_model)
        self.linear1 = nn.Linear(self.d_model, 1)
        self.linear2 = nn.Linear(self.seq_len, 1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.norm1 = nn.LayerNorm(self.d_model)

        self.device = DEVICE
        self.future_mask = self.generate_square_subsequent_mask(
            self.seq_len).to(self.device)
        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_pad_mask(self, inp):
        return inp == self.pad_value

    def set_start_token(self, inp):
        inp[:, 0] = self.start_token

    def init_weights(self):
        initrange = 0.1
        # init embeddings
        self.exercise_embeddings.weight.data.uniform_(-initrange, initrange)
        self.part_embeddings.weight.data.uniform_(-initrange, initrange)
        self.elapsed_time_embeddings.weight.data.uniform_(-initrange, initrange)
        self.container_embeddings.weight.data.uniform_(-initrange, initrange)
        self.target_embeddings.weight.data.uniform_(-initrange, initrange)
        self.linear1.bias.data.zero_()
        self.linear1.weight.data.uniform_(-initrange, initrange)
        # self.linear2.bias.data.zero_()
        # self.linear2.weight.data.uniform_(-initrange, initrange)

    def forward(self, inputs):
        """
        S is the sequence length, N the batch size and E the Embedding Dimension 
        (number of features).
        src: (S, N, E)
        src_mask: (S, S)
        src_key_padding_mask: (N, S)
        padding mask is (N, S) with boolean True/False.
        SRC_MASK is (S, S) with float(’-inf’) and float(0.0).
        """
        content_ids, elapsed_times, containers, parts, targets = \
            inputs['content_id'], inputs['prior_question_elapsed_time'], \
            inputs['task_container_id'], inputs['part'], inputs['targets']
        self.set_start_token(targets)
        src_pad_mask = self.make_pad_mask(content_ids)
        tgt_pad_mask = self.make_pad_mask(targets)
        pred_col_idx = (~src_pad_mask).byte().data.cpu().numpy().cumsum(1).argmax(1)
        assert(True not in src_pad_mask[np.arange(src_pad_mask.shape[0]),
                                        pred_col_idx])

        content_ids[content_ids > self.max_exercise] = 0
        elapsed_times[elapsed_times > self.max_time] = self.max_time
        containers[containers > self.max_container] = 0
        parts[parts > self.max_part] = 0

        embedded_src = (self.exercise_embeddings(content_ids)
                        + self.container_embeddings(containers)
                        + self.part_embeddings(parts)
                        # + self.target_embeddings(targets)
                        # + self.elapsed_time_embeddings(elapsed_times)
                        ) * np.sqrt(self.d_model)  
                        # (N, S, E)
        embedded_src = self.src_pos_embedding(embedded_src.transpose(0, 1)) 
        # (S, N, E)

        embedded_tgt = (self.target_embeddings(targets)
                        + self.elapsed_time_embeddings(elapsed_times)
                        ) * np.sqrt(self.d_model)
        embedded_tgt = self.tgt_pos_embedding(embedded_tgt.transpose(0, 1))
        output = super().forward(src=embedded_src,
                                 tgt=embedded_tgt,
                                #  src_mask=self.future_mask,
                                #  tgt_mask=self.future_mask,
                                #  memory_mask=self.future_mask,
                                 src_key_padding_mask=src_pad_mask,
                                 tgt_key_padding_mask=tgt_pad_mask,
                                 memory_key_padding_mask=src_pad_mask)
        # output = self.norm1(output)
        # output = self.dropout1(output)
        output = self.linear1(output).squeeze(-1).transpose(1, 0) # (N, S)
        output = output[np.arange(output.shape[0]), pred_col_idx]
        # output = self.relu(output)

        # output = self.dropout1(output)
        # output = self.linear2(output).squeeze(-1) # (N)
        output = self.sigmoid(output)

        return output

# Training Functions

In [None]:
class Riiid(torch.utils.data.Dataset):
    def __init__(self, dt, seq_len, pad_value=0,
                 max_samples_per_user=None, for_training=True):
        self.len = dt.shape[0]
        self.dcols = {col: i for i, col in
                      enumerate(FEATURES + [TARGET])}
        groups = dt.groupby(KEY_FEATURE)
        self.groups = groups.apply(
            lambda r: r[FEATURES + [TARGET]].values).values
        
        self.sample_cap = (50, 500)
        counts = groups.count()[TARGET]
        counts[counts < self.sample_cap[0]] = self.sample_cap[0]
        counts[counts > self.sample_cap[1]] = self.sample_cap[1]
        self.probs = counts / sum(counts)
        self.idx = range(len(self.groups))
        self.seq_len = seq_len
        self.pad_value = pad_value
        self.max_samples_per_user = max_samples_per_user
        self.for_training = for_training

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        if self.for_training:
            idx = np.random.choice(self.idx, size=1, p=self.probs)[0]
        inputs = self.groups[idx].copy()
        targets = inputs[:-1, self.dcols[TARGET]].copy()
        targets = np.r_[START_TOKEN, targets + 1]
        inputs = np.c_[inputs, targets]
        rolling_data = make_time_series(inputs, self.seq_len,
                                        pad_value=PAD_VALUE)

        if self.max_samples_per_user is not None:
            n_sequences = len(rolling_data)
            if isinstance(self.max_samples_per_user, int):
                nsamples = min(self.max_samples_per_user, n_sequences)
            else:
                assert (0 < self.max_samples_per_user <= 1)
                nsamples = int(min(n_sequences, self.sample_cap[1]) *
                               self.max_samples_per_user)
            samples_idx = np.random.choice(np.arange(n_sequences),
                                           nsamples, replace=False)
            rolling_data = rolling_data[samples_idx]
        return rolling_data


def collate_fn(batch):
    return np.concatenate(batch).transpose(2, 0, 1)


# %% [code]


def train_epoch(estimator, train_iterator, optim, criterion,
                device="cpu",
                batch_limit=128):
    estimator.train()

    tbar = tqdm(train_iterator, ncols=500)
    num_corrects = 0
    loss_sum = 0
    batch_count = 0
    sample_count = 0

    for batch in tbar:
        inputs = {}
        for i, feat in enumerate(FEATURES):
            if DTYPES[feat] is int:
                inputs[feat] = torch.Tensor(batch[i]).to(device).long()
            elif DTYPES[feat] is float:
                inputs[feat] = torch.Tensor(batch[i]).to(device).float()
            elif DTYPES[feat] is bool:
                inputs[feat] = torch.Tensor(batch[i]).to(device).bool()
        inputs['targets'] = torch.Tensor(batch[-1]).to(device).long()
        labels_all = torch.Tensor(batch[-2]).to(device).long()

        n_samples = len(labels_all)
        n_batches = int(np.ceil(n_samples / batch_limit))
        for nbatch in range(n_batches):
            optim.zero_grad()

            start_idx = nbatch * batch_limit
            end_idx = (nbatch + 1) * batch_limit
            targets = inputs['targets'][start_idx: end_idx].data.cpu().numpy()
            pred_col_idx = (targets != PAD_VALUE).cumsum(1).argmax(1)
            assert(PAD_VALUE not in targets[np.arange(targets.shape[0]),
                                            pred_col_idx])

            output = estimator(inputs={name: feat[start_idx: end_idx]
                                       for name, feat in inputs.items()})
            # print(output.shape, output)
            labels = labels_all[start_idx: end_idx].float()
            # print(targets)
            # print(labels)
            # output = output[np.arange(targets.shape[0]), pred_col_idx]
            labels = labels[np.arange(labels.shape[0]), pred_col_idx]
            # print('\n', labels.shape, labels)
            # print(output.shape, output)
            loss = criterion(output, labels)
            loss.backward()
            optim.step()

            loss_sum += loss.item()
            # pred = (torch.sigmoid(output) >= 0.5).long()
            pred = (output >= 0.5).long()
            num_corrects += (pred == labels).sum().item()
            batch_count += 1
            sample_count += len(labels)

            tbar.set_description(
                f'{nbatch + 1}/{n_batches} | ' + 'trn_loss - {:.4f}'.format(
                    loss_sum / batch_count))

    acc = num_corrects / sample_count
    loss = loss_sum / batch_count

    return loss, acc


# %% [code]


def val_epoch(estimator, val_iterator, criterion, device="cpu",
              batch_limit=128):
    estimator.eval()

    loss_sum = 0
    batch_count = 0
    num_corrects = 0
    sample_count = 0
    truth = np.empty(0)
    outs = np.empty(0)

    tbar = tqdm(val_iterator, ncols=500)
    for n_iter, batch in enumerate(tbar):
        inputs = {}
        for i, feat in enumerate(FEATURES):
            if DTYPES[feat] is int:
                inputs[feat] = torch.Tensor(batch[i]).to(device).long()
            elif DTYPES[feat] is float:
                inputs[feat] = torch.Tensor(batch[i]).to(device).float()
            elif DTYPES[feat] is bool:
                inputs[feat] = torch.Tensor(batch[i]).to(device).bool()
        inputs['targets'] = torch.Tensor(batch[-1]).to(device).long()
        labels_all = torch.Tensor(batch[-2]).to(device).long()

        n_samples = len(labels_all)
        n_batches = int(np.ceil(n_samples / batch_limit))
        for nbatch in range(n_batches):
            start_idx = nbatch * batch_limit
            end_idx = (nbatch + 1) * batch_limit
            targets = inputs['targets'][start_idx: end_idx].data.cpu().numpy()
            pred_col_idx = (targets != PAD_VALUE).cumsum(1).argmax(1)
            assert(PAD_VALUE not in targets[np.arange(targets.shape[0]),
                                            pred_col_idx])

            with torch.no_grad():
                output = estimator(inputs={name: feat[start_idx: end_idx]
                                           for name, feat in inputs.items()})
            labels = labels_all[start_idx: end_idx].float()
            # output = output[np.arange(targets.shape[0]), pred_col_idx]
            labels = labels[np.arange(targets.shape[0]), pred_col_idx]

            loss = criterion(output, labels)
            loss_sum += loss.item()
            batch_count += 1

            # pred = (torch.sigmoid(output) >= 0.5).long()
            pred = (output >= 0.5).long()
            num_corrects += (pred == labels).sum().item()
            sample_count += len(labels)
            truth = np.r_[truth, labels.view(-1).data.cpu().numpy()]
            outs = np.r_[outs, output.view(-1).data.cpu().numpy()]

            tbar.set_description(
                f'{nbatch + 1}/{n_batches} | ' + 'val_loss - {:.4f}'.format(
                    loss_sum / batch_count))

    acc = num_corrects / sample_count
    auc = roc_auc_score(truth, outs)
    loss = loss_sum / batch_count

    return loss, acc, auc


# %% [code]


def train_transformer(estimator, optim, sched, train, valid,
                      epochs=10, n_user_batches=32, batch_limit=128,
                      max_samples_per_user=100, device="cpu", early_stopping=2,
                      eps=1e-4, nworkers=4, last_auc=0):
    trn_dataset = Riiid(dt=train, seq_len=SEQ_LEN, pad_value=PAD_VALUE,
                        max_samples_per_user=max_samples_per_user, 
                        for_training=True)
    val_dataset = Riiid(dt=valid, seq_len=SEQ_LEN, pad_value=PAD_VALUE,
                        max_samples_per_user=None, for_training=False)

    trn_dataloader = torch.utils.data.DataLoader(dataset=trn_dataset,
                                                    batch_size=n_user_batches,
                                                    collate_fn=collate_fn,
                                                    num_workers=nworkers)
    val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                 batch_size=n_user_batches,
                                                 collate_fn=collate_fn,
                                                 num_workers=nworkers)
    if TPU:
        trn_dataloader = pl.ParallelLoader(trn_dataloader, [device])
        val_dataloader = pl.ParallelLoader(val_dataloader, [device])

    criterion = nn.BCELoss()
    criterion.to(device)
    estimator.to(device)

    over_fit = 0
    for epoch in range(epochs):
        print('\nLearning rate:', optim.param_groups[0]['lr'])
        trn_loss, trn_acc = train_epoch(estimator, trn_dataloader, optim,
                                        criterion, device, batch_limit)
        print("Training epoch {} - loss:{:.4f} - acc: {:.4f}"\
              .format(epoch + 1, trn_loss, trn_acc))
        val_loss, val_acc, val_auc = val_epoch(estimator, val_dataloader,
                                               criterion, device, 1028)
        color = ''
        if val_auc > last_auc and epoch != 0:
            color = '\033[91m'
        print(color + "Validation epoch {} - loss: {:.4f} - acc: {:.4f}, auc: {:.6f}"\
              .format(epoch + 1, val_loss, val_acc, val_auc))
        if sched is not None:
            if isinstance(sched, torch.optim.lr_scheduler.ReduceLROnPlateau):
                sched.step(val_auc)
            else:
                sched.step()

        if val_auc > last_auc + eps:
            last_auc = val_auc
            over_fit = 0
            save_model(estimator, optimizer, sched, val_auc)
        else:
            over_fit += 1

        if over_fit >= early_stopping:
            print("early stop epoch ", epoch + 1)
            break

    return estimator


# %% [code]
# ========================== TEST =======================================


class RiiidTest(torch.utils.data.Dataset):
    def __init__(self, dt, queries, seq_len, pad_value=0, local=False):
        self.data = dt
        self.dcols = {col: i for i, col in
                      enumerate([KEY_FEATURE] + FEATURES + [TARGET])}

        self.queries = queries
        self.groups = None
        if local:
            self.groups = queries.groupby(KEY_FEATURE) \
                .apply(lambda r: r.values).values.tolist()
        self.seq_len = seq_len
        self.pad_value = pad_value
        self.is_local = local

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        if self.is_local:
            random.shuffle(self.groups)
            query = self.groups[0][[0]]
            query_label = query[0, self.dcols[TARGET]]
            query = np.delete(query, self.dcols[TARGET], axis=1)
            self.groups[0] = self.groups[0][1:]
            if self.groups[0].shape[0] == 0:
                self.groups = self.groups[1:]
        else:
            query = self.queries[[idx]]
        uid = query[0, self.dcols[KEY_FEATURE]]
        query = np.delete(query, self.dcols[KEY_FEATURE], axis=1)

        if uid in self.data.index:
            inputs = self.data[uid]
            labels = inputs[:, self.dcols[TARGET]]
            inputs = np.delete(inputs, [self.dcols[KEY_FEATURE],
                                        self.dcols[TARGET]], axis=1)

            inputs = np.r_[inputs, query]
        else:
            inputs = query
            labels = np.empty(0)

        inputs = pad_batch(inputs, self.seq_len, self.pad_value)
        targets = np.r_[START_TOKEN, labels + 1]
        targets = pad_batch(targets, self.seq_len, self.pad_value)
        if self.is_local:
            labels = np.r_[labels, query_label]
            labels = pad_batch(labels, self.seq_len, self.pad_value)
            uids = np.full(inputs.shape[0], uid)
            inputs = np.c_[inputs, uids, labels]
        return np.c_[inputs, targets]


def collate_fn_test(batch):
    return np.array(batch).transpose(2, 0, 1)


# %% [code]


def update_stats(prev_data, prev_batch):
    def update_stat(trow):
        uid = trow[0]
        if uid in prev_data.index:
            prev_data[uid] = np.r_[prev_data[uid][-SEQ_LEN + 2:], [trow]] \
                .astype(np.float32)
        else:
            prev_data[uid] = np.array([trow])
    np.apply_along_axis(update_stat, arr=prev_batch, axis=1)


# %% [code]


def predict_local(filename, prev_data, is_debug, batch_size):
    test_set = preprocess(load_data(filename))

    if is_debug:
        test_set = test_set.iloc[-50_000:]
    print('test shape:', test_set.shape)
    test_dataset = RiiidTest(prev_data, test_set,
                             SEQ_LEN, local=True)
    test_dataloader = \
        torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=batch_size,
                                    collate_fn=collate_fn_test,
                                    num_workers=0)
    preds = predict_test(model,
                         test_dataloader,
                         local=True,
                         prev_data=prev_data,
                         device=DEVICE)
    return preds


# %% [code]


def predict_test(estimator,
                 tst_iterator,
                 local=False,
                 prev_data=None,
                 device="cpu"):
    estimator.eval()

    truth = np.empty(0)
    outs = np.empty(0)
    if local:
        tst_iterator = tqdm(tst_iterator, ncols=500)
    for batch in tst_iterator:
        inputs = {}
        for i, feat in enumerate(FEATURES):
            if DTYPES[feat] is int:
                inputs[feat] = torch.Tensor(batch[i].astype(int)) \
                    .to(device).long()
            elif DTYPES[feat] is float:
                inputs[feat] = torch.Tensor(batch[i]).to(device).float()
            elif DTYPES[feat] is bool:
                inputs[feat] = torch.Tensor(batch[i].astype(bool)). \
                    to(device).bool()
        inputs['targets'] = torch.Tensor(batch[-1].astype(np.int64)).to(
            device).long()

        with torch.no_grad():
            output = estimator(inputs=inputs)
        # output = torch.sigmoid(output)
        # output = output[np.arange(inputs['targets'].shape[0]), pred_col_idx]
        outs = np.r_[outs, output.data.cpu().numpy()]

        if local:
            targets = inputs['targets'].data.cpu().numpy()
            pred_col_idx = (targets != PAD_VALUE).cumsum(1).argmax(1)
            assert (PAD_VALUE not in targets[np.arange(targets.shape[0]),
                                             pred_col_idx])

            prev_batch = batch[[-3] + list(range(len(FEATURES))) + [-2]]
            prev_batch = prev_batch[:, np.arange(batch.shape[1]),
                                    pred_col_idx].T
            update_stats(prev_data, prev_batch)

            labels = batch[-2, np.arange(batch.shape[1]), pred_col_idx]
            truth = np.r_[truth, labels]
            tst_iterator.set_description(
                'test_auc - {:.4f}'.format(roc_auc_score(truth, outs)))

    return outs


# %% [code]


def predict_submission_group(estimator,
                             tst_batch,
                             prev_batch,
                             prev_data,
                             batch_size=128):
    all_cols = list(tst_batch.columns) + ADDED_FEATURES + [TARGET]
    all_cols = dict(zip(all_cols, range(len(all_cols))))
    used_cols = [all_cols[feat] for feat in [KEY_FEATURE] + FEATURES]

    tst_batch = preprocess(tst_batch).values
    if (prev_batch is not None) & (psutil.virtual_memory().percent < 90):
        # print(psutil.virtual_memory().percent)
        prev_batch = np.c_[prev_batch, eval(
            tst_batch[0, all_cols['prior_group_answers_correct']])]
        prev_batch = prev_batch[prev_batch[:, all_cols['content_type_id']] == 0
                                ][:, used_cols + [all_cols[TARGET]]]
        update_stats(prev_data, prev_batch)

    parts = np.apply_along_axis(
        lambda rid: TAGS_DF[rid[0]]['part'] if rid[0] in TAGS_DF else 0,
        axis=1, arr=tst_batch[:, [all_cols['content_id']]])
    tst_batch = np.c_[tst_batch, parts]
    prev_batch = tst_batch.copy()

    qrows = tst_batch[:, all_cols['content_type_id']] == 0
    tst_batch = tst_batch[qrows]
    tst_dataset = RiiidTest(prev_data, tst_batch[:, used_cols], SEQ_LEN)
    tst_dataloader = torch.utils.data.DataLoader(dataset=tst_dataset,
                                                 batch_size=batch_size,
                                                 collate_fn=collate_fn_test,
                                                 num_workers=0)

    _preds = predict_test(estimator, tst_dataloader, device=DEVICE)
    tst_batch = np.c_[tst_batch, _preds]
    _predictions = pd.DataFrame(
        tst_batch[:, [all_cols[col] for col in SUBMISSION_COLUMNS]],
        columns=SUBMISSION_COLUMNS)

    return {'preds': _predictions,
            'prev_batch': prev_batch,
            'prev_data': prev_data}

# Train

In [None]:
MODEL = TransformerModel #@param ["TransformerModel", "TransformerEncoderModel"] {type:"raw"}

MODEL_FILENAME = f'{MODEL.__name__}_best.pth' #@param {type:"string"}

SEQ_LEN =  96#@param {type:"integer"}

D_MODEL =  512#@param {type:"integer"}

NHEAD =  4#@param {type:"integer"}

N_LAYERS =  4#@param {type:"integer"}

DIM_FEEDFORWARD = 256 #@param {type:"integer"}

DROPOUT = 0.3 #@param {type:"slider", min:0, max:1, step:0.05}

ACTIVATION = "relu" #@param ["relu", "tanh", "sigmoid"]

LEARNING_RATE = 1e-3 #@param {type:"slider", min:1e-5, max:1, step:1e-5}



EPOCHS = 100 #@param {type:"integer"}

N_USER_BATCHES = 64 #@param {type:"integer"}

BATCH_LIMIT = 64 #@param {type:"integer"}

MAX_SAMPLES_PER_USER =  1#@param {type:"integer"}

EARLY_STOPPING = 10 #@param {type:"integer"}

WARMUP_STEPS = 20 #@param {type:"integer"}

NUM_WORKERS = 0 #@param {type:"integer"}



PAD_VALUE = 0
START_TOKEN = 3
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") \
         if not TPU else xm.xla_device()

In [None]:
retrain_transformer = True #@param {type:"boolean"}

cont = False #@param {type:"boolean"}

debug = True #@param {type:"boolean"}

data_size = 20000000 #@param {type:"slider", min:1000000, max:100000000, step:1000000}

debug_size = 1600000 #@param {type:"slider", min:100000, max:100000000, step:100000}




local_sample = True #@param {type:"boolean"}

In [None]:
if __name__ == '__main__':
    print('Using Device -', DEVICE)
    model_params = {
        'seq_len': SEQ_LEN,
        'd_model': D_MODEL, 
        'nhead': NHEAD,
        'num_encoder_layers': N_LAYERS,
        'dim_feedforward': DIM_FEEDFORWARD,
        'dropout': DROPOUT,
        'activation': ACTIVATION
    }
    if MODEL is TransformerModel:
        model_params['num_decoder_layers'] = N_LAYERS

    if retrain_transformer and not COMPILE:
        data_path = 'train_transformer.parquet'
        data = preprocess(load_data(data_path))[-data_size:]
        if debug:
            data = data.iloc[-debug_size:]
        df_train, df_valid = split_train_valid(data, 0.05)
        print('train size:', df_train.shape, 
              '- num users:', df_train['user_id'].nunique())
        print('valid size:', df_valid.shape, 
              '- num users:', df_valid['user_id'].nunique())
        del data
        gc.collect()

        if cont:
            model, optimizer, scheduler, val_score = \
                load_model(MODEL, for_training=True,  warmup_steps=WARMUP_STEPS)
        else:
            model = create_model(MODEL, **model_params)
            optimizer = create_optimizer(model, lr=LEARNING_RATE)
            scheduler = create_scheduler(model, optimizer,
                                            warmup_steps=WARMUP_STEPS)

        model = train_transformer(model, optimizer, scheduler,
                                    df_train, df_valid,
                                    epochs=EPOCHS,
                                    n_user_batches=N_USER_BATCHES,
                                    batch_limit=BATCH_LIMIT,
                                    max_samples_per_user=MAX_SAMPLES_PER_USER,
                                    early_stopping=EARLY_STOPPING,
                                    eps=1e-4,
                                    nworkers=NUM_WORKERS,
                                    device=DEVICE,
                                    last_auc=0 if not cont else val_score)

Using Device - cuda:0
train size: (1904758, 6) - num users: 7579
valid size: (95242, 6) - num users: 375
{'seq_len': 96, 'd_model': 512, 'nhead': 4, 'num_encoder_layers': 4, 'dim_feedforward': 256, 'dropout': 0.3, 'activation': 'relu', 'num_decoder_layers': 4}

Learning rate: 4.941058844013094e-07


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 1 - loss:0.8000 - acc: 0.5331


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 1 - loss: 0.6479 - acc: 0.6805, auc: 0.518416

Learning rate: 9.882117688026188e-07


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 2 - loss:0.7852 - acc: 0.5499


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 2 - loss: 0.6391 - acc: 0.6863, auc: 0.580697

Learning rate: 1.482317653203928e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 3 - loss:0.7657 - acc: 0.5579


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 3 - loss: 0.6259 - acc: 0.6906, auc: 0.625173

Learning rate: 1.9764235376052376e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 4 - loss:0.7655 - acc: 0.5572


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 4 - loss: 0.6346 - acc: 0.6928, auc: 0.648534

Learning rate: 2.4705294220065464e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 5 - loss:0.7561 - acc: 0.5668


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 5 - loss: 0.6202 - acc: 0.6942, auc: 0.672120

Learning rate: 2.964635306407856e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 6 - loss:0.7549 - acc: 0.5581


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 6 - loss: 0.6139 - acc: 0.6963, auc: 0.677316

Learning rate: 3.4587411908091652e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 7 - loss:0.7324 - acc: 0.5806


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 7 - loss: 0.6140 - acc: 0.6957, auc: 0.681707

Learning rate: 3.952847075210475e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 8 - loss:0.7296 - acc: 0.5773


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 8 - loss: 0.6272 - acc: 0.6941, auc: 0.683928

Learning rate: 4.446952959611784e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 9 - loss:0.7238 - acc: 0.5733


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 9 - loss: 0.5923 - acc: 0.6976, auc: 0.689279

Learning rate: 4.941058844013093e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 10 - loss:0.7101 - acc: 0.5907


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 10 - loss: 0.5959 - acc: 0.6977, auc: 0.687811

Learning rate: 5.4351647284144016e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 11 - loss:0.7131 - acc: 0.5833


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 11 - loss: 0.6311 - acc: 0.6935, auc: 0.689540

Learning rate: 5.929270612815712e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 12 - loss:0.7059 - acc: 0.5931


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 12 - loss: 0.5987 - acc: 0.6950, auc: 0.688902

Learning rate: 6.423376497217022e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 13 - loss:0.6902 - acc: 0.5988


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 13 - loss: 0.5778 - acc: 0.7025, auc: 0.690640

Learning rate: 6.9174823816183304e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 14 - loss:0.7029 - acc: 0.5883


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 14 - loss: 0.6003 - acc: 0.6955, auc: 0.690550

Learning rate: 7.41158826601964e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 15 - loss:0.7025 - acc: 0.5858


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 15 - loss: 0.5854 - acc: 0.6969, auc: 0.693745

Learning rate: 7.90569415042095e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 16 - loss:0.6926 - acc: 0.5990


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 16 - loss: 0.5864 - acc: 0.6957, auc: 0.693920

Learning rate: 8.399800034822258e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 17 - loss:0.6880 - acc: 0.6064


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 17 - loss: 0.5793 - acc: 0.7011, auc: 0.689741

Learning rate: 8.893905919223568e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 18 - loss:0.6929 - acc: 0.5957


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 18 - loss: 0.5952 - acc: 0.6946, auc: 0.689976

Learning rate: 9.388011803624876e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 19 - loss:0.6850 - acc: 0.6051


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 19 - loss: 0.5881 - acc: 0.6979, auc: 0.690904

Learning rate: 9.882117688026186e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 20 - loss:0.6842 - acc: 0.5998


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 20 - loss: 0.5801 - acc: 0.7032, auc: 0.692443

Learning rate: 9.643959372630746e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 21 - loss:0.6804 - acc: 0.6005


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 21 - loss: 0.5805 - acc: 0.7006, auc: 0.696518

Learning rate: 9.422229518055114e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 22 - loss:0.6729 - acc: 0.6164


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 22 - loss: 0.5745 - acc: 0.7042, auc: 0.696947

Learning rate: 9.215122259681072e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 23 - loss:0.6764 - acc: 0.6118


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 23 - loss: 0.5799 - acc: 0.7006, auc: 0.695096

Learning rate: 9.021097956087905e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 24 - loss:0.6704 - acc: 0.6145


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 24 - loss: 0.5794 - acc: 0.7024, auc: 0.696260

Learning rate: 8.838834764831846e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 25 - loss:0.6793 - acc: 0.6088


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 25 - loss: 0.5724 - acc: 0.7052, auc: 0.698113

Learning rate: 8.667190566019206e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 26 - loss:0.6716 - acc: 0.6129


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 26 - loss: 0.5750 - acc: 0.7015, auc: 0.698829

Learning rate: 8.505172717997149e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 27 - loss:0.6788 - acc: 0.6018


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 27 - loss: 0.5716 - acc: 0.7057, auc: 0.700143

Learning rate: 8.351913809763263e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 28 - loss:0.6705 - acc: 0.6117


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 28 - loss: 0.5995 - acc: 0.6956, auc: 0.699290

Learning rate: 8.20665205373266e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 29 - loss:0.6646 - acc: 0.6212


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


[91mValidation epoch 29 - loss: 0.5737 - acc: 0.7021, auc: 0.701129

Learning rate: 8.068715304598784e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 30 - loss:0.6749 - acc: 0.6080


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 30 - loss: 0.5747 - acc: 0.7000, auc: 0.701078

Learning rate: 7.937507937511906e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 31 - loss:0.6750 - acc: 0.6046


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 31 - loss: 0.5795 - acc: 0.6972, auc: 0.701056

Learning rate: 7.812500000000002e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 32 - loss:0.6646 - acc: 0.6178


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 32 - loss: 0.5793 - acc: 0.6983, auc: 0.699920

Learning rate: 7.693218186208296e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 33 - loss:0.6675 - acc: 0.6123


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 33 - loss: 0.5799 - acc: 0.6989, auc: 0.699279

Learning rate: 7.579238282385407e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 34 - loss:0.6733 - acc: 0.6027


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 34 - loss: 0.5790 - acc: 0.6999, auc: 0.699160

Learning rate: 7.47017880833996e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 35 - loss:0.6692 - acc: 0.6150


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 35 - loss: 0.5805 - acc: 0.6993, auc: 0.699591

Learning rate: 7.36569563735987e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 36 - loss:0.6712 - acc: 0.6064


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…


Validation epoch 36 - loss: 0.5743 - acc: 0.7024, auc: 0.699186

Learning rate: 7.265477421488705e-06


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=119.0), HTML(value='')), layout=Layout(di…


Training epoch 37 - loss:0.6680 - acc: 0.6149


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=6.0), HTML(value='')), layout=Layout(disp…

KeyboardInterrupt: ignored

# Test

In [None]:
# ============================= TESTING ===================================
TAGS_DF = pd.read_parquet(PARQUETS_DIR + 'tags.parquet')
# Add 1 to content ids to match embeddings.
TAGS_DF.index = TAGS_DF.index + 1
TAGS_DF = TAGS_DF.to_dict('index')
pdata_path = 'train_merged.parquet' if COMPILE else 'train_transformer.parquet'
previous_data = preprocess(load_data(pdata_path)).groupby(
    KEY_FEATURE).apply(lambda g: g.tail(SEQ_LEN - 1).values)

# In[]
model, _, _, _ = load_model(MODEL)
model.eval()

if local_sample and not COMPILE:
    print('predicting on local sample...')
    predictions = predict_local('test_transformer.parquet',
                                previous_data,
                                is_debug=debug,
                                batch_size=32 if debug else 512)
elif not COMPILE:
    print('Submitting locally...')
    previous_batch = None
    tgts = []
    example_test = pd.read_csv(DATA_DIR + 'example_test.csv')
    submission = pd.DataFrame(columns=SUBMISSION_COLUMNS)
    for gnum in tqdm(example_test['group_num'].unique()):
        test_batch = example_test[
            example_test['group_num'] == gnum].copy()
        # test_batch['content_type_id'] = np.random.randint(0, 2, len(test_batch))
        # test_batch['user_id'] = 1931258865  # np.random.randint(0, previous_data.index.max() + 10000, len(test_batch))
        # test_batch['content_id'] = 10542131233  # np.random.randint(0, 20000, len(test_batch))
        preds, previous_batch, previous_data = predict_submission_group(
            model, test_batch,
            previous_batch,
            previous_data,
            batch_size=1024).values()
        tgts.extend(eval(test_batch['prior_group_answers_correct'].iloc[0]))
        submission = submission.append(preds)
    tgts.extend([-1] * len(test_batch))
    submission['target'] = tgts
    submission['pred'] = (submission[TARGET] >= 0.5).astype(np.int8)
    submission = submission.reset_index(drop=True)
    print(submission)
    acc = submission[submission['target'] != -1]['target'] == \
            submission[submission['target'] != -1]['pred']
    acc = sum(acc) / len(acc)
    print('Accuracy', acc)
else:
    print('Submitting...')
    env = riiideducation.make_env()
    previous_batch = None
    for test_batch, _ in env.iter_test():
        preds, previous_batch, previous_data = \
            predict_submission_group(model,
                                        test_batch,
                                        previous_batch,
                                        previous_data,
                                        batch_size=1024).values()
        env.predict(preds)

In [None]:
predictions

In [None]:
print(torch.__version__)