Deep Factorization Machines
- as its name suggests, there is a FM component & a Deep component
- FM: model 2-way interaction
- Deep MLP: capture complex feature interactions & non-linearities
- similar to Wide & Deep, which consists of a linear (wide) and mlp (deep) layer, except the wide layer is the FM layer

The model architecture is nicely illustrated here (image by d21.ai):

<img src="model_illustration.png" style="width:80%">

In [4]:
import numpy as np
import torch
import torch.nn as nn

In [None]:
class DeepFM(nn.module):
    def __init__(self, n_fields, n_features, embed_dim, mlp_dims=[128, 64], dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(n_features, embed_dim)
        self.fc = nn.Embedding(n_features, 1)
        self.bias = nn.Parameter(torch.zeros((1,0)))
        self.mlp_input_dim = n_fields * embed_dim
        input_dim = self.mlp_input_dim
        mlp_layers = []
        for dim in mlp_dims:
            mlp_layers.append(nn.Linear(input_dim, dim))
            mlp_layers.append(nn.BatchNorm1d(dim))
            mlp_layers.append(nn.ReLU())
            mlp_layers.append(nn.Dropout(dropout))
            input_dim = dim
        mlp_layers.append(nn.Linear(mlp_input_dim, 1))
        self.mlp = nn.Sequential(*mlp_layers)
        
        
    def factorization_machine(self, x):
        first_order = self.fc(x).sum(dim=1) + self.bias
        sq_sum = self.embedding(x).sum(dim=1)**2
        sum_sq = (self.embedding(x)**2).sum(dim=1)
        second_order = 0.5 * (sq_sum - sum_sq).sum(dim=1, keepdim=True)
        return first_order + second_order
    
    def forward(self, x):
        fm_out = self.factorization_machine(x)
        mlp_input = self.embedding(x).view(-1, self.mlp_input_dim)
        mlp_out = self.mlp(mlp_input)
        return torch.sigmoid(fm_out + mlp_out)
    
    def predict(self, x):
        self.eval()
        with torch.no_grad():
            return self.forward(x)
            
            
           
        

In [3]:
t = [
    [[1,2,3], [4,5,6]],
    [[1,2,3], [4,5,6]],
]

torch.tensor(t).view(-1,6)

tensor([[1, 2, 3, 4, 5, 6],
        [1, 2, 3, 4, 5, 6]])