Two Tower (https://storage.googleapis.com/pub-tools-public-publication-data/pdf/6c8a86c981a62b0126a11896b7f6ae0dae4c3566.pdf)
- as the name suggests, the model consists of two MLP layers
- user & item encoder: neural nets that take features about the item and users and encode them into fixed dimension embedding vectors
- user and item embedding vectors are finally multiplied to get the prediction
- this architecture is widely used for retrieval because of it's scalibility
  - since user & item features don't interact early on, we don't need to pass the user & item features through the complex neural network to get our prediction
  - user & item feature embeddings can be precomputed, and at serve time, we just multiply the user's embedding vector with all the item embeddings
  - instead of naively multiplying the user embedding with all item embeddings, which has a linear time complexity, a more efficient approach would be to do a approximate nearest neighbor (ANN) to narrow down the number of items by searching for the top k most similar items based on the user embedding (since similar vectors will have a larger dot product), before passing to a ranking model
  
The model architecture is nicely illustrated here ([image source](https://www.linkedin.com/pulse/personalized-recommendations-iv-two-tower-models-gaurav-chakravorty/)):

<img src="model_illustration.png" style="width:70%">

In [19]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix


In [2]:
class TwoTower(nn.Module):
    def __init__(
        self,
        n_item_features,
        n_user_features,
        n_item_fields,
        n_user_fields,
        embed_dim,
        mlp_out_embed_dim,
        item_mlp_dims,
        user_mlp_dims=None,
        dropout=0.2,
    ):
        super().__init__()
        self.n_item_fields = n_item_fields
        self.n_user_fields = n_user_fields
        
        # item tower
        self.item_embedding = nn.Embedding(n_item_features, embed_dim)
        self.item_mlp_input_dim = n_item_fields * embed_dim
        item_mlp_input_dim = self.item_mlp_input_dim
        item_mlp_layers = []
        for dim in item_mlp_dims:
            item_mlp_layers.append(nn.Linear(item_mlp_input_dim, dim))
            item_mlp_layers.append(nn.BatchNorm1d(dim))
            item_mlp_layers.append(nn.ReLU())
            item_mlp_layers.append(nn.Dropout(dropout))
            item_mlp_input_dim = dim
        item_mlp_layers.append(nn.Linear(item_mlp_input_dim, mlp_out_embed_dim))
        self.item_mlp = nn.Sequential(*item_mlp_layers)

        # user tower
        self.user_embedding = nn.Embedding(n_user_features, embed_dim)
        self.user_mlp_input_dim = n_user_fields * embed_dim
        user_mlp_input_dim = self.user_mlp_input_dim
        user_mlp_layers = []
        user_mlp_dims = (
            user_mlp_dims if user_mlp_dims else item_mlp_dims
        )  # if user_mlp_dims not specified, use the same mlp structure as item
        for dim in user_mlp_dims:
            user_mlp_layers.append(nn.Linear(user_mlp_input_dim, dim))
            user_mlp_layers.append(nn.BatchNorm1d(dim))
            user_mlp_layers.append(nn.ReLU())
            user_mlp_layers.append(nn.Dropout(dropout))
            user_mlp_input_dim = dim
        user_mlp_layers.append(nn.Linear(user_mlp_input_dim, mlp_out_embed_dim))
        self.user_mlp = nn.Sequential(*user_mlp_layers)

    def forward(self, x):
        # x shape : [[item_feature_1...item_feature_n, user_feature_1...user_feature_n]]
        # concat each item's embeddings to a 1d vector
        item_embed = self.item_embedding(x[:, :self.n_item_fields]).view(-1, self.item_mlp_input_dim)
        # concat each user's embeddings to a 1d vector
        user_embed = self.user_embedding(x[:, self.n_item_fields:]).view(-1, self.user_mlp_input_dim)
        # pass feature embeddings through mlp to output user & item embeddings
        item_mlp_embed = self.item_mlp(item_embed)
        user_mlp_embed = self.user_mlp(user_embed)
        # dot product of item & user embeddings
        dot = (item_mlp_embed * user_mlp_embed).sum(dim=1)
        return torch.sigmoid(dot)
    
    def predict(self, x):
        self.eval()
        with torch.no_grad():
            return self.forward(x)

In [3]:
def train(model, dataloader, epochs=20, lr=0.001):
    device = (
        torch.device("cuda:0") if torch.cuda.is_available(
        ) else torch.device("cpu")
    )
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss()
    training_history = []
    for epoch in range(epochs):
        epoch_loss = 0
        for x, y in dataloader:
            y_pred = model.forward(x)
            loss = criterion(y_pred, y)
            epoch_loss += loss
            model.zero_grad()
            loss.backward()
            optimizer.step()
        epoch_loss /= len(dataloader)
        training_history.append(epoch_loss)
        if epoch%10 == 0:
            print(f"Epoch {epoch}: {epoch_loss:.4f}")
    return model, training_history

# Data Preparation
- X is an array of feature indices: [[item_feature_1...item_feature_n, user_feature_1...user_feature_n]]
- y is just a (n_sample, 1) array of the ground truth

In [4]:
import sys
sys.path.append('..')
import utils

In [6]:
rating, item, user = utils.get_movielens()

In [7]:
item_label = utils.get_items_label_encoding(item, return_df=False)
user_label = utils.get_users_label_encoding(user, return_df=False)

In [8]:
# concat item & user feature matrix to get X
X = np.hstack((item_label[rating['item_id']-1,:], user_label[rating['user_id']-1,:])) # offset -1 since item&user id starts with 1

In [9]:
# convert rating to 1/0
threshold = 3
y = np.where(rating['rating'].to_numpy()>=threshold, 1, 0)

Train test split

Here, for simplicity, we are only using a random split, with 80% as the train set, and 20% as the test set. In practice, the splitting maybe done by user, e.g. 80/20 split of a user's rating/interaction history.

In [10]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [11]:
dataset = data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train).float())
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Model

In [12]:
n_item_features = np.max(item_label)+1
n_user_features = np.max(user_label)+1
n_item_fields = item_label.shape[-1]
n_user_fields = user_label.shape[-1]
embed_dim = 50
mlp_out_embed_dim = 64
item_mlp_dims = [128, 64]

model = TwoTower(
    n_item_features=n_item_features,
    n_user_features=n_user_features,
    n_item_fields=n_item_fields,
    n_user_fields=n_user_fields,
    embed_dim=embed_dim,
    mlp_out_embed_dim=mlp_out_embed_dim,
    item_mlp_dims=item_mlp_dims,
    user_mlp_dims=None,
    dropout=0.1,
)

In [13]:
model, history = train(model=model, 
                       dataloader=train_dataloader, 
                       epochs=100, 
                       lr=0.001)

Epoch 0: 0.4428
Epoch 10: 0.3273
Epoch 20: 0.2862
Epoch 30: 0.2574
Epoch 40: 0.2353
Epoch 50: 0.2210
Epoch 60: 0.2081
Epoch 70: 0.1973
Epoch 80: 0.1894
Epoch 90: 0.1820


In [16]:
y_pred_soft = model.predict(torch.from_numpy(X_test))

In [20]:
y_pred = np.where(y_pred_soft.numpy() > 0.5, 1, 0)

acc = accuracy_score(y_pred, y_test)
auc = roc_auc_score(y_test, y_pred_soft)
f1 = f1_score(y_test, y_pred)
cf_mat = confusion_matrix(y_test, y_pred)

In [21]:
print(f"Accuracy: {acc}")
print(f"AUC: {auc}")
print(f"F1 Score: {f1}")

Accuracy: 0.8256
AUC: 0.778323289898644
F1 Score: 0.8967314069161535


In [22]:
cf_mat

array([[ 1368,  2100],
       [ 1388, 15144]])

In [69]:
top = 10
top_k_precision, top_k_recall = utils.top_k_precision_recall(
    model, X_train, X_test, y_test, item_label, top=top, user_id_col=item_label.shape[1])

print(f"Top K({top}) precision: {top_k_precision}")
print(f"Top K({top}) recall: {top_k_recall}")

Top K(10) precision: 0.013829787234042556
Top K(10) recall: 0.008143358090614682
