# Deep Factorization Machines

For real-world data where inherent feature crossing structures are usually very complex and nonlinear, second-order feature interactions  generally used in factorization machines in practice are often insufficient. Modeling higher degrees of feature combinations with factorization machines is possible theoretically but it is usually not adopted due to numerical instability and high computational complexity. One effective solution is using deep neural networks. 

Deep neural networks are powerful in feature representation learning and have the potential to learn sophisticated feature interactions. As such, it is natural to integrate deep neural networks to factorization machines. Adding nonlinear transformation layers to factorization machines gives it the capability to model both low-order feature combinations and high-order feature combinations. Moreover, non-linear inherent structures from inputs can also be captured with deep neural networks. As such, we will train a representative model named deep factorization machines (DeepFM) [Guo et al., 2017] which combine FM and deep neural networks.

DeepFM consists of an FM component and a deep component which are integrated in a parallel structure. The FM component is the same as the 2-way factorization machines which is used to model the low-order feature interactions. The deep component is a multi-layered perceptron that is used to capture high-order feature interactions and nonlinearities. These two components share the same inputs/embeddings and their outputs are summed up as the final prediction. It is worth pointing out that the spirit of DeepFM resembles that of the Wide & Deep architecture which can capture both memorization and generalization. The advantages of DeepFM over the Wide & Deep model is that it reduces the effort of hand-crafted feature engineering by identifying feature combinations automatically.

![](https://drive.google.com/uc?id=1KXC_8TRNC5Dj1w_NyDfyagzxAnbQDABb)



* ntegrating neural networks to FM enables it to model complex and high-order 
interactions.

* DeepFM outperforms the original FM on the advertising dataset.


# Model implementation in PyTorch

In [None]:
import numpy as np
import pandas as pd
import tqdm
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
class MovieLensDataset(Dataset):
    """
        MovieLens 1M Dataset
        Data preparation: treat samples with a rating less than 3 as negative samples
        :param dataset_path: MovieLens dataset path
    """

    def __init__(self, dataset_path, sep='::', engine='python', header=None):
        # Read the data into a Pandas dataframe
        data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header).to_numpy()[:, :3]

        # Retrieve the items and ratings data
        self.items = data[:, :2].astype(np.int) - 1  # -1 because ID begins from 1
        self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32)

        # Get the range of the items
        self.field_dims = np.max(self.items, axis=0) + 1

        # Initialize NumPy arrays to store user and item indices
        self.user_field_idx = np.array((0,), dtype=np.long)
        self.item_field_idx = np.array((1,), dtype=np.long)

    def __len__(self):
        """
        :return: number of total ratings
        """
        return self.targets.shape[0]

    def __getitem__(self, index):
        """
        :param index: current index
        :return: the items and ratings at current index
        """
        return self.items[index], self.targets[index]

    def __preprocess_target(self, target):
        """
        Preprocess the ratings into negative and positive samples
        :param target: ratings
        :return: binary ratings (0 or 1)
        """
        target[target <= 3] = 0  # ratings less than or equal to 3 classified as 0
        target[target > 3] = 1  # ratings bigger than 3 classified as 1
        return target

In [None]:
class DeepFM(nn.Module):
  """
  A Pytorch implementation of Deep Factorization Model
  """

  def __init__(self, field_dims, embed_dim, mlp_dims, dropout):
    super(DeepFM, self).__init__()
    self.linear = FeaturesLinear(field_dims)
    self.fm = FactorizationMachine(reduce_sum=True)
    self.embedding = FeaturesEmbedding(field_dims, embed_dim)
    self.embed_output_dim = len(field_dims) * embed_dim
    self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout)

  def forward(self, x):
    """
    :param x: Long tensor of size (batch_size, num_fields)
    """
    embed_x = self.embedding(x)
    x = self.linear(x) + self.fm(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim))
    return torch.sigmoid(x.squeeze(1))

In [None]:
class FeaturesLinear(torch.nn.Module):
    """
    Class to perform a linear transformation on the features
    """

    def __init__(self, field_dims, output_dim=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return torch.sum(self.fc(x), dim=1) + self.bias


class FeaturesEmbedding(torch.nn.Module):
    """
    Class to get feature embeddings
    """

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)


class MultiLayerPerceptron(torch.nn.Module):
    """
    Class to instantiate a Multilayer Perceptron model
    """

    def __init__(self, input_dim, embed_dims, dropout, output_layer=True):
        super().__init__()
        layers = list()
        for embed_dim in embed_dims:
            layers.append(torch.nn.Linear(input_dim, embed_dim))
            layers.append(torch.nn.BatchNorm1d(embed_dim))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(p=dropout))
            input_dim = embed_dim
        if output_layer:
            layers.append(torch.nn.Linear(input_dim, 1))
        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        return self.mlp(x)


class FactorizationMachine(torch.nn.Module):
    """
        Class to instantiate a Factorization Machine model
    """

    def __init__(self, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

In [None]:
def fit(model, optimizer, data_loader, criterion, device, log_interval=1000):
    """
    Train the model
    :param model: choice of model
    :param optimizer: choice of optimizer
    :param data_loader: data loader class
    :param criterion: choice of loss function
    :param device: choice of device
    :return: loss being logged
    """
    # Step into train mode
    model.train()
    total_loss = 0
    for i, (fields, target) in enumerate(tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        loss = criterion(y, target.float())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Log the total loss for every 1000 runs
        if (i + 1) % log_interval == 0:
            print('    - loss:', total_loss / log_interval)
            total_loss = 0

def test(model, data_loader, device):
    """
    Evaluate the model
    :param model: choice of model
    :param data_loader: data loader class
    :param device: choice of device
    :return: AUC score
    """
    # Step into evaluation mode
    model.eval()
    targets, predicts = list(), list()
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.tolist())
            predicts.extend(y.tolist())

    # Return AUC score between predicted ratings and actual ratings
    return roc_auc_score(targets, predicts)

In [None]:
# get the data
!wget http://files.grouplens.org/datasets/movielens/ml-1m.zip
!unzip ml-1m.zip

--2021-02-09 06:12:18--  http://files.grouplens.org/datasets/movielens/ml-1m.zip
Resolving files.grouplens.org (files.grouplens.org)... 128.101.65.152
Connecting to files.grouplens.org (files.grouplens.org)|128.101.65.152|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5917549 (5.6M) [application/zip]
Saving to: ‘ml-1m.zip’


2021-02-09 06:12:18 (27.2 MB/s) - ‘ml-1m.zip’ saved [5917549/5917549]

Archive:  ml-1m.zip
   creating: ml-1m/
  inflating: ml-1m/movies.dat        
  inflating: ml-1m/ratings.dat       
  inflating: ml-1m/README            
  inflating: ml-1m/users.dat         


In [None]:
# Get the dataset
dataset = MovieLensDataset('./ml-1m/ratings.dat')

# Split the data into 80% train, 10% validation, and 10% test
train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, (train_length, valid_length, test_length))

# Instantiate data loader classes for train, validation, and test sets
train_data_loader = DataLoader(train_dataset, batch_size=1024, num_workers=8)
valid_data_loader = DataLoader(valid_dataset, batch_size=1024, num_workers=8)
test_data_loader = DataLoader(test_dataset, batch_size=1024, num_workers=8)

In [None]:
# Get the model
field_dims = dataset.field_dims
learning_rate=0.001
weight_decay=1e-6
epoch=10
device='cpu'

model = DeepFM(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.5)
# Use binary cross entropy loss
criterion = torch.nn.BCELoss()
# Use Adam optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Loop through pre-defined number of epochs
for epoch_i in range(epoch):
    # Perform training on the train set
    fit(model, optimizer, train_data_loader, criterion, device)
    # Perform evaluation on the validation set
    valid_auc = test(model, valid_data_loader, device)
    # Log the epochs and AUC on the validation set
    print('epoch:', epoch_i, 'validation: auc:', valid_auc)

# Perform evaluation on the test set
test_auc = test(model, test_data_loader, device)
# Log the final AUC on the test set
print('test auc:', test_auc)

# Save the model checkpoint
torch.save(model.state_dict(), 'deepfm.pt')



  0%|          | 0/782 [00:00<?, ?it/s][A[A

  5%|▍         | 36/782 [00:01<00:20, 35.65it/s][A[A

 12%|█▏        | 97/782 [00:02<00:14, 48.24it/s][A[A

 20%|██        | 159/782 [00:03<00:11, 52.68it/s][A[A

 28%|██▊       | 221/782 [00:04<00:10, 52.88it/s][A[A

 36%|███▌      | 283/782 [00:05<00:09, 54.54it/s][A[A

 44%|████▍     | 345/782 [00:06<00:07, 55.60it/s][A[A

 52%|█████▏    | 407/782 [00:07<00:06, 56.40it/s][A[A

 60%|█████▉    | 469/782 [00:08<00:05, 56.97it/s][A[A

 68%|██████▊   | 531/782 [00:09<00:04, 57.38it/s][A[A

 76%|███████▌  | 593/782 [00:10<00:03, 57.79it/s][A[A

 84%|████████▍ | 655/782 [00:11<00:02, 58.00it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 58.17it/s]


100%|██████████| 98/98 [00:01<00:00, 95.55it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 0 validation: auc: 0.7716942366617788




  5%|▍         | 38/782 [00:01<00:19, 37.51it/s][A[A

 13%|█▎        | 98/782 [00:02<00:14, 48.53it/s][A[A

 20%|██        | 158/782 [00:03<00:12, 51.96it/s][A[A

 28%|██▊       | 218/782 [00:04<00:10, 53.88it/s][A[A

 36%|███▌      | 278/782 [00:05<00:09, 54.97it/s][A[A

 43%|████▎     | 339/782 [00:06<00:07, 55.95it/s][A[A

 51%|█████     | 400/782 [00:07<00:06, 56.55it/s][A[A

 59%|█████▉    | 461/782 [00:08<00:05, 57.00it/s][A[A

 67%|██████▋   | 522/782 [00:09<00:04, 56.94it/s][A[A

 75%|███████▍  | 583/782 [00:10<00:03, 56.35it/s][A[A

 82%|████████▏ | 644/782 [00:11<00:02, 56.64it/s][A[A

 90%|█████████ | 705/782 [00:12<00:01, 56.86it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 56.77it/s]


100%|██████████| 98/98 [00:01<00:00, 95.91it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 1 validation: auc: 0.778057202956051




  5%|▌         | 40/782 [00:01<00:18, 39.30it/s][A[A

 13%|█▎        | 100/782 [00:02<00:13, 49.29it/s][A[A

 21%|██        | 161/782 [00:03<00:11, 53.09it/s][A[A

 28%|██▊       | 222/782 [00:04<00:10, 54.81it/s][A[A

 36%|███▌      | 283/782 [00:05<00:08, 55.88it/s][A[A

 44%|████▍     | 344/782 [00:06<00:07, 56.42it/s][A[A

 52%|█████▏    | 405/782 [00:07<00:06, 56.90it/s][A[A

 60%|█████▉    | 466/782 [00:08<00:05, 57.34it/s][A[A

 67%|██████▋   | 527/782 [00:09<00:04, 57.71it/s][A[A

 75%|███████▌  | 588/782 [00:10<00:03, 57.95it/s][A[A

 83%|████████▎ | 649/782 [00:11<00:02, 58.21it/s][A[A

 91%|█████████ | 710/782 [00:12<00:01, 58.40it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 58.18it/s]


100%|██████████| 98/98 [00:01<00:00, 95.07it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 2 validation: auc: 0.7817825948637652




  5%|▍         | 39/782 [00:01<00:19, 38.41it/s][A[A

 13%|█▎        | 100/782 [00:02<00:13, 49.49it/s][A[A

 21%|██        | 161/782 [00:03<00:11, 52.27it/s][A[A

 28%|██▊       | 222/782 [00:04<00:10, 53.33it/s][A[A

 36%|███▌      | 283/782 [00:05<00:09, 54.42it/s][A[A

 44%|████▍     | 344/782 [00:06<00:07, 55.29it/s][A[A

 52%|█████▏    | 405/782 [00:07<00:06, 55.81it/s][A[A

 60%|█████▉    | 466/782 [00:08<00:05, 56.22it/s][A[A

 67%|██████▋   | 527/782 [00:09<00:04, 56.51it/s][A[A

 75%|███████▌  | 588/782 [00:10<00:03, 56.77it/s][A[A

 83%|████████▎ | 649/782 [00:11<00:02, 57.04it/s][A[A

 91%|█████████ | 710/782 [00:12<00:01, 57.25it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 56.92it/s]


100%|██████████| 98/98 [00:01<00:00, 95.03it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 3 validation: auc: 0.7823289699267744




  5%|▍         | 38/782 [00:01<00:19, 37.91it/s][A[A

 12%|█▏        | 97/782 [00:02<00:14, 48.08it/s][A[A

 20%|█▉        | 156/782 [00:03<00:12, 51.23it/s][A[A

 27%|██▋       | 215/782 [00:04<00:10, 52.66it/s][A[A

 35%|███▌      | 274/782 [00:05<00:09, 52.20it/s][A[A

 43%|████▎     | 334/782 [00:06<00:08, 53.36it/s][A[A

 50%|█████     | 394/782 [00:07<00:07, 54.20it/s][A[A

 58%|█████▊    | 454/782 [00:08<00:05, 54.85it/s][A[A

 66%|██████▌   | 514/782 [00:09<00:04, 55.32it/s][A[A

 73%|███████▎  | 574/782 [00:10<00:03, 55.71it/s][A[A

 81%|████████  | 634/782 [00:11<00:02, 56.01it/s][A[A

 89%|████████▊ | 694/782 [00:12<00:01, 56.30it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 56.32it/s]


100%|██████████| 98/98 [00:01<00:00, 91.51it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 4 validation: auc: 0.7844714420679844




  5%|▌         | 40/782 [00:01<00:18, 39.42it/s][A[A

 13%|█▎        | 101/782 [00:02<00:13, 49.94it/s][A[A

 21%|██        | 162/782 [00:03<00:11, 53.31it/s][A[A

 29%|██▊       | 223/782 [00:04<00:10, 54.87it/s][A[A

 36%|███▋      | 284/782 [00:05<00:08, 55.95it/s][A[A

 44%|████▍     | 345/782 [00:06<00:07, 56.67it/s][A[A

 52%|█████▏    | 406/782 [00:07<00:06, 56.68it/s][A[A

 60%|█████▉    | 467/782 [00:08<00:05, 57.00it/s][A[A

 68%|██████▊   | 528/782 [00:09<00:04, 57.30it/s][A[A

 75%|███████▌  | 589/782 [00:10<00:03, 57.49it/s][A[A

 83%|████████▎ | 650/782 [00:11<00:02, 57.74it/s][A[A

 91%|█████████ | 711/782 [00:12<00:01, 57.92it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 57.72it/s]


100%|██████████| 98/98 [00:01<00:00, 96.51it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 5 validation: auc: 0.7857398733393367




  5%|▌         | 40/782 [00:01<00:18, 39.36it/s][A[A

 13%|█▎        | 101/782 [00:02<00:13, 49.72it/s][A[A

 21%|██        | 162/782 [00:03<00:11, 53.16it/s][A[A

 29%|██▊       | 223/782 [00:04<00:10, 55.00it/s][A[A

 36%|███▋      | 284/782 [00:05<00:08, 55.90it/s][A[A

 44%|████▍     | 345/782 [00:06<00:07, 56.55it/s][A[A

 52%|█████▏    | 406/782 [00:07<00:06, 57.03it/s][A[A

 60%|█████▉    | 467/782 [00:08<00:05, 57.36it/s][A[A

 68%|██████▊   | 528/782 [00:09<00:04, 57.68it/s][A[A

 75%|███████▌  | 589/782 [00:10<00:03, 57.89it/s][A[A

 83%|████████▎ | 650/782 [00:11<00:02, 58.06it/s][A[A

 91%|█████████ | 711/782 [00:12<00:01, 58.25it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 58.07it/s]


100%|██████████| 98/98 [00:01<00:00, 96.56it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 6 validation: auc: 0.7858337378842934




  5%|▍         | 39/782 [00:01<00:19, 38.78it/s][A[A

 13%|█▎        | 98/782 [00:02<00:14, 48.80it/s][A[A

 20%|██        | 159/782 [00:03<00:11, 52.66it/s][A[A

 28%|██▊       | 220/782 [00:04<00:10, 53.67it/s][A[A

 36%|███▌      | 281/782 [00:05<00:09, 54.82it/s][A[A

 44%|████▎     | 342/782 [00:06<00:07, 55.61it/s][A[A

 52%|█████▏    | 403/782 [00:07<00:06, 56.21it/s][A[A

 59%|█████▉    | 464/782 [00:08<00:05, 56.78it/s][A[A

 67%|██████▋   | 525/782 [00:09<00:04, 57.15it/s][A[A

 75%|███████▍  | 586/782 [00:10<00:03, 57.38it/s][A[A

 83%|████████▎ | 647/782 [00:11<00:02, 57.32it/s][A[A

 91%|█████████ | 708/782 [00:12<00:01, 57.43it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 57.28it/s]


100%|██████████| 98/98 [00:01<00:00, 94.96it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 7 validation: auc: 0.7866237857522




  5%|▌         | 41/782 [00:01<00:18, 40.32it/s][A[A

 13%|█▎        | 101/782 [00:02<00:13, 49.96it/s][A[A

 21%|██        | 161/782 [00:03<00:11, 52.58it/s][A[A

 28%|██▊       | 222/782 [00:04<00:10, 54.45it/s][A[A

 36%|███▌      | 283/782 [00:05<00:08, 55.57it/s][A[A

 44%|████▍     | 344/782 [00:06<00:07, 55.24it/s][A[A

 52%|█████▏    | 405/782 [00:07<00:06, 55.96it/s][A[A

 60%|█████▉    | 466/782 [00:08<00:05, 56.46it/s][A[A

 67%|██████▋   | 527/782 [00:09<00:04, 56.86it/s][A[A

 75%|███████▌  | 588/782 [00:10<00:03, 57.23it/s][A[A

 83%|████████▎ | 649/782 [00:11<00:02, 57.46it/s][A[A

 91%|█████████ | 710/782 [00:12<00:01, 57.63it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 57.49it/s]


100%|██████████| 98/98 [00:01<00:00, 97.17it/s]


  0%|          | 0/782 [00:00<?, ?it/s][A[A

epoch: 8 validation: auc: 0.7879520626528944




  5%|▍         | 38/782 [00:01<00:19, 37.86it/s][A[A

 13%|█▎        | 98/782 [00:02<00:14, 48.64it/s][A[A

 20%|██        | 159/782 [00:03<00:11, 52.51it/s][A[A

 28%|██▊       | 220/782 [00:04<00:10, 54.26it/s][A[A

 36%|███▌      | 281/782 [00:05<00:09, 55.11it/s][A[A

 44%|████▎     | 342/782 [00:06<00:07, 55.68it/s][A[A

 52%|█████▏    | 403/782 [00:07<00:06, 56.08it/s][A[A

 59%|█████▉    | 464/782 [00:08<00:05, 56.35it/s][A[A

 67%|██████▋   | 525/782 [00:09<00:04, 56.60it/s][A[A

 75%|███████▍  | 586/782 [00:10<00:03, 56.76it/s][A[A

 83%|████████▎ | 647/782 [00:11<00:02, 56.95it/s][A[A

 91%|█████████ | 708/782 [00:12<00:01, 57.19it/s][A[A

100%|██████████| 782/782 [00:13<00:00, 56.82it/s]


100%|██████████| 98/98 [00:01<00:00, 91.17it/s]


  0%|          | 0/98 [00:00<?, ?it/s][A[A

epoch: 9 validation: auc: 0.7886707771358512


100%|██████████| 98/98 [00:01<00:00, 93.95it/s]

test auc: 0.7920102605357593



