In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import pathlib

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np

import sys
sys.path.append('../..')

from dataset import CriteoAdDataset
from models import DeepFM

## Dataset


In [None]:
data_dir = pathlib.Path("data/criteo-ad-data")
train_dataset = CriteoAdDataset(data_dir, type="train")
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)


In [None]:
label, count_features, category_features = next(iter(train_dataloader))
label.shape, count_features.shape, category_features.shape


## DeepFM

In [None]:
embedding_dims = 20
category_feature_names = train_dataset.category_feature_columns
category_cardinalities = train_dataset.category_cardinalities

dense_embedding_in_features = len(train_dataset.count_feature_columns)
dense_embedding_hidden_features = 30
deep_layer_out_features = 10

deepfm = DeepFM(
    embedding_dims=embedding_dims,
    category_cardinalities=category_cardinalities,
    dense_embedding_in_features=dense_embedding_in_features,
    dense_embedding_hidden_features=dense_embedding_hidden_features,
    deep_layer_out_features=deep_layer_out_features,
)

logits = deepfm(
    count_features=count_features,
    category_features=category_features,
    category_feature_names=category_feature_names,
)
logits.shape


In [None]:
print(deepfm)

In [None]:
summary(
    deepfm,
    count_features=count_features,
    category_features=category_features,
    category_feature_names=category_feature_names,
)
