In [None]:
import sys
sys.path.append('../src/')

In [None]:
import numpy as np
import torch
from scipy import sparse as sp
from torch import nn
import faiss
from tqdm.notebook import tqdm

from nn_models import  UserModel
from utils import (
    ProductEncoder, 
    coo_to_pytorch_sparse, 
    make_coo_row, 
    get_shard_path, 
    normalized_average_precision
)

In [None]:
class TorchPredictor:
    def __init__(self, product_csv_path, user_model_path, knn_index_path, dim):
        self.product_encoder = ProductEncoder(product_csv_path)
        user_model = UserModel(self.product_encoder.num_products, dim)
        user_model.load_state_dict(torch.load(user_model_path))
        
        self.user_model = user_model
        self.knn_index = faiss.read_index(knn_index_path)


    def predict(self, trans_history):
        user_input_row = coo_to_pytorch_sparse(self.make_coo_row(trans_history, self.product_encoder))
        user_vectors = self.user_model(user_input_row).data.numpy()
        user_vectors /= np.linalg.norm(user_vectors)
        preds = self.knn_index.search(user_vectors, 30)[1][0]
        return self.product_encoder.toPid([x for x in preds if x > 0])

In [None]:
predictor = TorchPredictor(
    product_csv_path='../data/raw/products.csv',
    user_model_path='../artifacts/embds_d128/user_model_cpu.pth',
    knn_index_path='../artifacts/embds_d128/knn.idx',
    dim=128
)

In [None]:
scores = []
for js in tqdm((json.loads(l) for l in open(get_shard_path(15)))):
    gt_items = js["target"][0]["product_ids"]
    recommended_items = predictor.predict(js["transactions_history"])
    scores.append(normalized_average_precision(gt_items, recommended_items))
print(np.mean(scores))