In [None]:
# default_exp networks

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
import torch
import torch.nn as nn

from torch import tensor

# Networks

> Common neural network architectures for *Collaborative Filtering*.

# Overview

This package implements several neural network architectures that can be used to build recommendation systems. Users of the library can add or define their own implementations or use the existing ones. There are two layers that every architecture should define:

* **user_embeddings**: The user embedding matrix
* **item_embeddings**: The item embedding matrix

Every implementation should be a subclass of `torch.nn.Module`.

## Simple Collaborative Filtering

This architecture is the simplest one to implement *Collaborative Filtering*. It only defines the embedding matrices for users and items and the final rating is computed by the dot product of the corresponding rows.

In [None]:
# export
class SimpleCF(nn.Module):
    def __init__(self, n_users: int, n_items: int, factors: int = 16,
                 init: torch.nn.init = torch.nn.init.normal_, **kwargs):
        super().__init__()
        self.user_embeddings = nn.Embedding(n_users, factors)
        init(self.user_embeddings.weight.data, **kwargs)
        self.item_embeddings = nn.Embedding(n_items, factors)
        init(self.item_embeddings.weight.data, **kwargs)

    def forward(self, u: torch.tensor, i: torch.tensor) -> torch.tensor:
        user_embedding = self.user_embeddings(u)
        item_embedding = self.item_embeddings(i)
        rating = torch.matmul(user_embedding, item_embedding.transpose(0, 1))
        return rating

Arguments:

* n_users (int): The number of unique users
* n_items (int): The number of unique items
* factors (int): The dimension of the embedding space
* init (torch.nn.init): The initialization method of the embedding matrices - default: torch.nn.init.normal_

In [None]:
# initialize the model with 100 users, 50 items and a 16-dimensional embedding space
model = SimpleCF(100, 50, 16, mean=0., std=.1)

# predict the rating that user 3 would give to item 33
model(torch.tensor([2]), torch.tensor([32]))

tensor([[0.0055]], grad_fn=<MmBackward>)