In [1]:
import torch
from tqdm.auto import tqdm
import itertools
import numpy as np
from torch.utils.data import DataLoader
import itertools

from recoxplainer.utils.torch_utils import use_optimizer
from recoxplainer.data_reader import DataReader
from recoxplainer.config import cfg
from recoxplainer.models import PyTorchModel
from recoxplainer.data_reader.user_item_rating_dataset import UserItemRatingDataset
from recoxplainer.recommender import Recommender 
from recoxplainer.evaluator import Splitter, Evaluator, ExplanationEvaluator
from recoxplainer.explain import ARPostHocExplainer, KNNPostHocExplainer

In [31]:
class MFModel(torch.nn.Module):

    def __init__(self,
                 learning_rate: int,
                 weight_decay: int,
                 latent_dim: int,
                 epochs: int,
                 batch_size: int,
                 device_id=None):
        
        super().__init__()
        
        self.weight_decay = weight_decay
        self.latent_dim = latent_dim
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.batch_size = batch_size
        
        self.criterion = torch.nn.MSELoss()
        
    def forward(self, user_indices, item_indices):
        
        user_embeddings = self.embedding_user(user_indices)
        item_embeddings = self.embedding_item(item_indices)
        return (user_embeddings * item_embeddings).sum(1)
        
    def fit(self, dataset_metadata):
        
        self.embedding_user = torch.nn.Embedding(
            num_embeddings=dataset_metadata.num_user,
            embedding_dim=self.latent_dim)

        self.embedding_item = torch.nn.Embedding(
            num_embeddings=dataset_metadata.num_item,
            embedding_dim=self.latent_dim)
        
        self.optimizer = torch.optim.SGD(self.parameters(), 
                                         lr=self.learning_rate, 
                                         weight_decay=self.weight_decay)
        
        dataset = UserItemRatingDataset(user_tensor=torch.LongTensor(dataset_metadata.dataset.userId),
                                        item_tensor=torch.LongTensor(dataset_metadata.dataset.itemId),
                                        target_tensor=torch.FloatTensor(dataset_metadata.dataset.rating))
        
        with tqdm(total=self.epochs) as progress:
            for epoch in range(self.epochs):
                
                data = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
                tot_loss = 0
                cnt = 0
                for _, batch in enumerate(data):
                    
                    self.optimizer.zero_grad()
                    
                    user, item, rating = batch[0], batch[1], batch[2]
                    
                    prediction = self(user, item)
                    loss = self.criterion(prediction, rating)
                    loss.backward()
                    self.optimizer.step()
                    tot_loss += loss.item()
                    cnt += 1
            
                progress.update(1)
                progress.set_postfix({"MSE": tot_loss/cnt})
        
        return True
    
    def predict(self, user_id, item_id):
        if type(user_id) == 'int':
            user_id = [user_id]
        if type(item_id) == 'int':
            item_id = [item_id]
        user_id = torch.LongTensor([user_id])
        item_id = torch.LongTensor(item_id)
        with torch.no_grad():
            pred = self.forward(user_id, item_id).cpu().tolist()
            return pred


In [32]:
mf = MFModel(latent_dim=100, 
        epochs=100,
        learning_rate=.01, 
        weight_decay=.001,
        batch_size=128)

In [33]:
data = DataReader(**cfg.ml100k)
data.make_consecutive_ids_in_dataset()
sp = Splitter()
train, test = sp.split_leave_n_out(data, frac=0.1)

In [34]:
mf.fit(train)

  0%|          | 0/100 [00:00<?, ?it/s]

True

In [35]:
rec = Recommender(train, mf)
rec = rec.recommend_all()

Recommending for users:   0%|          | 0/943 [00:00<?, ?it/s]

In [36]:
evaluator = Evaluator(test)
evaluator.cal_hit_ratio(rec)

0.009489919311771776

In [37]:
ar = ARPostHocExplainer(mf, rec, train)
expl_ar = ar.explain_recommendations()

Computing explanations:   0%|          | 0/9430 [00:00<?, ?it/s]

In [38]:
expl_eval = ExplanationEvaluator(train.num_user)

In [39]:
expl_eval.model_fidelity(expl_ar)

0.14209968186638347