In [None]:
# default_exp step

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
import torch
import numpy as np

from abc import ABCMeta
from abc import abstractmethod
from typing import Callable
from tqdm import tqdm

# Step

> *Incremental Collaborative Filtering* algorithms.

## Step Base

This class defines the interface that every `Step` module should implement. Namely, each class should implement five methods:

* `batch_fit`: To support batch training
* `step`: To support incremental learning
* `predict`: To offer recommendations
* `save`: To save the model parameters
* `load`: To load the model parameters

In [None]:
# export
class StepBase:
    """Defines the interface that all step models here expose."""
    __metaclass__ = ABCMeta
    
    @abstractmethod
    def batch_fit(self, data_loader: torch.utils.data.DataLoader, epochs: int):
        """Trains the model on a batch of user-item interactions."""
        pass
    
    @abstractmethod
    def step(self, user: torch.tensor, item: torch.tensor, 
             rating: torch.tensor, preference: torch.tensor):
        """Trains the model incrementally."""
        pass
    
    @abstractmethod
    def predict(self, user: torch.tensor, k: int):
        """Recommends the top-k items to a specific user."""
        pass
    
    @abstractmethod
    def save(self, path: str):
        """Saves the model parameters to the given path."""
        pass
    
    @abstractmethod
    def load(self, path: str):
        """Loads the model parameters from a given path."""
        pass

## Step

The step class implements the basic *Incremental Collaborative Filtering* recommender system.

In [None]:
# export
class Step(StepBase):
    """Incremental and batch training of recommender systems."""
    def __init__(self, model: torch.nn.Module, objective: Callable,
                 optimizer: Callable, conf_func: Callable = lambda x: 1):
        self.model = model
        self.objective = objective
        self.optimizer = optimizer
        self.conf_func = conf_func

        # check if the user has provided user and item embeddings
        assert self.model.user_embeddings, 'User embedding matrix could not be found.'
        assert self.model.item_embeddings, 'Item embedding matrix could not be found.'

    @property
    def user_embeddings(self):
        return self.model.user_embeddings

    @property
    def item_embeddings(self):
        return self.model.item_embeddings

    def batch_fit(self, data_loader: torch.utils.data.DataLoader, epochs: int = 1):
        """Trains the model on a batch of user-item interactions."""
        self.model.train()
        for epoch in range(epochs):
            with tqdm(total=len(data_loader)) as pbar:
                for _, (users, items, ratings, preferences) in enumerate(data_loader):
                    predictions = self.model(users, items)
                    conf = self.conf_func(ratings)
                    loss = (conf * self.objective(predictions, preferences)).mean()
                    loss.backward()
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    pbar.update(1)

    def step(self, user: torch.tensor, item: torch.tensor, 
             rating: torch.tensor = None, preference: torch.tensor = None):
        """Trains the model incrementally."""
        self.model.train()
        prediction = self.model(user, item)
        conf = self.conf_func(rating)
        loss = conf * self.objective(prediction, preference)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

    def predict(self, user: torch.tensor, k:int = 10) -> torch.tensor:
        """Recommends the top-k items to a specific user."""
        self.model.eval()
        user_embedding = self.user_embeddings(user)
        item_embeddings = self.item_embeddings.weight
        score = item_embeddings @ user_embedding.transpose(0, 1)
        predictions = score.squeeze().argsort()[-k:]
        return predictions

    def save(self, path: str):
        """Saves the model parameters to the given path."""
        torch.save(self.model.state_dict(), path)

    def load(self, path: str):
        """Loads the model parameters from a given path."""
        self.model.load_state_dict(torch.load(path))

Arguments:

* model (torch.nn.Module): The neural network architecture
* objective (Callable): The objective function
* optimizer (Callable): The method used to optimize the objective function. Usually a `torch.optim` loss function
* conf_func (Callable): A method that converts implicit ratings to confidence scores  

In [None]:
show_doc(Step.batch_fit)

<h4 id="Step.batch_fit" class="doc_header"><code>Step.batch_fit</code><a href="__main__.py#L23" class="source_link" style="float:right">[source]</a></h4>

> <code>Step.batch_fit</code>(**`data_loader`**:`DataLoader`, **`epochs`**:`int`=*`1`*)

Trains the model on a batch of user-item interactions.

In [None]:
show_doc(Step.step)

<h4 id="Step.step" class="doc_header"><code>Step.step</code><a href="__main__.py#L37" class="source_link" style="float:right">[source]</a></h4>

> <code>Step.step</code>(**`user`**:`tensor`, **`item`**:`tensor`, **`rating`**:`tensor`=*`None`*, **`preference`**:`tensor`=*`None`*)

Trains the model incrementally.

In [None]:
show_doc(Step.predict)

<h4 id="Step.predict" class="doc_header"><code>Step.predict</code><a href="__main__.py#L48" class="source_link" style="float:right">[source]</a></h4>

> <code>Step.predict</code>(**`user`**:`tensor`, **`k`**:`int`=*`10`*)

Recommends the top-k items to a specific user.

In [None]:
show_doc(Step.save)

<h4 id="Step.save" class="doc_header"><code>Step.save</code><a href="__main__.py#L57" class="source_link" style="float:right">[source]</a></h4>

> <code>Step.save</code>(**`path`**:`str`)

Saves the model parameters to the given path.

In [None]:
show_doc(Step.load)

<h4 id="Step.load" class="doc_header"><code>Step.load</code><a href="__main__.py#L61" class="source_link" style="float:right">[source]</a></h4>

> <code>Step.load</code>(**`path`**:`str`)

Loads the model parameters from a given path.