In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("..")

# Prepare data


In [4]:
import polars as pl
import numpy as np
from tensorflow.python.data import Dataset, AUTOTUNE


def train_test_split(df, train_frac, seed=42):
    df = df.with_columns(pl.all().shuffle(seed)).with_row_count()
    df_train = df.filter(pl.col("row_nr") < pl.col("row_nr").max() * train_frac).drop("row_nr")
    df_test = df.filter(pl.col("row_nr") >= pl.col("row_nr").max() * train_frac).drop("row_nr")
    
    return df_train, df_test


def to_dataset(df, batch_size, shuffle=True, buffer_size=10_000):
    inputs, labels = df.select(pl.all().exclude("label")), df.select(pl.col("label"))
    ds = Dataset.from_tensor_slices((inputs, labels)).cache()
    if shuffle:
        ds = ds.shuffle(buffer_size)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    
    return ds


def prepare_data(filename, batch_size):
    df = pl.scan_csv(filename, separator="\t").select(pl.col("^cat\d+$"), pl.col("click").alias("label")).collect()
    
    num_embeddings = len(np.unique(df.select(pl.all().exclude("label")).to_numpy())) + 1
    
    df_train, df_test = train_test_split(df, train_frac=0.9)
    df_train, df_val = train_test_split(df_train, train_frac=0.98)
    
    ds_train = to_dataset(df_train, batch_size)
    ds_val = to_dataset(df_val, batch_size, shuffle=False)
    ds_test = to_dataset(df_test, batch_size, shuffle=False)

    return ds_train, ds_val, ds_test, num_embeddings


filename = "../data/criteo_attribution_dataset.tsv"
batch_size = 2**14
ds_train, ds_val, ds_test, num_embeddings = prepare_data(filename, batch_size)

# Benchmark models


In [7]:
from timeit import default_timer as timer
import tensorflow as tf


class EpochTimer(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()

        self._reset()

    def _reset(self):
        self._step_durations = []
        self.step_durations = []
        self.epoch_durations = []

    def on_train_begin(self, logs=None):
        self._reset()

    def on_batch_begin(self, batch, logs=None):
        self._step_start_timestamp = timer()

    def on_batch_end(self, batch, logs=None):
        self._step_durations.append(timer() - self._step_start_timestamp)

    def on_epoch_begin(self, epoch, logs=None):
        self._epoch_start_timestamp = timer()

    def on_epoch_end(self, epoch, logs=None):
        self.step_durations.append(np.mean(self._step_durations))
        self._step_durations = []
        self.epoch_durations.append(timer() - self._epoch_start_timestamp)


def train(
    model,
    ds_train,
    ds_val,
    lr,
    epochs,
    verbose=1,
):
    tf.keras.backend.clear_session()

    optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=lr)
    model.compile(optimizer=optimizer, loss="binary_crossentropy")
    timer_callback = EpochTimer()
    callbacks = [
        timer_callback,
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=2, mode="min", verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, verbose=1),
    ]
    history = model.fit(
        ds_train,
        validation_data=ds_val,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose,
    )
    hist_dict = history.history
    hist_dict["epoch_duration"] = timer_callback.epoch_durations
    hist_dict["step_duration"] = timer_callback.step_durations

    return hist_dict

In [8]:
from models.tensorflow.gdcn import GDCNS

gcdn_model = GDCNS(
    dim_input=ds_train.element_spec[0].shape[1],
    num_embedding=num_embeddings,
    dim_embedding=8,
    num_cross=3,
    num_hidden=3,
    dim_hidden=128,
)

hist = train(gcdn_model, ds_train, ds_val, lr=1e-1, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 3: ReduceLROnPlateau reducing learning rate to 0.010000000149011612.
Epoch 4/10
Epoch 5/10
Epoch 5: ReduceLROnPlateau reducing learning rate to 0.0009999999776482583.
Epoch 6/10
Epoch 7/10
Epoch 7: ReduceLROnPlateau reducing learning rate to 9.999999310821295e-05.
Epoch 8/10
Epoch 9/10
Epoch 9: ReduceLROnPlateau reducing learning rate to 9.999999019782991e-06.
Epoch 9: early stopping


In [9]:
y_pred = gcdn_model.predict(ds_test)

