In [1]:
%pwd

'/Users/tolya/Documents/code/dlrm/src/research'

In [2]:
import os
os.chdir("../..")

In [3]:
%pwd

'/Users/tolya/Documents/code/dlrm'

In [4]:
import sys
sys.path.append('./src')

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
from torch.utils import tensorboard

In [7]:
import logging
logger = logging.getLogger("notebooks.debug")

In [8]:
import utils.logging
utils.logging.setup("conf/logging/default.yaml")

In [9]:
import utils.configs
_ = utils.configs.setup("conf/app.yaml")

2023-11-26 00:28:53,529 - utils.configs - INFO - loading app config 'conf/app.yaml'...
2023-11-26 00:28:53,531 - utils.configs - INFO - loading app config 'conf/app.yaml': done


In [10]:
import dotenv
assert dotenv.load_dotenv(dotenv_path="conf/envs/dev.env")

---

In [11]:
import time
import datetime
import tqdm.auto as tqdm
import numpy as np
import pyarrow
import pyarrow.dataset

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
import torch
import torch.nn

In [13]:
LABEL = "label"
DENSE_FEATURES = [f"f{idx}" for idx in range(1,14)]
SPARSE_FEATURES = [f"f{idx}_idx" for idx in range(14,40)]

In [14]:
src_dataset = pyarrow.dataset.dataset(
    "data/joined/compact",
    partitioning = "hive"
)

In [15]:
def measure_speed(
    dataset,
    filter = None,
    limit = None,
    read_params = {},
    payload_fn = None
):
    if limit is None:
        logger.info("getting dataset size...")
        total_records = dataset.count_rows(filter=filter)
        logger.info(f"getting dataset size: done ({total_records} records)")
    else:
        total_records = limit

    logger.info("reading dataset...")
    time_start = time.time()
    pbar = tqdm.tqdm(desc="reading data", total=total_records)
    src_batches = dataset.to_batches(filter=filter, **read_params)
    rows_processed = 0
    for batch_id, batch in enumerate(src_batches, start=1):
        batch = batch.to_pandas()
        pbar.set_postfix({'batches': batch_id}, refresh=False)
        pbar.update(batch.shape[0])
        rows_processed += batch.shape[0]
        if limit is not None and rows_processed >= limit:
            break
        if payload_fn is not None:
            payload_fn(batch)
    pbar.close()

    time_finish = time.time()
    elapsed_time = (time_finish - time_start)
    read_speed = rows_processed / elapsed_time
    logger.info(f"reading dataset: done ({int(elapsed_time)} seconds, {int(read_speed)} rows/sec)")

In [16]:
def get_dense_features(batch, device = "cpu"):
    dense_features = torch.log(torch.tensor(batch[DENSE_FEATURES].to_numpy(dtype="float32"), device=device) + 3)
    dense_features.masked_fill_(dense_features.isnan(), 0)
    return dense_features

In [17]:
def get_sparse_features(batch, device = "cpu"):
    sparse_features = torch.tensor(batch[SPARSE_FEATURES].to_numpy(dtype="int32"), device=device)
    return sparse_features

In [18]:
def get_labels(batch, device = "cpu"):
    labels = torch.tensor(batch[LABEL].to_numpy(dtype="int8"), device=device)
    return labels

In [19]:
def convert_batch(batch, device = "cpu"):
    dense_features = get_dense_features(batch, device)
    sparse_features = get_sparse_features(batch, device)
    labels = get_labels(batch, device)
    return dense_features, sparse_features, labels

In [20]:
# measure_speed(src_dataset) ## warmup the disk cache

In [21]:
# measure_speed(src_dataset, payload_fn = lambda b: convert_batch(b, device="cpu"))

In [22]:
# measure_speed(src_dataset, payload_fn = lambda b: convert_batch(b, device="mps"))

In [23]:
def get_feature_cardinality(dataset):
    max_ids = []
    pbar = tqdm.tqdm("reading dataset", total=dataset.count_rows())
    for batch in src_dataset.to_batches():
        batch = batch.to_pandas()
        pbar.update(batch.shape[0])
        max_ids.append(get_sparse_features(batch, device = "mps").max(dim=0).values)
    return torch.stack(max_ids).max(dim=0).values

In [24]:
sparse_feature_sizes = get_feature_cardinality(src_dataset)
sparse_feature_sizes

100%|██████████| 195841983/195841983 [00:19<00:00, 10170011.73it/s]


tensor([100000,   4460,   7123,    393,   5313,      2,   3353,    356,      8,
        100001,  16640,  19153,      8,    261,   1515,     11,      3,     20,
             7, 100000, 100001, 100001,  18156,   4112,      8,     10],
       device='mps:0', dtype=torch.int32)

In [25]:
sparse_feature_sizes = dict(zip(SPARSE_FEATURES, list(sparse_feature_sizes.data.cpu().numpy())))
sparse_feature_sizes

{'f14_idx': 100000,
 'f15_idx': 4460,
 'f16_idx': 7123,
 'f17_idx': 393,
 'f18_idx': 5313,
 'f19_idx': 2,
 'f20_idx': 3353,
 'f21_idx': 356,
 'f22_idx': 8,
 'f23_idx': 100001,
 'f24_idx': 16640,
 'f25_idx': 19153,
 'f26_idx': 8,
 'f27_idx': 261,
 'f28_idx': 1515,
 'f29_idx': 11,
 'f30_idx': 3,
 'f31_idx': 20,
 'f32_idx': 7,
 'f33_idx': 100000,
 'f34_idx': 100001,
 'f35_idx': 100001,
 'f36_idx': 18156,
 'f37_idx': 4112,
 'f38_idx': 8,
 'f39_idx': 10}

### Playing aroudn w/ torch components

In [26]:
EMBEDDING_DIM = 64
DENSE_LAYERS = [512,256,EMBEDDING_DIM]
FINAL_LAYERS = [512,512,256,1]
DEVICE = "mps"

In [27]:
def exp_id():
    return f"exp-{datetime.datetime.now().replace(microsecond=0).isoformat()}"

In [28]:
exp_id()

'exp-2023-11-26T00:29:13'

In [29]:
LOG_INTERVAL = 1
VAL_INTERVAL = LOG_INTERVAL * 5

In [30]:
import models.dlrm

In [31]:
model = models.dlrm.DLRM(
    sparse_feature_dim = EMBEDDING_DIM,
    sparse_feature_sizes = [size+100 for size in list(sparse_feature_sizes.values())],
    dense_in_features = len(DENSE_FEATURES),
    dense_layer_sizes = DENSE_LAYERS,
    final_layer_sizes = FINAL_LAYERS,
    dense_device = DEVICE
)

In [32]:
writer = tensorboard.SummaryWriter(log_dir=f"data/exps/{exp_id()}")
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
pbar = tqdm.tqdm("training", total=src_dataset.count_rows())
for batch_id, batch in enumerate(src_dataset.to_batches(), start=1):
    batch = batch.to_pandas()
    dense_features, sparse_features, labels = convert_batch(batch)
    logits = model(dense_features, sparse_features)
    loss = loss_fn(logits.squeeze(-1), labels.to(DEVICE).float())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    pbar.set_postfix({'batches': batch_id, 'loss': loss.item()}, refresh=False)
    if batch_id % LOG_INTERVAL == 0:
        writer.add_scalar('loss/train', loss.item(), batch_id)
    pbar.update(batch.shape[0])

100%|██████████| 195841983/195841983 [3:33:44<00:00, 27397.09it/s, batches=1510, loss=0.127]  

In [36]:
import utils.filesystem
utils.filesystem.mkdir("data/checkpoints/exp-2023-11-26T00:29:14")

2023-11-26 13:02:40,961 - utils.filesystem - INFO - creating path 'data/checkpoints/exp-2023-11-26T00:29:14'...
2023-11-26 13:02:40,965 - utils.filesystem - INFO - creating path 'data/checkpoints/exp-2023-11-26T00:29:14': done


In [35]:
checkpoint = {}
checkpoint['model_state_dict'] = model.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint_path = "data/checkpoints/exp-2023-11-26T00:29:14/first-attempt.pt"
torch.save(checkpoint, checkpoint_path)

In [37]:
checkpoint = torch.load(checkpoint_path)

model = models.dlrm.DLRM(
    sparse_feature_dim = EMBEDDING_DIM,
    sparse_feature_sizes = [size+100 for size in list(sparse_feature_sizes.values())],
    dense_in_features = len(DENSE_FEATURES),
    dense_layer_sizes = DENSE_LAYERS,
    final_layer_sizes = FINAL_LAYERS,
    dense_device = DEVICE
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [44]:
model.sparse_arch.embeddings[0].weight[0]

tensor([-2.3739,  0.0116, -0.7886, -0.1465,  0.9212, -0.4239, -0.2930,  0.1731,
        -0.1319, -1.0849,  0.1730, -2.4127, -0.6039,  0.7543,  0.8656,  1.3840,
         0.5706, -0.1873, -0.1842,  0.4883,  0.7234, -1.1200,  0.7445,  0.0559,
         1.2161,  1.0311, -0.3917,  0.2094,  0.9472, -0.1171,  0.5646, -0.9953,
        -0.2787,  0.3100,  0.5355, -0.4890,  0.3734,  0.6972,  1.1223,  0.1922,
        -0.8874,  0.9448, -0.4160,  0.6944,  1.1935, -0.1178,  0.4508,  0.2079,
        -0.5156, -1.0846, -0.5752, -1.1570, -0.9603, -0.2282, -1.0797,  0.3148,
         1.2818,  1.3543,  0.4588, -1.2980,  1.2508, -1.2206,  0.1699,  0.0982],
       grad_fn=<SelectBackward0>)