In [22]:
from collections import defaultdict
import torch
import math
import numpy as np
import pandas as pd

def HitRate(proba_all_topic_csv, n=10):
    """
    @args:
    proba_all_topic_csv: {user_id, topic_id, was_interaction, predict_proba} csv file return recommendation probability for every topic 
    n: keep top n topic recommendation

    @output:
    hit_list: (array) hit[i] == 0/1 indicates hit/miss for test sample i 
    ndcg_list: (array) ndcg[i] score for test sample ith
    """
    hit_list = []
    ndcg_list = []
    proba_all_topic_df = pd.read_csv(proba_all_topic_csv)
    user_predict = proba_all_topic_df.groupby(['user_id'])

    for user, topic in user_predict:
        #print(topic)
        # Get the top N of highest probability and rank them 
        topN = [x for _, x in sorted(zip(topic['predict_proba'], topic['topic_id']), reverse=True)][:n]
        positive_topic = int(topic[topic['was_interaction']==1]['topic_id'])
        # Calculate hit rate
        hit_list.append(getHitRatio(topN, positive_topic))
        
        # Calculate NDCG
        ndcg_list.append(getNDCG(topN, positive_topic))
        
    print(f'HR@{n}: {np.array(hit_list).mean()}, NDCG@{n}: {np.array(ndcg_list).mean()}')
    return (hit_list, ndcg_list)


def getHitRatio(ranklist, topic):
    for item in ranklist:
        if item == topic:
            return 1
    return 0

def getNDCG(ranklist, topic):
    for i in range(len(ranklist)):
        item = ranklist[i]
        if item == topic:
            return math.log(2) / math.log(i+2)
    return 0

HitRate('ncf_64_predictive_factors_first_try_outputs.csv', 5)

HR@5: 0.328125, NDCG@5: 0.2130844848571618


([1,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  1,
  1,
  0,
  0,
  1,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  0,
  1,
  0,
  0,
  1,
  1,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  1,
  1,
  0,
  0,
  1,
  0],
 [0.5,
  0,
  0,
  0,
  0,
  0,
  0,
  0.3868528072345416,
  0,
  0,
  1.0,
  0.6309297535714574,
  0,
  0,
  0.3868528072345416,
  0,
  0,
  0,
  0.43067655807339306,
  0.43067655807339306,
  0.43067655807339306,
  0.43067655807339306,
  0,
  0.43067655807339306,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1.0,
  0,
  0,
  0,
  0,
  0,
  1.0,
  0.6309297535714574,
  0,
  1.0,
  0,
  0,
  0.6309297535714574,
  1.0,
  0,
  0,
  0,
  1.0,
  0,
  0,
  0,
  0,
  0,
  0.43067655807339306,
  0,
  0,
  0.3868528072345416,
  1.0,
  0,
  0,
  0.5,
  0])