In [1]:
import torch
import torch.nn as nn


In [2]:
class CollaborativeFiltering(nn.Module):
    def __init__(self, num_users, num_items, content_embed_size=12, embedding_dim=32, hidden_dim=32, output_dim=16):
        super(CollaborativeFiltering, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.ffnn_item = nn.Sequential(
            nn.Linear(embedding_dim + content_embed_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        self.ffnn_user = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, user_ids, item_ids, content_embed):
        user_embed = self.user_embedding(user_ids)
        item_embed = self.item_embedding(item_ids)
        item_content_embed = torch.cat([item_embed, content_embed], dim=1)
        
        item_scores = self.ffnn_item(item_content_embed)
        user_scores = self.ffnn_user(user_embed)
        
        return torch.matmul(user_embed, item_scores.T)