In [None]:
%pwd

In [None]:
import os
os.chdir("../..")

In [None]:
%pwd

In [None]:
import sys
sys.path.append('./src')

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from torch.utils import tensorboard

In [None]:
import logging
logger = logging.getLogger("notebooks.debug")

In [None]:
import utils.logging
utils.logging.setup("conf/logging/default.yaml")

In [None]:
import utils.configs
_ = utils.configs.setup("conf/app.yaml")

In [None]:
import dotenv
assert dotenv.load_dotenv(dotenv_path="conf/envs/dev.env")

---

In [None]:
import time
import datetime
import tqdm.auto as tqdm
import numpy as np
import pyarrow
import pyarrow.dataset

In [None]:
import torch
import torch.nn

In [None]:
LABEL = "label"
DENSE_FEATURES = [f"f{idx}" for idx in range(1,14)]
SPARSE_FEATURES = [f"f{idx}_idx" for idx in range(14,40)]

In [None]:
src_dataset = pyarrow.dataset.dataset(
    "data/joined/compact",
    partitioning = "hive"
)

In [None]:
def measure_speed(
    dataset,
    filter = None,
    limit = None,
    read_params = {},
    payload_fn = None
):
    if limit is None:
        logger.info("getting dataset size...")
        total_records = dataset.count_rows(filter=filter)
        logger.info(f"getting dataset size: done ({total_records} records)")
    else:
        total_records = limit

    logger.info("reading dataset...")
    time_start = time.time()
    pbar = tqdm.tqdm(desc="reading data", total=total_records)
    src_batches = dataset.to_batches(filter=filter, **read_params)
    rows_processed = 0
    for batch_id, batch in enumerate(src_batches, start=1):
        batch = batch.to_pandas()
        pbar.set_postfix({'batches': batch_id}, refresh=False)
        pbar.update(batch.shape[0])
        rows_processed += batch.shape[0]
        if limit is not None and rows_processed >= limit:
            break
        if payload_fn is not None:
            payload_fn(batch)
    pbar.close()

    time_finish = time.time()
    elapsed_time = (time_finish - time_start)
    read_speed = rows_processed / elapsed_time
    logger.info(f"reading dataset: done ({int(elapsed_time)} seconds, {int(read_speed)} rows/sec)")

In [None]:
def get_dense_features(batch, device = "cpu"):
    dense_features = torch.log(torch.tensor(batch[DENSE_FEATURES].to_numpy(dtype="float32"), device=device) + 3)
    dense_features.masked_fill_(dense_features.isnan(), 0)
    return dense_features

In [None]:
def get_sparse_features(batch, device = "cpu"):
    sparse_features = torch.tensor(batch[SPARSE_FEATURES].to_numpy(dtype="int32"), device=device)
    return sparse_features

In [None]:
def get_labels(batch, device = "cpu"):
    labels = torch.tensor(batch[LABEL].to_numpy(dtype="int8"), device=device)
    return labels

In [None]:
def convert_batch(batch, device = "cpu"):
    dense_features = get_dense_features(batch, device)
    sparse_features = get_sparse_features(batch, device)
    labels = get_labels(batch, device)
    return dense_features, sparse_features, labels

In [None]:
# measure_speed(src_dataset) ## warmup the disk cache

In [None]:
# measure_speed(src_dataset, payload_fn = lambda b: convert_batch(b, device="cpu"))

In [None]:
# measure_speed(src_dataset, payload_fn = lambda b: convert_batch(b, device="mps"))

In [None]:
def get_feature_cardinality(dataset):
    max_ids = []
    pbar = tqdm.tqdm("reading dataset", total=dataset.count_rows())
    for batch in src_dataset.to_batches():
        batch = batch.to_pandas()
        pbar.update(batch.shape[0])
        max_ids.append(get_sparse_features(batch, device = "mps").max(dim=0).values)
    return torch.stack(max_ids).max(dim=0).values

In [None]:
sparse_feature_sizes = get_feature_cardinality(src_dataset)
sparse_feature_sizes

In [None]:
sparse_feature_sizes = dict(zip(SPARSE_FEATURES, list(sparse_feature_sizes.data.cpu().numpy())))
sparse_feature_sizes

### Playing aroudn w/ torch components

In [None]:
EMBEDDING_DIM = 64
DENSE_LAYERS = [512,256,EMBEDDING_DIM]
FINAL_LAYERS = [512,512,256,1]
DEVICE = "mps"

In [None]:
def exp_id():
    return f"exp-{datetime.datetime.now().replace(microsecond=0).isoformat()}"

In [None]:
exp_id()

In [None]:
LOG_INTERVAL = 1
VAL_INTERVAL = LOG_INTERVAL * 5

In [None]:
import models.dlrm

In [None]:
model = models.dlrm.DLRM(
    sparse_feature_dim = EMBEDDING_DIM,
    sparse_feature_sizes = [size+100 for size in list(sparse_feature_sizes.values())],
    dense_in_features = len(DENSE_FEATURES),
    dense_layer_sizes = DENSE_LAYERS,
    final_layer_sizes = FINAL_LAYERS,
    dense_device = DEVICE
)

In [None]:
writer = tensorboard.SummaryWriter(log_dir=f"data/exps/{exp_id()}")
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
pbar = tqdm.tqdm("training", total=src_dataset.count_rows())
for batch_id, batch in enumerate(src_dataset.to_batches(), start=1):
    batch = batch.to_pandas()
    dense_features, sparse_features, labels = convert_batch(batch)
    logits = model(dense_features, sparse_features)
    loss = loss_fn(logits.squeeze(-1), labels.to(DEVICE).float())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    pbar.set_postfix({'batches': batch_id, 'loss': loss.item()}, refresh=False)
    if batch_id % LOG_INTERVAL == 0:
        writer.add_scalar('loss/train', loss.item(), batch_id)
    pbar.update(batch.shape[0])