In [1]:
import matplotlib.pyplot as plt

%matplotlib inline
import matplotlib.style as style
import numpy as np
import pandas as pd
import seaborn as sns

style.use("fivethirtyeight")
import gc
import math
import os
import sys

import h5py
import lightgbm as lgb
import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics.functional import accuracy
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder
from torch import nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.auto import tqdm

pd.set_option("display.max_rows", 100)


SEED = 69
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f6015ac0af0>

# Prepare Dataset

In [2]:
%%time

train = pd.read_pickle("riiid_train.pkl.gzip")
questions_df = pd.read_csv("questions.csv")
lectures_df = pd.read_csv("lectures.csv")

folder_path = "data"
print("Loading lectures arrays")
lectures_ids = np.load(f"{folder_path}/lectures_ids.npy")
lectures_parts = np.load(f"{folder_path}/lectures_parts.npy")
lectures_types = np.load(f"{folder_path}/lectures_types.npy")
lectures_tags = lectures_df.tag.values 

print("Loading questions arrays")
questions_parts = np.load(f"{folder_path}/questions_parts.npy")

# process tags
def split_tags(t):
    try:
        return [int(i) for i in t.split(" ")]
    except AttributeError:
        return list()

# Get tags to be 2D array of shape (Q, T), where Q is question_idx, and T is the max number of tag possible (6)
questions_df["tags"] = questions_df.tags.apply(split_tags)
questions_tags = pd.DataFrame(questions_df["tags"].tolist(), index=questions_df.index)
questions_tags = questions_tags.fillna(questions_tags.max().max()+1).astype(np.int).values # pad with max tag + 1

Loading lectures arrays
Loading questions arrays
CPU times: user 25.2 ms, sys: 1.03 s, total: 1.05 s
Wall time: 1.05 s


#### Time Features

In [3]:
def ffill(arr):
    # https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
    prev = np.arange(len(arr))
    prev[arr == 0] = 0
    prev = np.maximum.accumulate(prev)
    return arr[prev]


# def get_time_elapsed_from_timestamp(arr):
#     # this does the pandas operation but on numpy
#     return ffill(np.diff(arr, prepend=1)).astype(np.float32)


def get_time_elapsed_from_timestamp(arr, max_minutes=300):
    # Saint+ way
    # Note this is isnt the smartest way..
    arr_seconds = np.diff(arr, prepend=1) // 1000 
    arr_seconds[arr_seconds > max_minutes] = max_minutes
    return arr_seconds.astype(np.float32)

## Generate Hdf5

In [5]:
# Add part to lecture and questions
train["part"] = 0
train.loc[~train.content_type_id, "part"] = train[
    ~train.content_type_id
].content_id.map(pd.Series(questions_parts))
train.loc[train.content_type_id, "part"] = (
    train[train.content_type_id]
    .content_id.map(pd.Series(lectures_ids))
    .map(pd.Series(lectures_parts))
)
train.part = train.part.astype(np.uint8)

In [6]:
train.prior_question_had_explanation = train.prior_question_had_explanation.astype(
    np.uint8
)
train.prior_question_elapsed_time.fillna(21000, inplace=True)
train.drop(columns="row_id", inplace=True)

In [7]:
train.columns

Index(['timestamp', 'user_id', 'content_id', 'content_type_id',
       'task_container_id', 'user_answer', 'answered_correctly',
       'prior_question_elapsed_time', 'prior_question_had_explanation',
       'time_between_question_bundles', 'part'],
      dtype='object')

In [None]:
# ignore lectures for now
hf = h5py.File("feats.h5", "w")

for user_id, data in tqdm(train[~train.content_type_id].groupby("user_id")):
    processed_feats = data[
        [
            "content_id",
            "part",
            "answered_correctly",
            "timestamp",
            "prior_question_elapsed_time",
            "prior_question_had_explanation",
        ]
    ].values

    hf.create_dataset(f"{user_id}/content_ids", data=processed_feats[:, 0])
    hf.create_dataset(f"{user_id}/parts", data=processed_feats[:, 1])
    hf.create_dataset(f"{user_id}/answered_correctly", data=processed_feats[:, 2])
    hf.create_dataset(f"{user_id}/timestamps", data=processed_feats[:, 3])
    hf.create_dataset(f"{user_id}/prior_question_elapsed_time", data=processed_feats[:, 4])
    hf.create_dataset(f"{user_id}/prior_question_had_explanation", data=processed_feats[:, 5])

hf.close()

## Pytorch Stuff

### Data

Here we define the pytorch Dataset object and a custom collate function. 

In [4]:
class RIIDDataset(Dataset):
    """RIID dataset."""

    def __init__(
        self,
        user_mapping,
        hf5_file="feats.h5",
        window_size=100,
        content_id_padding=13523,
        answered_correctly_padding=3,
    ):
        """
        Args:
            user_mapping (np.array): array of all unique user ids 
            hf5_file (string): location of hf5 feats file
        """
        self.user_mapping = user_mapping
        self.hf5_file = hf5_file
        self.max_window_size = window_size

        self.content_id_padding = content_id_padding
        self.answered_correctly_padding = answered_correctly_padding

        
    def open_hdf5(self):
        self.f = h5py.File(self.hf5_file, "r")
        
    def __len__(self):
        return len(self.user_mapping)

    def __getitem__(self, idx):

        # open the hdf5 file in the iterator to allow multiple workers
        # https://github.com/pytorch/pytorch/issues/11929
        if not hasattr(self, 'f'):
            self.open_hdf5()
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        user_id = self.user_mapping[idx]
        length = self.f[f"{user_id}/answered_correctly"].len()

        window_size = min(self.max_window_size, length)

        parts = np.zeros(window_size, dtype=np.int64).copy()
        content_ids = np.zeros(window_size, dtype=np.int64).copy()
        answered_correctly = np.zeros(window_size, dtype=np.int64).copy()
        timestamps = np.zeros(window_size, dtype=np.float32).copy()
        
        # index for loading larger than window size
        start_index = 0
        if length > window_size:
            # randomly select window size subset instead of trying to cram in everything
            start_index = np.random.randint(length - window_size)

        self.f[f"{user_id}/content_ids"].read_direct(
            content_ids,
            source_sel=np.s_[start_index : start_index + window_size],
            dest_sel=np.s_[0:window_size],
        )
        self.f[f"{user_id}/parts"].read_direct(
            parts,
            source_sel=np.s_[start_index : start_index + window_size],
            dest_sel=np.s_[0:window_size],
        )
        self.f[f"{user_id}/answered_correctly"].read_direct(
            answered_correctly,
            source_sel=np.s_[start_index : start_index + window_size],
            dest_sel=np.s_[0:window_size],
        )
        self.f[f"{user_id}/timestamps"].read_direct(
            timestamps,
            source_sel=np.s_[start_index : start_index + window_size],
            dest_sel=np.s_[0:window_size],
        )
        
        #convert timestamps to time elapsed
        timestamps = get_time_elapsed_from_timestamp(timestamps)
        
        # get question tags
        tags = questions_tags[content_ids,:].astype(np.int64)
        
        sample = {
            "parts": torch.from_numpy(parts),
            "tags": torch.from_numpy(tags),
            "content_ids": torch.from_numpy(content_ids),
            "answered_correctly": torch.from_numpy(answered_correctly),
            "timestamps": torch.from_numpy(timestamps),
            "length": window_size,
        }
        return sample

In [5]:
from torch.nn.utils.rnn import pad_sequence

# The collate function is used to merge individual data samples into a batch
# It handles the padding aspect
def collate_fn(batch):
    # collate lenghts into 1D tensor
    items = {"length": torch.tensor([batch_item["length"] for batch_item in batch])}

    # find shape that the batch will have
    max_length = items["length"].max()
    num_items = len(batch)

    # padding list
    for (key, padding) in [
        ("parts", 0),
        ("content_ids", 13523),
        ("answered_correctly", 3),
        ("timestamps", 0.0), # note timestamps isnt an embedding
        ("tags", 188)
    ]:
        items[key] = pad_sequence(
            [batch_item[key] for batch_item in batch],
            batch_first=False,
            padding_value=padding,
        )

    # shift by one the answered_correctly sequence
    items["answers"] = items["answered_correctly"].clone().roll(1, dims=0)
    items["answers"][0, :] = 2  # set start token

    # mask to weight loss by (S, N)
    items["loss_mask"] = (
        (
            torch.arange(max_length).expand(num_items, max_length)
            < items["length"].unsqueeze(1)
        )
        .transpose(1, 0)
        .float()
    )

    items["answered_correctly"] = items["answered_correctly"].float()

    return items

In [6]:
# Create Dataset will all users
user_ids = train.user_id.unique()
dataset = RIIDDataset(user_ids)
len(dataset)

393656

### Model

In [14]:
from pytorch_lightning.metrics.functional.classification import auroc
from torch.nn import TransformerEncoder, TransformerEncoderLayer


def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.zeros_(m.bias)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, sequence_length):
        # returns embeds (sequence_length, 1, d_model)
        return self.pe[:sequence_length, :]


class RIIDDTransformerModel(pl.LightningModule):
    def __init__(
        self,
        learning_rate=0.001,
        warmup=4000,
        decay=True,  # whether to decay after warmup
        n_content_id=13524,  # number of different contents = 13523 + 1 (for padding)
        n_part=8,  # number of different parts = 7 + 1 (for padding)
        n_tags=189,  # number of different tags = 188 + 1 (for padding)
        emb_dim=64,  # embedding dimension
        dropout=0.1,
        n_heads: int = 1,
        n_encoder_layers: int = 2,
        n_decoder_layers: int = 2,
        dim_feedforward: int = 256,
        activation: str = "relu",
        batch_size=256,  # will get saved as hyperparam
        num_user_train=300000,  # will get saved as hyperparam
        num_user_val=30000,  # will get saved as hyperparam
        use_time_feats=True,
        tie_weights=True,
    ):
        super(RIIDDTransformerModel, self).__init__()
        self.model_type = "RiiidTransformer"
        self.learning_rate = learning_rate
        self.warmup = warmup
        self.use_time_feats = use_time_feats
        self.decay = decay
        self.tie_weights = tie_weights

        # save params of models to yml
        self.save_hyperparameters()

        self.embed_content_id = nn.Embedding(n_content_id, emb_dim, padding_idx=13523)
        self.embed_parts = nn.Embedding(n_part, emb_dim, padding_idx=0)
        self.embed_tags = nn.Embedding(n_tags, emb_dim, padding_idx=188)
        # exercise weights to weight the mean embeded excercise embeddings
        self.exercise_weights = torch.nn.Parameter(torch.tensor([0.3, 0.55, 0.15]))

        self.embed_answered_correctly = nn.Embedding(
            4, emb_dim, padding_idx=3
        )  # 2 + 1 for start token + 1 for padding_idn_inputs

        self.embed_timestamps = nn.Linear(1, emb_dim)

        self.pos_encoder = PositionalEncoding(emb_dim)

        self.transformer = nn.Transformer(
            d_model=emb_dim,
            nhead=n_heads,
            num_encoder_layers=n_encoder_layers,
            num_decoder_layers=n_decoder_layers,
            dropout=dropout,
            dim_feedforward=dim_feedforward,
            activation=activation,
        )

        # tie weights of output layer
        if self.tie_weights:
            self.out_linear = nn.Linear(
                emb_dim, 4
            )  # must also predict start token since weight tied
            self.out_linear.weight = self.embed_answered_correctly.weight

            # masking the extra dimensions
            # should be of length embed_answered_correctly
            out_mask = torch.tensor([0, 0, float("-inf"), float("-inf")])
            self.register_buffer("out_mask", out_mask)
        else:
            self.out_linear = nn.Linear(emb_dim, 1)

        init_weights(self)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = (
            mask.float()
            .masked_fill(mask == 0, float("-inf"))
            .masked_fill(mask == 1, float(0.0))
        )
        return mask

    def forward(self, content_ids, parts, answers, tags, timestamps):
        # content_ids: (Source Sequence Length,Number of samples,Embedding)
        # tgt: (Target Sequence Length,Number of samples,Embedding)
        # src_mask: (S,S)
        # tgt_mask: (T,T)

        sequence_length = content_ids.shape[0]

        # sequence that will go into encoder
        embeded_content = self.embed_content_id(content_ids)
        embeded_parts = self.embed_parts(parts)
        embeded_tags = self.embed_tags(tags).sum(dim=2)
        e_w = F.softmax(self.exercise_weights, dim=0)

        embeded_exercise_sequence = (
            (embeded_content * e_w[0])
            + (embeded_parts * e_w[1])
            + (embeded_tags * e_w[2])
        )

        # sequence that will go into decoder

        embeded_responses = self.embed_answered_correctly(answers)
        if self.use_time_feats:
            embeded_timestamps = self.embed_timestamps(timestamps.unsqueeze(2))
            embeded_responses = (embeded_responses + embeded_timestamps) * 0.5

        # adding positional vector
        embedded_positions = self.pos_encoder(sequence_length)
        embeded_responses = embeded_responses + embedded_positions
        embeded_exercise_sequence = embeded_exercise_sequence + embedded_positions

        # mask of shape S x S -> prevents attention looking forward
        top_right_attention_mask = self.generate_square_subsequent_mask(
            sequence_length
        ).type_as(embeded_exercise_sequence)

        output = self.transformer(
            embeded_exercise_sequence,
            embeded_responses,
            tgt_mask=top_right_attention_mask,
            src_mask=top_right_attention_mask,
        )

        if self.tie_weights:
            # returns softmax but ignores the last two columns (<start> and <pad>)
            output = self.out_linear(output)
            return F.softmax(output + self.out_mask, dim=2,)[:, :, 1]
        else:
            output = self.out_linear(output).squeeze(2)
            return torch.sigmoid(output)

    def training_step(self, batch, batch_nb):
        result = self(
            batch["content_ids"],
            batch["parts"],
            batch["answers"],
            batch["tags"],
            batch["timestamps"],
        )
        loss = F.binary_cross_entropy(
            result, batch["answered_correctly"], weight=batch["loss_mask"]
        )
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_nb, dataset_nb=None):
        result = self(
            batch["content_ids"],
            batch["parts"],
            batch["answers"],
            batch["tags"],
            batch["timestamps"],
        )
        loss = F.binary_cross_entropy(
            result, batch["answered_correctly"], weight=batch["loss_mask"]
        )

        self.log("val_loss_step", loss)

        return (
            torch.masked_select(result, batch["loss_mask"] > 0),
            torch.masked_select(batch["answered_correctly"], batch["loss_mask"] > 0),
        )

    def validation_epoch_end(self, outputs):
        results = torch.cat([out[0] for out in outputs], dim=0)
        correct = torch.cat([out[1] for out in outputs], dim=0)
        auc = auroc(results, correct)

        self.log("avg_val_auc", auc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        #         return optimizer
        #         def warm_decay(step):
        #             if step < self.warmup:
        #                 return (step / self.warmup)
        #             if self.decay:
        #                 return self.warmup ** 0.5 * step ** -0.5
        #             else:
        #                 return 1

        #         scheduler1 = (
        #             {
        #                 "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, warm_decay),
        #                 "interval": "step",
        #                 "frequency": 1,
        #                 "name": "noam"
        #             }
        #         )

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", patience=10
            ),
            "monitor": "avg_val_auc",
            "interval": "epoch",
            "frequency": 1,
            "strict": True,
        }

        return [optimizer], [scheduler]

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split

torch.manual_seed(SEED)


# Params
learning_rate = 0.001
emb_dim = 64  # 256
dropout = 0.0
n_heads = 1  # 2
n_encoder_layers = 2  # 4
n_decoder_layers = 2  # 4
dim_feedforward = 256
batch_size = 256
num_user_train = 300000
num_user_val = 30000
use_time_feats = True
warmup = 4000
decay = False
tie_weights = False

# create split
train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [num_user_train, num_user_val, len(dataset) - num_user_train - num_user_val,],
)


# Init DataLoader from RIIID Dataset subset
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=6,
    pin_memory=torch.cuda.is_available(),  # if GPU then pin memory for perf
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=512,
    collate_fn=collate_fn,
    num_workers=6,
    pin_memory=torch.cuda.is_available(),
)


# Init our model
model = RIIDDTransformerModel(
    learning_rate=learning_rate,
    emb_dim=emb_dim,  # embedding dimension - this is for everything
    dropout=dropout,
    n_heads=n_heads,
    n_encoder_layers=n_encoder_layers,
    n_decoder_layers=n_decoder_layers,
    dim_feedforward=dim_feedforward,
    batch_size=batch_size,
    num_user_train=num_user_train,
    num_user_val=num_user_val,
    use_time_feats=use_time_feats,
    warmup=warmup,
    decay=decay,
    tie_weights=tie_weights,
)


# changed this:
experiment = ""


logger = TensorBoardLogger(
    "lightning_logs",
    name=f"tie_{tie_weights}_lr_{learning_rate}_emb_{emb_dim}_h_{n_heads}_enc_{n_encoder_layers}_dec_{n_decoder_layers}_ff_{dim_feedforward}",
)

# Initialize a trainer
trainer = pl.Trainer(
    gpus=1,
    max_epochs=500,
    progress_bar_refresh_rate=1,
    callbacks=[
        EarlyStopping(monitor="avg_val_auc", patience=20, mode="max"),
        ModelCheckpoint(
            monitor="avg_val_auc",
            filename="{epoch}-{val_loss_step:.2f}-{avg_val_auc:.2f}",
            mode="max"
        ),
        LearningRateMonitor(logging_interval="step"),
    ],
    logger=logger,
)

# Train the model ⚡
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=[val_loader])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type               | Params
----------------------------------------------------------------
0 | embed_content_id         | Embedding          | 865 K 
1 | embed_parts              | Embedding          | 512   
2 | embed_tags               | Embedding          | 12.1 K
3 | embed_answered_correctly | Embedding          | 256   
4 | embed_timestamps         | Linear             | 128   
5 | pos_encoder              | PositionalEncoding | 0     
6 | transformer              | Transformer        | 233 K 
7 | out_linear               | Linear             | 65    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
# avg_val_auc=0.761 - basic

In [11]:
model.load_from_checkpoint(
    "lightning_logs/tie_False_lr_0.0001_emb_256_h_2_enc_4_dec_4_ff_512/version_2/checkpoints/epoch=0.ckpt"
)

RIIDDTransformerModel(
  (embed_content_id): Embedding(13524, 256, padding_idx=13523)
  (embed_parts): Embedding(8, 256, padding_idx=0)
  (embed_tags): Embedding(189, 256, padding_idx=188)
  (embed_answered_correctly): Embedding(4, 256, padding_idx=3)
  (embed_timestamps): Linear(in_features=1, out_features=256, bias=True)
  (pos_encoder): PositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (linear2): Linear(in_features=512, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1)