In [None]:
#@title
# Copyright 2021 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Intro

This notebook trains a transformer model on the [EdNet dataset](https://github.com/riiid/ednet) using the [google/trax library](https://github.com/google/trax). The EdNet dataset is large set of student responses to multiple choice questions related to English language learning. A recent Kaggle competition, [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction), provided as subset of this data, consisting of 100 million responses to 13 thousand questions from 300 thousand students.

The state of the art result, detailed in [SAINT+: Integrating Temporal Features for EdNet Correctness Prediction](https://arxiv.org/abs/2010.12042), achieves an AUC ROC of 0.7914. The winning solution in the [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction) competition achieved an AUC ROC of 0.820. This notebook achieves an AUC ROC of 0.776 implementing an approach similar to the state of the art approach, training for 25,000 steps. It demonstrates several techniques that may be useful to those getting started with the [google/trax library](https://github.com/google/trax) or deep learning in general. This notebook demonstrates how to:

* Use BigQuery to perform feature engineering
* Create TFRecords with multiple sequences per record
* Modify the trax Transformer model to accommodate a knowledge tracing dataset:
    * Utilize multiple encoder and decoder embeddings - aggregated either by concatenation or sum
    * Include a custom metric - AUC ROC
    * Utilize a combined padding and future mask
* Use trax's [gin-config](https://github.com/google/gin-config) integration to specify training parameters
* Display training progress using trax's tensorboard integration

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/CalebEverett/riiid_transformer/blob/master/riiid-trax-transformer.ipynb)

In [None]:
# Choose a location for your storage bucket and BigQuery dataset to minimize data egress charges. Once you have
# created them, if you restart your notebook you can run this to see where your colab is running
# and factory reset until you get a location that is near your data.
!curl ipinfo.io

## Imports

In [None]:
# <hide-output>
!git clone https://github.com/google/trax.git
!pip install ./trax
!pip install -U pyarrow
!pip install -U google-cloud-bigquery google-cloud-bigquery-storage

In [None]:
from functools import partial
import json
import math
import os
from pathlib import Path
import subprocess
import sys
import time

import gin
from google.cloud import storage, bigquery
from google.cloud.bigquery import LoadJobConfig, QueryJobConfig, \
    SchemaField, SourceFormat
import jax
from jax.config import config
import pandas as pd
import numpy as np
import requests
import sqlite3
import trax
from trax import fastmath
from trax import layers as tl
from trax.fastmath import numpy as tnp
import tensorflow as tf
from tqdm.notebook import tqdm
import zipfile

# Create google credentials and store in drive
# https://colab.research.google.com/drive/1LWhrqE2zLXqz30T0a0JqXnDPKweqd8ET
# 
# Create a config.json file with variables for:
# "BUCKET": "",
# "BQ_DATASET": "",
# "KAGGLE_USERNAME": "",
# "KAGGLE_KEY": "",
# "PROJECT": "",
# "LOCATION": ""
from google.colab import drive

DRIVE = Path('/content/drive/My Drive')
PATH = 'riiid-transformer'

if not DRIVE.exists():
    drive.mount(str(DRIVE.parent))
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = str(DRIVE/PATH/'google.json')

with open(str(DRIVE/PATH/'config.json')) as f:
    CONFIG = json.load(f)
    os.environ = {**os.environ, **CONFIG}

from kaggle.api.kaggle_api_extended import KaggleApi
kaggle_api = KaggleApi()
kaggle_api.authenticate()

AUTO = tf.data.experimental.AUTOTUNE
BUCKET = os.getenv('BUCKET', 'riiid-transformer')
BQ_DATASET = os.getenv('BQ_DATASET', 'my_data')
LOCATION = os.getenv('LOCATION', 'us-central1')
PROJECT = os.getenv('PROJECT', 'fastai-caleb')

bucket = storage.Client(project=PROJECT).get_bucket(BUCKET)
dataset = bigquery.Dataset(f'{PROJECT}.{BQ_DATASET}')
bq_client = bigquery.Client(project=PROJECT, location=LOCATION)

%matplotlib inline
from matplotlib import pyplot as plt

%load_ext tensorboard

gin.enter_interactive_mode()

## Control Panel

These variables can be set to True to run the code in the sections described or False to skip over them after they have been run for the first time.

In [None]:
USE_TPU = False
DOWNLOAD_DATASET = False
LOAD_DATA_TO_BQ = False
PERFORM_FEATURE_ENGINEERING = False
TEST_FEATURE_ENGNEERING = False
CREATE_TFRECORDS = False
TEST_TFRECORDS = False
TRAIN_MODEL = False

## Initialize TPU

In [None]:
if USE_TPU:
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
    print(config.FLAGS.jax_backend_target)

## Download Dataset

In [None]:
if DOWNLOAD_DATASET:
    kaggle_api.competition_download_cli('riiid-test-answer-prediction')
    with zipfile.ZipFile('riiid-test-answer-prediction.zip', 'r') as zip_ref:
        zip_ref.extractall()
    for f in ['train.csv', 'questions.csv', 'lectures.csv']:
        bucket.blob(f).upload_from_filename(f)

if False:
    for f in tqdm(['train.csv', 'questions.csv', 'lectures.csv']):
        bucket.blob(f).download_to_filename(f)

## Create BigQuery Dataset

In [None]:
if False:
    delete_contents=False
    bq_client.delete_dataset(BQ_DATASET, delete_contents=delete_contents)
    print(f'Dataset {dataset.dataset_id} deleted from project {dataset.project}.')

try:
    dataset = bq_client.get_dataset(dataset.dataset_id)
    print(f'Dataset {dataset.dataset_id} already exists '
          f'in location {dataset.location} in project {dataset.project}.')
except:
    dataset = bq_client.create_dataset(dataset)
    print(f'Dataset {dataset.dataset_id} created '
          f'in location {dataset.location} in project {dataset.project}.')

## Dtypes

In [None]:
dtypes_orig = {
    'lectures': {
        'lecture_id': 'uint16',
        'tag': 'uint8',
        'part': 'uint8',
        'type_of': 'str',
    },
    'questions': {
        'question_id': 'uint16',
        'bundle_id': 'uint16',
        'correct_answer': 'uint8',
        'part': 'uint8',
        'tags': 'str',
        
    },
    'train': {
        'row_id': 'int64',
        'timestamp': 'int64',
        'user_id': 'int32',
        'content_id': 'int16',
        'content_type_id': 'int8',
        'task_container_id': 'int16',
        'user_answer': 'int8',
        'answered_correctly': 'int8',
        'prior_question_elapsed_time': 'float32', 
        'prior_question_had_explanation': 'bool'
    }
    
}

dtypes_new = {
    'lectures': {},
    'questions': {
        'tags_array': 'str'
    },
    'train': {
        'task_container_id_q': 'int16',
        'pqet_current': 'int32',
        'ts_delta': 'int32'
    }
}

dtypes = {}
for table_id in dtypes_orig:
    dtypes[table_id] = {
        **dtypes_orig[table_id],
        **dtypes_new[table_id]
    }

### Big Query Table Schemas

In [None]:
# <hide-input>
type_map = {
    'int64': 'INT64',
    'int32': 'INT64',
    'int16': 'INT64',
    'int8': 'INT64',
    'uint8': 'INT64',
    'uint16': 'INT64',
    'str': 'STRING',
    'bool': 'BOOL',
    'float32': 'FLOAT64'
}

schemas_orig = {table: [SchemaField(f, type_map[t]) for f, t in
                   fields.items()] for table, fields in dtypes_orig.items()}

schemas = {}
for table_id, fields in dtypes_new.items():
    new_fields = [SchemaField(f, type_map[t]) for
                  f, t in fields.items() if 'array' not in f]
    
    new_array_feilds = [SchemaField(f, 'INT64', 'REPEATED') for
                  f, t in fields.items() if 'array' in f]

    new_fields += new_array_feilds

    schemas[table_id] = schemas_orig[table_id] + new_fields

### Load Tables

In [None]:
def load_job_cb(future):
    """Prints update upon completion to output of last run cell."""
    
    seconds = (future.ended - future.created).total_seconds()
    print(f'Loaded {future.output_rows:,d} rows to table {future.job_id.split("_")[0]} in '
        f'{seconds:>4,.1f} sec, {int(future.output_rows / seconds):,d} per sec.')

def load_csv_from_uri(table_id, schemas_orig):
    full_table_id = f'{BQ_DATASET}.{table_id}'

    job_config = LoadJobConfig(
        schema=schemas_orig[table_id],
        source_format=SourceFormat.CSV,
        skip_leading_rows=1
        )

    uri = f'gs://{BUCKET}/{table_id}.csv'
    load_job = bq_client.load_table_from_uri(uri, full_table_id,
                                            job_config=job_config,
                                            job_id_prefix=f'{table_id}_')
    print(f'job {load_job.job_id} started')
    load_job.add_done_callback(load_job_cb)
    
    return load_job

In [None]:
if LOAD_DATA_TO_BQ:
    for table_id in dtypes_orig:
        lj = load_csv_from_uri(table_id, schemas_orig).result()

### Update BiqQuery Schemas

Before performing feature engineering, we have to update the table schemas in Big Query to create columns for the new features.

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    for table_id, schema in schemas.items():
        table = bq_client.get_table(f'{BQ_DATASET}.{table_id}')
        table.schema = schema
        table = bq_client.update_table(table, ['schema'])

## Feature Engineering

Using BigQuery for a dataset of 100 million rows is much faster than using local dataframes. In addition, you get to use the full power of SQL, including [window functions](https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts), which are especially useful for time series feature engineering.

Feature engineering for this problem is fairly minimal and includes:
* Replacing missing null values for `prior_question_elapsed_time` and `prior_question_had_explanation` in the train table
* Replacing one missing tag value in the questions table
* Recalcuating the `task_container_id` as `task_container_id_q` so that it excludes lecture records and increases monotonically with `timetamp` so that the calucations for elapsed time and time delta, which depend on values from the immediately prior and immediately succeeding records, are calculated correctly.
* Calculating `pqet_current`, the time it took on average to answer the questions in the current `task_container_id_q`.
* Calculating `ts_delta`, the elapsed time between the last `task_container_id_q` and the current one.
* Creating `folds` table, in which users are assigned to one of 20 folds.
* Creating a `tags_array` field in the questions table, that returns an array of six elements populated with the tags assigned to each questions, padded with zeros if there are less than six.

In [None]:
def done_cb(future):
    seconds = (future.ended - future.started).total_seconds()
    print(f'Job {future.job_id} finished in {seconds} seconds.')

def run_query(query, job_id_prefix=None, wait=True,
                use_query_cache=True):

    job_config = QueryJobConfig(
        use_query_cache=use_query_cache)

    query_job = bq_client.query(query, job_id_prefix=job_id_prefix,
                                        job_config=job_config)
    print(f'Job {query_job.job_id} started.')
    query_job.add_done_callback(done_cb)
    if wait:
        query_job.result()
    
    return query_job

In [None]:
def get_df_query_bqs(query, dtypes=None, fillna=None):
    qj = bq_client.query(query)
    df = qj.to_dataframe(create_bqstorage_client=True, progress_bar_type='tqdm_notebook')
    if fillna is not None:
        df = df.fillna(fillna)
    try:
        df = df.astype({c: dtypes.get(c, 'int32') for c in df.columns})    
    except:
        print('dtypes not applied.')
    finally:    
        return df

### Replace Missing Values

In [None]:
def update_missing_values(table_id='train', column_id=None, value=None):
    return f"""
        UPDATE {BQ_DATASET}.{table_id}
        SET {column_id} = {value}
        WHERE {column_id} is NULL;
    """, sys._getframe().f_code.co_name + '_'

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    qj = run_query(*update_missing_values('train', 'prior_question_elapsed_time', '0'))
    qj = run_query(*update_missing_values('train', 'prior_question_had_explanation', 'false'))
    qj = run_query(*update_missing_values('questions', 'tags', '"188"'))

### Recalculate Task Container Ids for Questions Only

In [None]:
def update_task_container_id(table_id='train',
                                column_id='task_container_id',
                                excl_lectures=True):
    excl_lec = 'WHERE content_type_id = 0' if excl_lectures else ''
    
    return f"""
        UPDATE {BQ_DATASET}.{table_id} t
        SET {column_id} = target.calc
        FROM (
            SELECT row_id, DENSE_RANK()
            OVER (
                PARTITION BY user_id
                ORDER BY timestamp
            ) calc
            FROM {BQ_DATASET}.{table_id}
            {excl_lec}
        ) target
        WHERE target.row_id = t.row_id
    """, sys._getframe().f_code.co_name + '_'

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    q  = update_task_container_id(table_id='train',
                                column_id='task_container_id_q ',
                                excl_lectures=True)
    qj = run_query(*q)

### Calculate Current Question Elapsed Time and Timestamp Delta

In [None]:
def update_pqet_current(table_id='train'):
    return f"""
        UPDATE {BQ_DATASET}.{table_id} t
        SET t.pqet_current = CAST(p.pqet_current AS INT64)
        FROM (
            SELECT
            row_id, LAST_VALUE(prior_question_elapsed_time) OVER (
                PARTITION BY user_id ORDER BY task_container_id_q
                RANGE BETWEEN 1 FOLLOWING AND 1 FOLLOWING) pqet_current
            FROM {BQ_DATASET}.train            
            WHERE content_type_id = 0
        ) p
        WHERE t.row_id = p.row_id;
        
        UPDATE {BQ_DATASET}.{table_id}
        SET pqet_current = 0
        WHERE pqet_current IS NULL;
        
    """, sys._getframe().f_code.co_name + '_'

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    qj = run_query(*update_pqet_current())

In [None]:
def update_ts_delta(table_id='train'):
    return f"""
        UPDATE {BQ_DATASET}.{table_id} t
        SET t.ts_delta = timestamp - p.ts_prior
        FROM (
            SELECT
            row_id, LAST_VALUE(timestamp) OVER (
                PARTITION BY user_id ORDER BY task_container_id_q
                RANGE BETWEEN 1 PRECEDING AND 1 PRECEDING) ts_prior
            FROM {BQ_DATASET}.train            
            WHERE content_type_id = 0
        ) p
        WHERE t.row_id = p.row_id;
        
        UPDATE {BQ_DATASET}.{table_id}
        SET ts_delta = 0
        WHERE ts_delta IS NULL;
    """, sys._getframe().f_code.co_name + '_'

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    qj = run_query(*update_ts_delta())

### Create Folds Table
Assign users randomly to one of 20 folds. Store total records to facilitate filtering based on record count.

In [None]:
def create_table_folds(table_id='folds', n_folds=20):
    return f"""
        DECLARE f INT64;

        CREATE OR REPLACE TABLE {BQ_DATASET}.{table_id} (
            user_id INT64,
            fold INT64,
            record_count INT64
        );

        INSERT {BQ_DATASET}.{table_id} (user_id, fold, record_count)
        SELECT f.user_id, CAST(FLOOR(RAND() * {n_folds}) AS INT64) fold, f.record_count
        FROM (
        SELECT user_id,
            COUNT(row_id) record_count
        FROM {BQ_DATASET}.train
        WHERE content_type_id = 0
        GROUP BY user_id
        ) f
        ORDER BY user_id;
    """, sys._getframe().f_code.co_name + '_'

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    qj = run_query(*create_table_folds())

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    df_folds = get_df_query_bqs(f"""
        SELECT *
        FROM {BQ_DATASET}.folds
    """,
    dtypes=dtypes)

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    df_folds.groupby('fold').count().user_id.plot(kind='bar', title='Count of Users by Fold');

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    df_folds.groupby('fold').mean().record_count.plot(kind='bar', title='Average Records per User by Fold');

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    df_fold_ac = get_df_query_bqs(f"""
        SELECT fold, SUM(answered_correctly) ac_sum, COUNT(answered_correctly) rec_count
        FROM {BQ_DATASET}.train
        JOIN {BQ_DATASET}.folds
        ON train.user_id = folds.user_id
        GROUP BY fold
    """,
    dtypes=dtypes)

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    df_fold_ac.rec_count.plot(kind='bar', title='Count of Records by Fold');

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    (df_fold_ac.ac_sum / df_fold_ac.rec_count).plot(kind='bar', title='Percent Answered Correctly by Fold');

### Create Tags Array on Questions Table
We need the tags as an array later when we create TFRecords. We also increment by one and pad with zeros to a fixed length of 6 so that they can be concatentated as a feature for modeling.

In [None]:
def update_tags_array(table_id='questions', column_id='tags_array'):
    
    return f"""
        UPDATE {BQ_DATASET}.{table_id} q
        SET {column_id} = tp.tags_fixed_len
        FROM (
            WITH tags_padded AS (
                WITH tags_table AS (SELECT question_id, tags FROM {BQ_DATASET}.{table_id})
                SELECT question_id, ARRAY_CONCAT(ARRAY_AGG(CAST(tag AS INT64) + 1), [0,0,0,0,0]) tags_array
                FROM tags_table, UNNEST(SPLIT(tags, ' ')) as tag
                GROUP BY question_id
            )
            SELECT question_id,
                ARRAY(SELECT x FROM UNNEST(tags_array) AS x WITH OFFSET off WHERE off < 6 ORDER BY off) tags_fixed_len
            FROM tags_padded
        ) tp
        WHERE tp.question_id = q.question_id
    """, sys._getframe().f_code.co_name + '_'

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    qj = run_query(*update_tags_array())

In [None]:
if PERFORM_FEATURE_ENGINEERING:
    df_q = get_df_query_bqs('select * from my_data.questions', dtypes=dtypes)
    print(df_q.head())

## Feature Engineering Tests
* Features come back out of Biq Query with the same values they went in with
* `ts_delta` is equal to difference between timestamps on consecutive records
* `pqet_current` is equal to `prior_question_elapsed_time` from next record
* visually inspect distributions of `ts_delta` and `pqet_current`

### Load Sample from train.csv

In [None]:
if TEST_FEATURE_ENGNEERING:
    df_train_samp = pd.read_csv('train.csv', nrows=100000)
    df_train_samp.prior_question_had_explanation = df_train_samp.prior_question_had_explanation.fillna(False).astype(bool)
    df_train_samp.prior_question_elapsed_time = df_train_samp.prior_question_elapsed_time.fillna(0)
    user_ids_samp = df_train_samp.user_id.unique()[:-1]
    print(len(user_ids_samp))
    df_train_samp = df_train_samp[df_train_samp.user_id.isin(user_ids_samp) & (df_train_samp.content_type_id == 0)].reset_index(drop=True)
    print(len(df_train_samp))

### Pull sample of corresponding user_ids from BigQuery

In [None]:
if TEST_FEATURE_ENGNEERING:
    df_bq_samp = get_df_query_bqs(f"""
        SELECT *
        FROM {BQ_DATASET}.train
        WHERE user_id IN ({(',').join(map(str, user_ids_samp))})
        AND content_type_id = 0
        ORDER BY user_id, timestamp, row_id
    """,
    dtypes=None)

### Tests

In [None]:
if TEST_FEATURE_ENGNEERING:
    # values in columns are the same between train.csv and bq
    for c in df_train_samp.columns:
        assert all(df_train_samp[c] == df_bq_samp[c]), f'{c} is not the same'

    # pqet_current pulls prior_question_elapsed_time back one task_container_id for each user
    df_bq_samp_tst = df_bq_samp[['user_id', 'task_container_id_q', 'prior_question_elapsed_time', 'pqet_current']].groupby(['user_id', 'task_container_id_q']).max()

    for user_id in user_ids_samp:
        assert all(df_bq_samp_tst.loc[user_id].pqet_current.shift(1).iloc[1:] == df_bq_samp_tst.loc[user_id].prior_question_elapsed_time.iloc[1:])

    # ts_delta equal to timestamp from current task_container_id_q minus timestamp from prior task_container_id_q
    df_bq_samp_tst = df_bq_samp[['user_id', 'task_container_id_q', 'timestamp', 'ts_delta']].groupby(['user_id', 'task_container_id_q']).max()

    for user_id in user_ids_samp:
        assert all((df_bq_samp_tst.loc[user_id].timestamp - df_bq_samp_tst.loc[user_id].timestamp.shift(1)).iloc[1:] == df_bq_samp_tst.loc[user_id].ts_delta.iloc[1:])

In [None]:
if TEST_FEATURE_ENGNEERING:
    df_bq_samp.pqet_current.hist();

In [None]:
if TEST_FEATURE_ENGNEERING:
    df_bq_samp.ts_delta.hist();

## Create TFRecords

We are going to create a set of TFRecords with one user per record and one fold per file. We are going to include the following columns as features:
* `user_id` - this won't get used as a feature, but is included to able to tie back to original data
* `content_id` - incremented by one to reserve 0 for padding character
* `answered_correctly` - incremented by one to reserve 0 for padding character
* `part`
* `pqet_curret`
* `ts_delta`
* `tags` - already incremented by one with zeros as padding
* `task_container_id` - excluding lectures and already indexed to one
* `timestamp`

In [None]:
def _int64_feature(value):
        
    if type(value) != type(list()):
        value = [value]

    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

In [None]:
def serialize_example(user_id, features):
    
    feature_names = ['content_id', 'answered_correctly', 'part', 'pqet_current', 'ts_delta', 'tags',
                     'task_container_id', 'timestamp']
    
    feature = {'user_id': _int64_feature(user_id)}
    
    for i, n in enumerate(feature_names):
        feature[n] = _int64_feature(features[i])

    return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()

In [None]:
def parse_example(example):
    
    feature_names = {'content_id': tf.int32, 'answered_correctly': tf.int32, 'part': tf.int32,
                     'pqet_current': tf.int32, 'ts_delta': tf.int64, 'tags': tf.int32,
                     'task_container_id': tf.int32, 'timestamp': tf.int64}
      
    features = {'user_id': tf.io.FixedLenFeature([1], tf.int64)}
    
    for k, v in feature_names.items():
        features[k] = tf.io.VarLenFeature(tf.int64)

    example = tf.io.parse_single_example(example, features)

    for k, v in feature_names.items():
        example[k] = tf.cast(example[k].values, v)
        
    example['tags'] = tf.reshape(example['tags'], (tf.size(example['answered_correctly']), 6))

    return example

In [None]:
def get_ds_tfrec_raw(folds=[0]):
    file_pat = 'gs://{BUCKET}/tfrec/{f:02d}-*.tfrec'
    file_pats = [file_pat.format(BUCKET=BUCKET, f=f) for f in folds]
    options = tf.data.Options()

    ds = (tf.data.Dataset.list_files(file_pats)
          .with_options(options)
          .interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)
          .map(parse_example, num_parallel_calls=AUTO)
         )
    
    return ds

In [None]:
def get_df_tfrec(folds):
    df_tfrec = get_df_query_bqs(f"""
        SELECT fold, train.user_id, content_id + 1 content_id,
            answered_correctly + 1 answered_correctly, part, pqet_current, ts_delta,
            tags_array tags, task_container_id_q task_container_id, timestamp
        FROM {BQ_DATASET}.train
        JOIN {BQ_DATASET}.folds
        ON train.user_id = folds.user_id
        JOIN {BQ_DATASET}.questions
        ON train.content_id = questions.question_id
        WHERE fold IN ({(', ').join(map(str, folds))})
        AND content_type_id = 0
        ORDER BY user_id, timestamp, row_id
    """,
    dtypes=None)

    return df_tfrec

In [None]:
def write_tfrecords(folds):
    
    df_tfrec = get_df_tfrec(folds)
    
    for f in folds:
        groups_dict = (df_tfrec[df_tfrec.fold == f]
                       .groupby('user_id')
                       .apply(lambda r: (list(r['content_id'].values),
                                         list(r['answered_correctly'].values),
                                         list(r['part'].values),
                                         list(r['pqet_current'].values.astype(np.int64)),
                                         list(r['ts_delta'].values.astype(np.int64)),
                                         list(np.concatenate(r['tags'].values)),
                                         list(r['task_container_id'].values.astype(np.int64)),
                                         list(r['timestamp'].values.astype(np.int64)),
                                         ))).to_dict()        
        
        out_path = f'gs://{BUCKET}/tfrec'
        filename = f'{f:02d}-{len(groups_dict.keys())}.tfrec'
        record_file = f'{out_path}/{filename}'

        with tf.io.TFRecordWriter(record_file) as writer:
            for user_id, features in tqdm(groups_dict.items(), desc=f'Fold {f:02d}'):
                writer.write(serialize_example(user_id, features))

## Write TFRecords

* Process in chunks to avoid running out of memory.

In [None]:
if CREATE_TFRECORDS:
    fold_splits = np.array_split(np.arange(20), 10)
    for folds in tqdm(fold_splits):
        write_tfrecords(folds)

## Test TFRecords

* Same number of users and records as in `df_folds`
* Values in tfrecords are the same as in original data

In [None]:
def test_tfrecord_folds(folds_test, n_sample=100):
    pbar = tqdm(total=n_sample)
    ds = get_ds_tfrec_raw(folds_test)
    df = get_df_tfrec(folds_test)

    for b in ds.shuffle(10000).take(n_sample):
        try:
            for c in [c for c in df.columns if c not in  ['tags', 'fold', 'user_id']]:
                try:
                    assert all(df[df.user_id == b['user_id'].numpy()[0]][c] == b[c].numpy())
                except:
                    print(f"Error for user {b['user_id'].numpy()[0]}")
            user_tags = np.concatenate(df[df.user_id == b['user_id'].numpy()[0]].tags.values)
            assert all(user_tags == (b['tags'].numpy().flatten()))
        except:
            print(f"Error for user {b['user_id'].numpy()[0]}")
        finally:
            pbar.update()

In [None]:
if TEST_TFRECORDS:
    folds_test = list(range(20))
    ds = get_ds_tfrec_raw(folds=folds_test)

    df_folds = get_df_query_bqs(f"""
        SELECT *
        FROM {BQ_DATASET}.folds
    """,
    dtypes=dtypes)

    user_ids = []
    count = 0
    for b in ds:
        user_ids.append(b['user_id'].numpy()[0])
        count += len(b['content_id'].numpy())

    assert len(set(user_ids)) == len(df_folds)
    assert df_folds.record_count.sum() == count

    test_tfrecord_folds([10])

    b = next(iter(ds))
    print(b)

## Dataset Functions

In [None]:
@gin.configurable
def get_ds_tfrec(folds=None, max_len=None, min_len=None):
    file_pat = 'gs://{BUCKET}/tfrec/{f:02d}-*.tfrec'
    file_pats = [file_pat.format(BUCKET=BUCKET, f=f) for f in folds]
    options = tf.data.Options()

    ds = (tf.data.Dataset.list_files(file_pats, shuffle=True)
          .with_options(options)
          .interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)
          .shuffle(10000)
          .map(parse_example, num_parallel_calls=AUTO)
          .filter(partial(filter_min_len, min_len=min_len))
          .map(example_to_tuple, num_parallel_calls=AUTO)
          .map(partial(trunc_seq, max_len=max_len), num_parallel_calls=AUTO)
          .map(con_to_cat, num_parallel_calls=AUTO)
         )

    ds = ds.repeat().prefetch(AUTO)
    
    def gen(generator=None):
        del generator
        for example in fastmath.dataset_as_numpy(ds):
            yield example
    
    return gen

In [None]:
def filter_min_len(e, min_len):
    return tf.size(e['content_id']) >= min_len

In [None]:
def example_to_tuple(example):
    return (example['content_id'], example['part'], example['tags'], example['task_container_id'],
            example['answered_correctly'], example['pqet_current'], example['ts_delta'])

In [None]:
def trunc_seq(*b, max_len=None):
    """Returns a sequence drawn randomly from available tokens with a max length
        of max_len.
    """
    
    max_len = tf.constant(max_len)
    seq_len = tf.size(b[0])
    seq_end_min = tf.minimum(seq_len - 1, max_len)
    seq_end = tf.maximum(max_len, tf.random.uniform((), seq_end_min, seq_len, dtype=tf.int32))
    
    def get_seq(m):
        return m[seq_end-max_len:seq_end]
    
    return tuple(map(get_seq, b))

In [None]:
# SAINT+ Elapsed Time = prior_question_elapsed_time and Lag Time = time_stamp_1 - timestamp_0
# Elapsed Time categorical - capped at 300 seconds, discrete value for each second
# Lag Time - discretized to minutes 0, 1, 2, 3, 4, 5, 10, 20, 30 ... 1440. 150 discrete values.

ts_delta_lookup = tf.concat([tf.range(6, dtype=tf.int32), tf.repeat(5, 5)], axis=0)

cat = 10
while cat < 1440:
    ts_delta_lookup = tf.concat([ts_delta_lookup, tf.repeat(cat, 10)], axis=0)
    cat += 10
    
ts_delta_lookup = tf.concat([ts_delta_lookup, [1440]], axis=0)

def con_to_cat(*b):
    
    def pqet_cat(e, vocab_size=None, val_min=None, val_max=None):
        e = tf.clip_by_value(e, val_min, val_max)
        val_range = val_max - val_min
        e = tf.cast((e - val_min) * (vocab_size - 1) / val_range, tf.int32)
        return e
    
    def ts_delta_cat(e):
        val_max = tf.cast(tf.reduce_max(ts_delta_lookup) * 60000, tf.float64)
        e = tf.clip_by_value(tf.cast(e, tf.float64), 0, val_max)
        e = tf.cast(e / 60000, tf.int32)
        e = tf.gather(ts_delta_lookup, e)
        return e
    
    pqet = pqet_cat(b[-2], vocab_size=300, val_min=0, val_max=300000)
    ts_delta = ts_delta_cat(b[-1])
    
    return tuple((*b[:-2], pqet, ts_delta))

## Metrics Functions

In [None]:
def RocAucScore(num_thresholds=100, pos_label=2):
    def f(y_score, y_true, weight):        
        weight = tnp.expand_dims(tnp.ravel(weight), -1)
        
        softmax=tl.Softmax(axis=-1)
        y_score = tnp.ravel(softmax(y_score)[:, :, -1])
        y_score = tnp.expand_dims(y_score, -1)
        y_true = tnp.expand_dims(tnp.ravel(y_true) == pos_label, -1).astype(tnp.float32)
        
        thresholds = tnp.expand_dims(tnp.linspace(1, 0, num_thresholds), 0)
        
        threshold_counts = y_score > thresholds
        
        tps = tnp.logical_and(threshold_counts, y_true)
        fps = tnp.logical_and(threshold_counts, tnp.logical_not(y_true))
        
        tps = tnp.sum(tps * weight, axis=0)
        fps = tnp.sum(fps * weight, axis=0)
        
        tpr = tps / tps[-1]
        fpr = fps / fps[-1]
        
        return tnp.trapz(tpr, fpr)
    
    return tl.Fn('RocAucScore', f)

In [None]:
metrics = {
    'loss': tl.WeightedCategoryCrossEntropy(),
    'accuracy': tl.WeightedCategoryAccuracy(),
    'sequence_accuracy': tl.MaskedSequenceAccuracy(),
    'auc_all': RocAucScore(),
    'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum())
}

## Model Functions

In [None]:
@gin.configurable
@tl.assert_shape('bl->b1ll')
def PaddingFutureMask(pad=0, block_self=False, tid=True, pad_end=False):
    def f(x):
        mask_pad = tnp.logical_not(tnp.equal(x, 0))[:, tnp.newaxis, tnp.newaxis, :]
        
        x_new = x
        if pad_end:
            x_new = tnp.where(tnp.equal(x, 0), tnp.max(x), x)
        
        if tid:
            mask_future = x_new[:, :, tnp.newaxis] >= x_new[:, tnp.newaxis, :] + block_self
            mask_future = mask_future[:, tnp.newaxis, :, :]
        else:
            mask_future = tnp.arange(x.shape[-1])[tnp.newaxis, tnp.newaxis, :, tnp.newaxis] \
                >= tnp.arange(x.shape[-1])[tnp.newaxis, :]
        
        return tnp.logical_and(mask_future, mask_pad)
        
    return tl.Fn(f'PaddingFutureMask({pad})', f)


# the only thing different here is the shape assertions to accomodate the change
# in mask shape from b11l to b1ll

@tl.assert_shape('bld,b1ll->bld,b1ll')
@gin.configurable
def KTAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    return tl.Serial(
        tl.Select([0, 0, 0]),
        tl.AttentionQKV(
            d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )

def my_add_loss_weights(generator, id_to_mask=None):
    for example in generator:
        weights = (example[0] != id_to_mask).astype(tnp.float32)
        yield (*example, weights)

@gin.configurable
def KTAddLossWeights(id_to_mask=0):  # pylint: disable=invalid-name
    return lambda g: my_add_loss_weights(g, id_to_mask=id_to_mask)

def trim_tags(generator):
    for example in generator:
        # content_id, part, tags, tid, ac, pqet, ts_delta
        yield (example[0], example[1], example[2][:, :, :6], example[3], example[4], example[5], example[6])

@gin.configurable
def TrimTags():
    return lambda g: trim_tags(g)

@gin.configurable
def KTPositionalEncoder(max_position=10000.0, d_model=512, tid=False):   
    """This is set up to perform standard positional encoding based on the
    position in the sequence, but also to calculate position based on the
    id of the task container to which the question belongs.
    """
    def f(inputs):
        # whether or not to use task_container_id or seq position
        if tid:
            position = tnp.expand_dims(inputs.astype(tnp.float32), -1)
        else:
            position = tnp.arange(inputs.shape[1])
            
            position = position.astype(tnp.float32)[tnp.newaxis, :, tnp.newaxis]

        i = tnp.expand_dims(tnp.arange(d_model, dtype=tnp.float32), 0)

        angles = 1 / tnp.power(max_position, (2 * (i // 2)) /
                               tnp.array(d_model, dtype=tnp.float32))

        angle_rads = position * angles

        # apply sin to even index in the array
        sines = tnp.sin(angle_rads[:, :, 0::2])
        # apply cos to odd index in the array
        cosines = tnp.cos(angle_rads[:, :, 1::2])

        pos_encoding = tnp.concatenate([sines, cosines], axis=-1)

        return pos_encoding

    return tl.Fn('KTPositionalEncoder', f)

In [None]:
@gin.configurable
def KTTransformer(d_model,
                  d_input,
                  d_part,
                  d_tags,
                  d_out,
                  d_pqet,
                  d_ts_delta,
                  d_tid,
                  embed_concat=False,
                  d_ff=2048,
                  n_encoder_layers=6,
                  n_decoder_layers=6,
                  n_heads=8,
                  max_len=2048,
                  dropout=0.1,
                  dropout_shared_axes=None,
                  mode='train',
                  ff_activation=tl.Relu):
        
    def Embedder(vocab_size, d_embed):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_embed),
            tl.Dropout(
                rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        ]

    # Encoder Embeddings
    in_embedder = Embedder(*d_input)
    part_embedder = Embedder(*d_part)
    # Keeps the tags in the data batch tuple, but drops it if it
    # isn't included in the embeddings.
    if d_tags is not None:
        tags_embedder = tl.Serial(Embedder(*d_tags), tl.Sum(axis=-2))
    else:
        tags_embedder = tl.Drop()
    in_pos_encoder = KTPositionalEncoder(*d_tid)

    # Decoder Embeddings
    out_embedder = Embedder(*d_out)
    pqet_embedder = Embedder(*d_pqet)
    ts_delta_embedder = Embedder(*d_ts_delta)
    out_pos_encoder = KTPositionalEncoder(*d_tid)

    encoder_mode = 'eval' if mode == 'predict' else mode

    in_encoder = [tl.Parallel(in_embedder, part_embedder, tags_embedder, in_pos_encoder)]
    out_encoder = [tl.Parallel(out_embedder, pqet_embedder, ts_delta_embedder, out_pos_encoder)]
    
    if embed_concat:
        if d_tags is not None:
            in_encoder += [tl.Concatenate(n_items=3), tl.Add()]
        else:
            in_encoder += [tl.Concatenate(n_items=2), tl.Add()]
        out_encoder += [tl.Concatenate(n_items=3), tl.Add()]
    else:
        if d_tags is not None:
            in_encoder += [tl.Add(), tl.Add(), tl.Add()]
        else:
            in_encoder += [tl.Add(), tl.Add()]
        out_encoder += [tl.Add(), tl.Add(), tl.Add()]

    encoder_blocks = [
        _KTEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation)
        for i in range(n_encoder_layers)]

    encoder = tl.Serial(
        in_encoder,
        encoder_blocks,
        tl.LayerNorm()
    )

    encoder_decoder_blocks = [
        _KTEncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                             mode, ff_activation)
        for i in range(n_decoder_layers)]

                                        # output tuple - leading number is max index    
    return tl.Serial(                   # 7: 0:tok_e 1:tok_p 2:tok_t 3:tok_tid 4:tok_d 5:tok_pq, 6:tok_tsd 7:wts_l  
        tl.Select([0, 1, 2, 3, 3, 3,    # 10: 0:tok_e 1:tok_p 2:tok_t 3:tok_tid 4:tok_tid 5: tok_tid
                   4, 5, 6, 4]),        #     6:tok_d 7:tok`_pq, 8:tok_tsd 9:tok_d 10:wts_l

        # Encode.
        tl.Parallel(
            tl.Select([0, 1, 2, 3]),
            PaddingFutureMask(tid=True)
        ),                              # 10: tok_e tok_p tok_t tok_tid mask_combined tok_tid tok_d tok_pq tok_tsd tok_d wts_l
        encoder,                        # 7: vec_e mask_combined tok_tid tok_d tok_pq tok_tsd tok_d wts_l
        # Decode.
        tl.Select([3, 4, 5, 2, 2, 0]),  # 7: tok_d tok_pq tok_tsd tok_tid tok_tid vec_e tok_d wts_l
        tl.Parallel(
            tl.ShiftRight(mode=mode),
            tl.ShiftRight(mode=mode),  
            tl.ShiftRight(mode=mode),
            tl.ShiftRight(mode=mode),
            tl.Serial(tl.ShiftRight(),
                      PaddingFutureMask(tid=False)),
        ),                              # 7: tok_d tok_pq tok_tsd tok_tid mask_combined vec_e tok_d wts_l 
        out_encoder,                    # 4: vec_d mask_combined vec_e tok_d wts_l
        encoder_decoder_blocks,         # 4: vec_d mask_combined vec_e tok_d wts_l
        tl.LayerNorm(),                 # 4: vec_d mask_combined vec_e tok_d wts_l

        # Map to output vocab.
        tl.Select([0], n_in=3),         # 3: vec_d tok_d wts_l
        tl.Dense(d_out[0]),             # vec_d .....
    )


def _KTEncoderBlock(d_model, d_ff, n_heads,
                  dropout, dropout_shared_axes, mode, ff_activation):
    """Same as the default, but changes attention layer to KTAttention to 
    accept a combined padding and future mask.
    """
    
    attention = KTAttention(
        d_model, n_heads=n_heads, dropout=dropout, mode=mode)

    feed_forward = _KTFeedForwardBlock(
        d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

    dropout_ = tl.Dropout(
        rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

    return [
        tl.Residual(
            tl.LayerNorm(),
            attention,
            dropout_,
        ),
        tl.Residual(
            feed_forward
        ),
    ]

def _KTEncoderDecoderBlock(d_model, d_ff, n_heads,
                         dropout, dropout_shared_axes, mode, ff_activation):
    """Same as the default, but changes the first layer to KTAttention to 
    accept a combined padding and future mask.
    """
    def _Dropout():
        return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

    attention = KTAttention(
        d_model, n_heads=n_heads, dropout=dropout, mode=mode)

    attention_qkv = tl.AttentionQKV(
        d_model, n_heads=n_heads, dropout=dropout, mode=mode)

    feed_forward = _KTFeedForwardBlock(
        d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

    return [                             # vec_d masks vec_e
        tl.Residual(
            tl.LayerNorm(),              # vec_d ..... .....
            attention,                   # vec_d ..... .....
            _Dropout(),                  # vec_d ..... .....
        ),
        tl.Residual(
            tl.LayerNorm(),              # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,               # vec_d masks vec_e
            _Dropout(),                  # vec_d masks vec_e
        ),
        tl.Residual(
            feed_forward                 # vec_d masks vec_e
        ),
    ]

def _KTFeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes,
                      mode, activation):
    """Same as default.
    """
    dropout_middle = tl.Dropout(
        rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
    dropout_final = tl.Dropout(
        rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

    return [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        activation(),
        dropout_middle,
        tl.Dense(d_model),
        dropout_final,
    ]

## Configuration

In [None]:
# Configure hyperparameters.

total_steps = 10000

gin.clear_config()
gin.parse_config(f"""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
# min_len = 12
# max_len = 64
# d_model = 512 # need to make sure this works with concat embeddings
# d_ff = 256
# n_encoder_layers = 2
# n_decoder_layers = 2
# n_heads = 2
# dropout = 0.0

min_len = 12
max_len = 256
d_model = 512 # need to make sure this works with concat embeddings
d_ff = 1024
n_encoder_layers = 6
n_decoder_layers = 6
n_heads = 8
dropout = 0.1

# Set to True to aggregate embeddings by concatenation. If set
# to False aggregation will be by sum.
embed_concat = True

# (Vocab, depth) Uncomment to use with aggregation by concatenation.
d_input = (13500, 384)
d_part = (8, 8)
d_tags = (189, 120)

# (Vocab, depth) Uncomment to use with aggregation by concatenation.
d_out = (3, 384)
d_pqet = (300, 64)
d_ts_delta = (150, 64)

# Used for positional encodings if not None. Positional encoding based
# on sequence in batch if None.
d_tid = (10000, %d_model)

# d_input = (13500, %d_model)
# d_part = (8, %d_model)
# d_tags = (189, %d_model)
# # d_tags = None
# d_out = (3, %d_model)
# d_pqet = (300, %d_model)
# d_ts_delta = (150, %d_model)
# d_tid = (10000, %d_model)

total_steps = {total_steps}

# Parameters for learning rate schedule:
# ==============================================================================
warmup_and_rsqrt_decay.n_warmup_steps = 3000
warmup_and_rsqrt_decay.max_value = 0.001

# multifactor.constant = 0.01
# multifactor.factors = 'constant * linear_warmup * cosine_decay'
# multifactor.warmup_steps = 4000
# multifactor.steps_per_cycle = %total_steps
# multifactor.minimum = .0001

# Parameters for Adam:
# ==============================================================================
# Adam.weight_decay_rate=0.0
Adam.b1 = 0.9
Adam.b2 = 0.999
Adam.eps = 1e-8

# Parameters for input pipeline:
# ==============================================================================
get_ds_tfrec.min_len = %min_len
get_ds_tfrec.max_len = %max_len
train/get_ds_tfrec.folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
eval/get_ds_tfrec.folds = [19]

BucketByLength.boundaries =  [32, 64, 128]
BucketByLength.batch_sizes = [512, 256, 128, 64]
# BucketByLength.batch_sizes = [16, 8, 4,  2]

BucketByLength.strict_pad_on_len = True

KTAddLossWeights.id_to_mask = 0

train/make_additional_stream.stream = [
  @train/get_ds_tfrec(),
  @BucketByLength(),
  @TrimTags(),
  @KTAddLossWeights()
]

eval/make_additional_stream.stream = [
  @eval/get_ds_tfrec(),
  @BucketByLength(),
  @TrimTags(),
  @KTAddLossWeights()
]

make_inputs.train_stream = @train/make_additional_stream()
make_inputs.eval_stream = @eval/make_additional_stream()

# Parameters for KTPositionalEncoder:
# ==============================================================================
KTPositionalEncoder.d_model = %d_model

# Set to True to calculate positional encodings based on position in orginal
# full length sequence, False to be based on position in batch sequence.
KTPositionalEncoder.tid = False

# Parameters for PaddingFutureMaske:
# ==============================================================================
PaddingFutureMask.pad_end = False

# Set to True to calculate future mask based on task container id (questions
# are delivered to users in groups identified by task_container id) or False
# to be based next question only.
PaddingFutureMask.tid = False

# Parameters for KTTransformer:
# ==============================================================================
KTTransformer.d_model = %d_model
KTTransformer.d_input = %d_input
KTTransformer.d_part = %d_part
KTTransformer.d_tags = %d_tags
KTTransformer.d_out = %d_out
KTTransformer.d_pqet = %d_pqet
KTTransformer.d_ts_delta = %d_ts_delta
KTTransformer.d_tid = %d_tid
KTTransformer.embed_concat = %embed_concat
KTTransformer.d_ff = %d_ff
KTTransformer.n_encoder_layers = %n_encoder_layers
KTTransformer.n_decoder_layers = %n_decoder_layers
KTTransformer.n_heads = %n_heads
KTTransformer.dropout = %dropout

# Parameters for train:
# ==============================================================================
train.inputs = @make_inputs
train.eval_frequency = 200
train.eval_steps = 20
train.checkpoints_at = {list(range(0,total_steps + 1, 2000))}
train.optimizer = @trax.optimizers.Adam
train.steps = %total_steps
train.model = @KTTransformer
train.lr_schedule_fn = @trax.supervised.lr_schedules.warmup_and_rsqrt_decay
""")

In [None]:
if False:
    inputs = trax.data.inputs.make_inputs()
    train_stream = inputs.train_stream(trax.fastmath.device_count())
    train_eval_stream = inputs.train_eval_stream(trax.fastmath.device_count())
    b = next(train_stream)
    for i, m in enumerate(b):
        print(i, m.shape)
    b

In [None]:
if False:
    model = KTTransformer()
    model.init(trax.shapes.signature(b))
    outs = model(b)
    for i, m in enumerate(outs):
        print(i, m.shape)
    outs

## Training

In [None]:
run_no = 0
prefix = f'model_runs/{run_no:02d}'
output_dir = f'gs://{BUCKET}/{prefix}'
log_dir = output_dir[:-3]

In [None]:
%tensorboard --logdir $log_dir

In [None]:
if TRAIN_MODEL:
    if False:
        init_checkpoint = f'{output_dir}/model.pkl.gz'
    else:
        bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))

    loop = trax.supervised.trainer_lib.train(output_dir, metrics=metrics)