In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

import pathlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np

from dataset import CriteoAdDataset
from utils import CategoryEncoder
from models import DeepFM

## Dataset

In [3]:
data_dir = pathlib.Path("./data/criteo-ad-data")

category_encoder = CategoryEncoder()
train_dataset = CriteoAdDataset(
    data_dir, type="train", nums=10000, category_encoder=category_encoder
)
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_dataset = CriteoAdDataset(data_dir, type="val", nums=10000, category_encoder=category_encoder)
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle=False)


In [4]:
label, count_features, category_features = next(iter(train_dataloader))
label.shape, count_features.shape, category_features.shape


(torch.Size([10]), torch.Size([10, 13]), torch.Size([10, 26]))

## Model

In [5]:
embedding_dims = 20
category_feature_names = train_dataset.category_feature_columns
category_cardinalities = train_dataset.category_cardinalities

dense_embedding_in_features = len(train_dataset.count_feature_columns)
dense_embedding_hidden_features = 30
deep_layer_out_features = 10

model = DeepFM(
    embedding_dims=embedding_dims,
    category_cardinalities=category_cardinalities,
    dense_embedding_in_features=dense_embedding_in_features,
    dense_embedding_hidden_features=dense_embedding_hidden_features,
    deep_layer_out_features=deep_layer_out_features,
)


In [6]:
summary(
    model,
    count_features=count_features,
    category_features=category_features,
    category_feature_names=category_feature_names,
)


Layer (type:depth-idx)                   Param #
DeepFM                                   --
├─SparseEmbedding: 1-1                   --
│    └─ModuleDict: 2-1                   --
│    │    └─EmbeddingBag: 3-1            3,520
│    │    └─EmbeddingBag: 3-2            7,740
│    │    └─EmbeddingBag: 3-3            41,700
│    │    └─EmbeddingBag: 3-4            105,700
│    │    └─EmbeddingBag: 3-5            34,520
│    │    └─EmbeddingBag: 3-6            500
│    │    └─EmbeddingBag: 3-7            40,720
│    │    └─EmbeddingBag: 3-8            94,500
│    │    └─EmbeddingBag: 3-9            200
│    │    └─EmbeddingBag: 3-10           23,000
│    │    └─EmbeddingBag: 3-11           10,960
│    │    └─EmbeddingBag: 3-12           100
│    │    └─EmbeddingBag: 3-13           110,440
│    │    └─EmbeddingBag: 3-14           100,760
│    │    └─EmbeddingBag: 3-15           180
│    │    └─EmbeddingBag: 3-16           260
│    │    └─EmbeddingBag: 3-17           50,520
│    │    └─Embed

## Train

In [7]:
def train(
    train_dataloader: DataLoader,
    model: nn.Module,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    log_interval: int = 10000,
):
    size = len(train_dataloader.dataset)
    sum_loss = 0
    for batch, (labels, count_features, category_features) in tqdm(
        enumerate(train_dataloader, 1), total=len(train_dataloader), unit="iter"
    ):
        model.train()
        logits = model(
            count_features=count_features,
            category_features=category_features,
            category_feature_names=category_feature_names,
        )
        loss = criterion(logits.squeeze(), labels.float())
        optimizer.zero_grad()
        loss.backward()
        # to prevent gradient explosion
        nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

        sum_loss += loss.item()

        if batch % log_interval == 0:
            mean_loss = sum_loss / log_interval
            tqdm.write(f"loss: {mean_loss:>7f} [{batch * len(labels):>5d}/{size:>5d}]")
            sum_loss = 0


def eval(dataloader: DataLoader, model: nn.Module, criterion: nn.Module):
    model.eval()
    sum_loss = 0
    num_batches = len(dataloader)
    labels_list = []
    preds_list = []
    with torch.no_grad():
        for labels, count_features, category_features in tqdm(
            dataloader, total=len(dataloader), unit="iter"
        ):
            logits = model(
                count_features=count_features,
                category_features=category_features,
                category_feature_names=category_feature_names,
            )
            preds = F.sigmoid(logits.squeeze())
            sum_loss += criterion(preds, labels.float()).item()

            labels_list.extend(labels.tolist())
            preds_list.extend(preds.tolist())

    mean_loss = sum_loss / num_batches
    accuracy = accuracy_score(labels_list, np.round(np.asarray(preds_list)))
    roc_auc = roc_auc_score(labels_list, preds_list)
    pr_auc = average_precision_score(labels_list, preds_list)

    return mean_loss, accuracy, roc_auc, pr_auc


In [8]:
epochs = 3
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in tqdm(range(epochs), unit="epoch"):
    train(train_dataloader, model, criterion, optimizer, log_interval=2000)
    mean_loss, accuracy, roc_auc, pr_auc = eval(train_dataloader, model, criterion)
    tqdm.write(f"Train | avg loss: {mean_loss:.4f}, accuracy: {accuracy:.4f}, roc_auc: {roc_auc:.4f}, pr_auc: {pr_auc:.4f}")
    mean_loss, accuracy, roc_auc, pr_auc = eval(val_dataloader, model, criterion)
    tqdm.write(f"Val | avg loss: {mean_loss:.4f}, accuracy: {accuracy:.4f}, roc_auc: {roc_auc:.4f}, pr_auc: {pr_auc:.4f}")


  0%|          | 0/3 [00:00<?, ?epoch/s]

  0%|          | 0/1000 [00:00<?, ?iter/s]

  0%|          | 0/1000 [00:00<?, ?iter/s]

Train | avg loss: 0.7416, accuracy: 0.7890, roc_auc: 0.7100, pr_auc: 0.4179


  0%|          | 0/1000 [00:00<?, ?iter/s]

Val | avg loss: 0.7375, accuracy: 0.7381, roc_auc: 0.6072, pr_auc: 0.3474


  0%|          | 0/1000 [00:00<?, ?iter/s]

  0%|          | 0/1000 [00:00<?, ?iter/s]

Train | avg loss: 0.7256, accuracy: 0.7827, roc_auc: 0.7787, pr_auc: 0.5003


  0%|          | 0/1000 [00:00<?, ?iter/s]

Val | avg loss: 0.7235, accuracy: 0.7397, roc_auc: 0.6445, pr_auc: 0.3820


  0%|          | 0/1000 [00:00<?, ?iter/s]

  0%|          | 0/1000 [00:00<?, ?iter/s]

Train | avg loss: 0.7154, accuracy: 0.8159, roc_auc: 0.8548, pr_auc: 0.6526


  0%|          | 0/1000 [00:00<?, ?iter/s]

Val | avg loss: 0.7259, accuracy: 0.7411, roc_auc: 0.6310, pr_auc: 0.3718
