In [1]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [2]:
from data import (
    DataSetSplit,
    FeatureConfig,
    FloatFeatureConfig,
    IdListFeatureConfig,
    SyntheticDataset,
    collate_fn
)

feature_config = FeatureConfig(
    user_id_list_features = [
        IdListFeatureConfig(
            fid = 2,
            num_embeddings = 1001,
            embedding_dim = 64,
        ),
        IdListFeatureConfig(
            fid = 4,
            num_embeddings = 10001,
            embedding_dim = 128,
        )            
    ],
    user_float_features = [
        FloatFeatureConfig(
            fid = 6,
            padding_val = 0,
        ),
        FloatFeatureConfig(
            fid = 8,
            padding_val = 1
        )
    ],
    item_id_list_features = [
        IdListFeatureConfig(
            fid = 13,
            num_embeddings = 1001,
            embedding_dim = 64,
        ),
        IdListFeatureConfig(
            fid = 15,
            num_embeddings = 10001,
            embedding_dim = 128,
        )            
    ],
    item_float_features = [
        FloatFeatureConfig(
            fid = 17,
            padding_val = 0,
        ),
        FloatFeatureConfig(
            fid = 19,
            padding_val = 1
        )
    ]
)

synthetic_data = SyntheticDataset(
    dataset_split = DataSetSplit.TRAIN,
    feature_config = feature_config,
)

In [3]:
from torch.utils.data import DataLoader

data_loader = DataLoader(synthetic_data, collate_fn=collate_fn, batch_size=4)
next(iter(data_loader))

({'2': (tensor([735,  90, 248, 403, 171, 957, 921, 633,  73, 948,   7, 144, 840, 721,
           888, 628, 872, 373, 557], dtype=torch.int32),
   tensor([ 0,  7, 12, 12])),
  '4': (tensor([ 727, 6847, 4638,  788, 6278, 8330, 4180, 3574, 3003, 8547],
          dtype=torch.int32),
   tensor([0, 0, 0, 2]))},
 tensor([[-0.1406, -0.3140],
         [-0.3135, -0.3646],
         [-0.2842, -0.2584],
         [-0.2969,  0.0941]]),
 {'13': (tensor([701, 774, 785, 160, 845, 536, 943, 454, 700, 190, 478, 465, 517, 791,
           948, 860, 156, 290, 353, 160, 562, 839, 229], dtype=torch.int32),
   tensor([ 0,  7, 11, 16])),
  '15': (tensor([3803, 7441, 2267, 7878, 5870, 1619, 4709, 3269, 3430,  935, 2707, 4190,
           6184], dtype=torch.int32),
   tensor([ 0,  8,  8, 12]))},
 tensor([[-0.3519, -0.2478],
         [-0.3756,  1.3885],
         [-0.1763, -0.4541],
         [-0.2459,  3.0867]]),
 tensor([0.0113, 0.8387, 0.5503, 0.7679]))

In [4]:
from model import SparseNNTwoTower

model = SparseNNTwoTower(
    feature_config = feature_config,
    user_sparse_proj_dims = [256, 128],
    user_float_proj_dims = [128, 128],
    user_overarch_proj_dims = [256, 128],
    item_sparse_proj_dims = [256, 128],
    item_float_proj_dims = [128, 128],
    item_overarch_proj_dims = [256, 128],
    output_dim = 128,
)

In [None]:
from train import train, test

train(model=model, dataset=synthetic_data, batch_size=32, num_epoch=1, verbose_log_every_n=10)

INFO:root:epoch 0, batch 10, loss: 63.6168212890625
INFO:root:epoch 0, batch 20, loss: 59.43405532836914
INFO:root:epoch 0, batch 30, loss: 64.50531005859375
INFO:root:epoch 0, batch 40, loss: 56.27800369262695
INFO:root:epoch 0, batch 50, loss: 49.429100036621094
INFO:root:epoch 0, batch 60, loss: 54.393646240234375
INFO:root:epoch 0, batch 70, loss: 60.784759521484375
INFO:root:epoch 0, batch 80, loss: 55.25569534301758
INFO:root:epoch 0, batch 90, loss: 53.60392379760742
INFO:root:epoch 0, batch 100, loss: 51.56278991699219
INFO:root:epoch 0, batch 110, loss: 53.11770248413086
INFO:root:epoch 0, batch 120, loss: 62.186859130859375
INFO:root:epoch 0, batch 130, loss: 52.91277313232422
INFO:root:epoch 0, batch 140, loss: 64.42668151855469
INFO:root:epoch 0, batch 150, loss: 50.59214782714844
INFO:root:epoch 0, batch 160, loss: 57.41949462890625


In [None]:
test(model=model, dataset=synthetic_data, batch_size=64)