In [1]:
from __future__ import absolute_import, division, print_function
import json
from collections import Counter
import os
import argparse
from math import log
from tqdm.auto import tqdm
from easydict import EasyDict as edict
import torch
from functools import reduce
from kg_env_m import BatchKGEnvironment
from actor_critic import ActorCritic
from utils import *
import wandb

class Agrument:
    config = "../../config/beauty/graph_reasoning/UPGPR.json"
    seed = 0 
    set_name = "test_cold_start"
    domain = 'Beauty'
    
args = Agrument() #parser.parse_args()
    
with open(args.config, "r") as f:
    config = edict(json.load(f))

config.processed_data_dir = "../../data/beauty/Amazon_Beauty_01_01"



In [2]:
###load_user_pref
def load_user_pref(path, domain):
    user_pref_path = os.path.join(path)
    # Load JSON data from a file
    user_pref = json.load(open(f'{user_pref_path}/user_preference_{domain}.json', 'r'))
    return user_pref


if args.domain is not None:
    user_pref = load_user_pref(config.processed_data_dir, args.domain)

In [3]:
from make_cold_start_kg import InitalUserEmbedding
if args.set_name in ['test', 'test_cold_start', 'test_cold_start_trim', 'test_cold_start_trim_past_other']:
    set_name = 'test'
    
cold_start_uids = {}

init_embed = InitalUserEmbedding(
    set_name=set_name,
    config=config
)

Load embedding: ../../data/beauty/Amazon_Beauty_01_01/test_transe_embed.pkl


In [4]:
# embeds  = load_embed(config.processed_data_dir, set_name)
# dataset = load_dataset(config.processed_data_dir, set_name)

for idx in tqdm(range(len(user_pref))):
    user_id = user_pref[str(idx)]['idx_user']
    target_item = user_pref[str(idx)]['idx_item']
    user_acc_feature = user_pref[str(idx)]['user_acc_feature']
    user_rej_feature = user_pref[str(idx)]['user_rej_feature']
    user_rej_items = user_pref[str(idx)]['user_rej_items']
    
    user_preferred = init_embed.user_preference_config(
        user_acc_feature = user_acc_feature, 
        user_rej_feature = user_rej_feature, 
        user_rej_items = user_rej_items, 
    )
    
    user_pref_emb = init_embed.embeds['user'][user_pref[str(idx)]['idx_user']]
    
    idx_cand_user, cand_user_emb = init_embed.distance(user_pref_emb, top_k=6)
    user_preferred['related_user'] = idx_cand_user
    cold_start_uids[user_pref[str(idx)]['idx_user']] = user_preferred
    # break

print('all_user_pref', len(cold_start_uids))

  0%|          | 0/2251 [00:00<?, ?it/s]

all_user_pref 2079


In [9]:
# cold_start_uids[867]['non-purchase']

In [10]:
config.seed = args.seed
config_agent = config.AGENT
config.processed_data_dir = '../../data/beauty/Amazon_Beauty_01_01'

os.environ["CUDA_VISIBLE_DEVICES"] = config_agent.gpu
config_agent.device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"


if config_agent.early_stopping == True:
    with open("early_stopping.txt", "r") as f:
        config_agent.epochs = int(f.read())

config_agent.log_dir = config.processed_data_dir + "/" + config_agent.name

In [11]:

####test(config, args.set_name)
set_name = args.set_name
config_agent = config.AGENT
kg_config = config.KG_ARGS

policy_file = config_agent.log_dir + "/tmp_policy_model_epoch_{}.ckpt".format(
    config_agent.epochs
)
if set_name == 'test':
    path_file = config_agent.log_dir + "/policy_paths_epoch_{}.pkl".format(
        config_agent.epochs
    )
elif set_name == 'test_cold_start':
    path_file = config_agent.log_dir + "/policy_paths_epoch_{}_cold_start.pkl".format(
        config_agent.epochs
    )
elif set_name == 'test_cold_start_trim':
    path_file = config_agent.log_dir + "/policy_paths_epoch_{}_cold_start_trim.pkl".format(
        config_agent.epochs
    )
elif set_name == 'test_cold_start_trim_past_other':
    path_file = config_agent.log_dir + "/policy_paths_epoch_{}_cold_start_trim_past_other.pkl".format(
        config_agent.epochs
    )

train_labels = load_labels(config.processed_data_dir, "train")
test_labels = load_labels(config.processed_data_dir, "test")

dataset_name = config.processed_data_dir.split("/")[-1]

train_labels = load_labels(config.processed_data_dir, "train")
test_labels = load_labels(config.processed_data_dir, "test")

dataset_name = config.processed_data_dir.split("/")[-1]

model_name = (
    "UPGPR_len_"
    + str(config_agent.max_path_len)
    + "_"
    + config.AGENT.reward
    + "_"
    + config.TRAIN_EMBEDS.cold_start_embeddings
    + "_mask_"
    + str(config.AGENT.mask_first_interaction)
    + "_max_cold_concept_"
    + str(kg_config.max_nb_cold_entities)
    + "_topk_"
    + "_".join(map(str, config_agent.topk))
)

config_agent.result_file_dir = os.path.join(
    config_agent.result_file_dir, dataset_name, model_name, str(config.seed)
)

os.makedirs(
    config_agent.result_file_dir,
    exist_ok=True,
)

In [7]:
# if config_agent.run_path:
#     predict_paths(
#         policy_file, 
#         path_file, config, 
#         config_agent, 
#         kg_config
#     )
set_name = 'test'

In [8]:
from test_agent import batch_beam_search, batch_beam_search_cold_start

print("Predicting paths...")
env = BatchKGEnvironment(
    config.processed_data_dir,
    kg_config,
    set_name=set_name,
    max_acts=config_agent.max_acts,
    max_path_len=config_agent.max_path_len,
    state_history=config_agent.state_history,
    reward_function=config_agent.reward,
    mask_first_interaction=True,
    use_pattern=config_agent.use_pattern,
)
pretrain_sd = torch.load(policy_file, map_location=torch.device("cpu"))
model = ActorCritic(
    env.state_dim,
    env.act_dim,
    gamma=config_agent.gamma,
    hidden_sizes=config_agent.hidden,
    modified_policy=config_agent.modified_policy,
    embed_size=env.embed_size,
).to(config_agent.device)
model_sd = model.state_dict()
model_sd.update(pretrain_sd)
model.load_state_dict(model_sd)

if set_name in ['test', 'test_cold_start', 'test_cold_start_trim', 'test_cold_start_trim_past_other']:
    test_labels = load_labels(config.processed_data_dir, "test")
else:
    test_labels = load_labels(config.processed_data_dir, set_name)
test_uids = list(test_labels.keys())

Predicting paths...
Orginal kg.G.keys dict_keys(['user', 'item', 'brand', 'category', 'related_product', 'word', 'feature'])
delete kg.G['feaure']
Updated kg.G.keys dict_keys(['user', 'item', 'brand', 'category', 'related_product', 'word'])
Load embedding: ../../data/beauty/Amazon_Beauty_01_01/test_transe_embed.pkl
Key embed dict_keys(['user', 'item', 'brand', 'category', 'related_product', 'word', 'also_bought', 'also_viewed', 'bought_together', 'described', 'belong_to', 'category_of', 'mentioned', 'interested_in', 'like', 'dislike', 'purchase'])


In [9]:
# # load cold start users
# cold_users_path = os.path.join(config.processed_data_dir, "cold_start_users.json")
# cold_users = json.load(open(cold_users_path, "r"))

# # load cold start items
# # cold_items_path = os.path.join(config.processed_data_dir, "cold_start_items.json")
# # cold_items = json.load(open(cold_items_path, "r"))

In [10]:
# Convert lists to sets for fast membership checking
test_uids_set = set(test_uids)
cold_start_set = set(cold_start_uids)
# len(test_uids), len(all_user_pref)
# Extract elements in test_uids but not in all_user_pref
extracted_uids = test_uids_set - cold_start_set
# Convert the result back to a list if needed
non_cold_start_uids = list(extracted_uids)
# len(non_cold_start_uids)
assert len(test_uids) == len(cold_start_uids) + len(non_cold_start_uids)

In [11]:
len(non_cold_start_uids)

16228

In [30]:
batch_size = 16
start_idx = 0
all_paths, all_probs = [], []
pbar = tqdm(total=len(non_cold_start_uids))
# Non-cold start user
while start_idx < len(non_cold_start_uids):
    end_idx = min(start_idx + batch_size, len(non_cold_start_uids))
    batch_uids = non_cold_start_uids[start_idx:end_idx]
    paths, probs = batch_beam_search(
        env,
        model,
        kg_config,
        batch_uids,
        config_agent.device,
        topk=config_agent.topk,
        policy=config_agent.modified_policy,
    )
    all_paths.extend(paths)
    all_probs.extend(probs)
    start_idx = end_idx
    pbar.update(batch_size)
    # break

# Cold start user
def update_paths_with_uid(paths, uid):
    updated_paths = []
    for path in paths:
        updated_path = [(path[0][0], path[0][1], uid)] + path[1:]
        updated_paths.append(updated_path)
    return updated_paths

start_idx = 0
for uid in tqdm(cold_start_uids):
    batch_uids = cold_start_uids[uid]['related_user'][1:]
    paths, probs = batch_beam_search_cold_start(
        env,
        model,
        kg_config,
        batch_uids,
        config_agent.device,
        topk=config_agent.topk,
        policy=config_agent.modified_policy,
        user_pref_embed = user_pref_emb #adding user_pref 
    )

    updated_paths = update_paths_with_uid(paths, uid)
    
    all_paths.extend(updated_paths)
    all_probs.extend(probs)
    # break
    
predicts = {"paths": all_paths, "probs": all_probs}
# pickle.dump(predicts, open(path_file, "wb"))
# if config.use_wandb:
#     wandb.save(path_file)

  0%|          | 0/16228 [00:00<?, ?it/s]

  0%|          | 0/2079 [00:00<?, ?it/s]

In [31]:
# len(predicts['paths'])

In [32]:
# if config_agent.run_eval:
#     evaluate_paths(
#         config.processed_data_dir,
#         path_file,
#         train_labels,
#         test_labels,
#         kg_config,
#         config.use_wandb,
#         config_agent.result_file_dir,
#         validation=False,
#     )

In [12]:
result_file_dir = config_agent.result_file_dir
dir_path = config.processed_data_dir

embeds = load_embed(dir_path, set_name)
# embeds = load_embed(dir_path, 'train')
user_embeds = embeds["user"]
interaction_embeds = embeds[kg_config.interaction][0]
item_embeds = embeds["item"]
scores = np.dot(user_embeds + interaction_embeds, item_embeds.T)

# 1) Get all valid paths for each user, compute path score and path probability.
results = pickle.load(open(path_file, "rb"))
pred_paths = {uid: {} for uid in test_labels}

Load embedding: ../../data/beauty/Amazon_Beauty_01_01/test_cold_start_transe_embed.pkl


In [16]:
867 in cold_start_uids.keys()

True

In [15]:
cold_start_uids[pid]['non-purchase']

[11777,
 1548,
 11803,
 6686,
 10272,
 9257,
 8234,
 9774,
 4155,
 4160,
 4674,
 5698,
 7238,
 11336,
 7246,
 10321,
 2641,
 4188,
 5729,
 112,
 10365,
 10367,
 7808,
 3212,
 10899,
 4246,
 11420,
 3740,
 668,
 4767,
 9888,
 4257,
 2722,
 5796,
 7343,
 8890,
 7872,
 1217,
 8396,
 2766,
 724,
 8930,
 249,
 10498,
 7940,
 3844,
 6408,
 3353,
 10017,
 7974,
 9519,
 2352,
 10550,
 10551,
 9018,
 5445,
 4954,
 3950,
 9584,
 6010,
 7557,
 7045,
 2958,
 6051,
 5028,
 1446,
 10158,
 5553,
 438,
 5047,
 2493,
 4543,
 8143,
 10194,
 7637,
 990,
 991,
 4076,
 9716,
 4091]

In [13]:
results

{'paths': [[('self_loop', 'user', 5164),
   ('like', 'brand', 58),
   ('belong_to', 'item', 617),
   ('self_loop', 'item', 617)],
  [('self_loop', 'user', 5164),
   ('like', 'brand', 58),
   ('belong_to', 'item', 2458),
   ('self_loop', 'item', 2458)],
  [('self_loop', 'user', 5164),
   ('mentioned', 'word', 4924),
   ('described', 'item', 235),
   ('self_loop', 'item', 235)],
  [('self_loop', 'user', 5164),
   ('mentioned', 'word', 4924),
   ('described', 'item', 1315),
   ('self_loop', 'item', 1315)],
  [('self_loop', 'user', 5164),
   ('dislike', 'brand', 34),
   ('belong_to', 'item', 2723),
   ('self_loop', 'item', 2723)],
  [('self_loop', 'user', 5164),
   ('dislike', 'brand', 34),
   ('belong_to', 'item', 167),
   ('self_loop', 'item', 167)],
  [('self_loop', 'user', 5164),
   ('mentioned', 'word', 120),
   ('described', 'item', 9),
   ('self_loop', 'item', 9)],
  [('self_loop', 'user', 5164),
   ('mentioned', 'word', 120),
   ('described', 'item', 85),
   ('self_loop', 'item', 8

In [34]:
for path, probs in zip(results["paths"], results["probs"]):
    if path[-1][1] != "item":
        continue
    uid = path[0][2]
    # 2) Triming item which are assosiacted with user_rej_items        
    # if pid in user_rej_items: 
    #     continue  # Skip this item if it's in the user_rej_items list
    if uid not in pred_paths:
        continue
    pid = path[-1][2]
    if pid not in pred_paths[uid]:
        pred_paths[uid][pid] = []
    path_score = scores[uid][pid]
    path_prob = reduce(lambda x, y: x * y, probs)
    pred_paths[uid][pid].append((path_score, path_prob, path))

In [35]:
# 3) Pick best path for each user-product pair, also remove pid if it is in train set.
best_pred_paths = {}
for uid in pred_paths:
    train_pids = set(train_labels.get(uid, []))
    # if len(train_pids) == 0:
    #     continue
    best_pred_paths[uid] = []
    for pid in pred_paths[uid]:
        if pid in train_pids:
            continue
        # Get the path with highest probability
        sorted_path = sorted(pred_paths[uid][pid], key=lambda x: x[1], reverse=True)
        best_pred_paths[uid].append(sorted_path[0])

# path_patterns = {}
# for uid in best_pred_paths:
#     for path in best_pred_paths[uid]:
#         path_pattern = path[2]
#         pattern_key = ""
#         for node in path_pattern:
#             pattern_key += node[0] + "_" + node[1] + "-->"
#         path_patterns[pattern_key] = path_patterns.get(pattern_key, 0) + 1

# path_patterns

In [36]:
# filename = os.path.join(result_file_dir, "patterns.json")
# # json.dump(path_patterns, open(filename, "w"), indent=4)

# cold_start_users_path = os.path.join(dir_path, "cold_start_users.json")
# cold_start_users = json.load(open(cold_start_users_path, "r"))
# cold_start_users = set(cold_start_users["train"])

# cold_path_patterns = {}
# for uid in best_pred_paths:
#     if uid not in cold_start_users:
#         for path in best_pred_paths[uid]:
#             path_pattern = path[2]
#             pattern_key = ""
#             for node in path_pattern:
#                 pattern_key += node[0] + "_" + node[1] + "-->"
#             cold_path_patterns[pattern_key] = (
#                 cold_path_patterns.get(pattern_key, 0) + 1
#             )

# cold_path_patterns

In [37]:
filename = os.path.join(result_file_dir, "cold_patterns.json")
# json.dump(cold_path_patterns, open(filename, "w"), indent=4)

# computes the item distribution from train_labels
item_distribution = Counter()
for uid in train_labels:
    item_distribution.update(train_labels[uid])

filename = os.path.join(result_file_dir, "item_distribution.json")
# json.dump(item_distribution, open(filename, "w"), indent=4)

# 3) Compute top 10 recommended products for each user.
sort_by = "score"
pred_labels = {}
for uid in best_pred_paths:
    if sort_by == "score":
        sorted_path = sorted(
            best_pred_paths[uid], key=lambda x: (x[0], x[1]), reverse=True
        )
    elif sort_by == "prob":
        sorted_path = sorted(
            best_pred_paths[uid], key=lambda x: (x[1], x[0]), reverse=True
        )
    top_pids = [p[-1][2] for _, _, p in sorted_path]  # from largest to smallest

    pred_labels[uid] = top_pids[:10]  # change order to from smallest to largest!

In [38]:
pred_labels

{5164: [43, 1315, 617, 537, 85, 9, 167, 235, 2723, 2458],
 1333: [903, 734, 1047, 72, 106, 247, 195, 167, 91],
 12080: [1155, 463, 278, 936, 31, 84, 151, 355],
 867: [512, 19, 1050, 1135, 608],
 12887: [106, 821, 838, 1114, 92, 219, 414],
 3563: [1750, 2779, 21, 27, 987, 1848, 488, 968],
 297: [1389, 1193, 122, 1030, 63, 133],
 22289: [1506, 195, 43, 725, 839, 161, 823, 566, 937, 85],
 6214: [473, 56, 617, 217, 1496, 44, 23, 314, 161],
 21576: [21, 363, 96, 987, 4926, 5295, 367, 152],
 378: [93, 96, 363, 815, 27, 5630],
 20047: [581, 806, 901, 137, 35, 734, 82, 2026, 1126],
 3167: [62, 676, 111, 1426, 1111, 65, 1110, 1802, 6452, 6782],
 6045: [106, 415, 1251, 167, 43, 286, 367, 56, 179],
 15355: [1453, 1062, 218, 439, 161, 43, 4329, 4699],
 4492: [695, 439, 38, 37, 247, 65, 3, 826],
 12566: [2673, 106, 23, 124, 436, 132, 1313, 3249, 320, 245],
 12757: [581, 728, 107, 1076, 921, 161, 81, 36],
 21991: [21, 705, 27, 126, 4865, 43],
 3012: [5301, 1144, 9224, 530, 3349, 295, 99, 23, 11],
 5

In [39]:
test_labels

{5164: [2085],
 1333: [10758],
 12080: [9151],
 867: [1643, 1811],
 12887: [4919, 7614, 9578, 7623],
 3563: [9918, 2171],
 297: [6846],
 22289: [9576],
 6214: [10865],
 21576: [1284],
 378: [5630],
 20047: [3541],
 3167: [11369],
 6045: [7186],
 15355: [5486],
 4492: [10371],
 12566: [9063, 8923, 2171, 7309, 9706, 167, 8180],
 12757: [1057],
 21991: [6856],
 3012: [2339],
 5373: [4870],
 4248: [2246],
 1259: [4158],
 4598: [6424],
 7562: [503],
 14357: [7182],
 1749: [11153],
 2189: [2910],
 16766: [5845],
 18384: [12005],
 1005: [11586],
 20135: [2325],
 9841: [8407],
 12552: [10907],
 7344: [9617],
 1063: [11133],
 3645: [5603],
 10969: [3200],
 4887: [3637],
 3606: [6299],
 9079: [10772],
 20406: [1901],
 13371: [10640],
 21617: [6021],
 11703: [1430],
 17322: [9523],
 13036: [5112],
 10612: [8486],
 14620: [8374],
 22019: [4807],
 9033: [4165],
 15205: [9614],
 6756: [3266],
 18137: [7689],
 6400: [1975, 10969],
 12153: [6378],
 3648: [11632],
 18831: [5467],
 1187: [6408],
 12497:

In [40]:
topk_matches = pred_labels
test_user_products = test_labels
train_user_products = train_labels
use_wandb = False
dir_path = config.processed_data_dir
result_file_dir =result_file_dir
min_items=10
compute_all=True
k = 10

In [41]:
user_relations = Counter()
kg = load_kg(dir_path, set_name="test")
for uid, relations in kg.G["user"].items():
    for entities in relations.values():
        user_relations[uid] += len(entities)
        
invalid_users = []

# Metrics for all users
user_metrics = dict()

# Compute metrics
precisions, recalls, ndcgs, hits, hits_at_1, hits_at_3, hits_at_5 = ([], [], [], [], [], [], [], )
(precisions_all, recalls_all, ndcgs_all, hits_all, hits_at_1_all, hits_at_3_all, hits_at_5_all, ) = ([], [], [], [], [], [], [])
test_user_idxs = list(test_user_products.keys())

In [42]:
for uid in test_user_idxs:
    is_invalid = False
    if uid not in topk_matches or len(topk_matches[uid]) < min_items:
        invalid_users.append(uid)
        is_invalid = True
    pred_list, rel_set = topk_matches.get(uid, []), test_user_products[uid]
    nb_train = len(train_user_products.get(uid, []))
    if len(pred_list) == 0:
        ndcgs_all.append(0.0)
        recalls_all.append(0.0)
        precisions_all.append(0.0)
        hits_all.append(0.0)
        hits_at_1_all.append(0.0)
        hits_at_3_all.append(0.0)
        hits_at_5_all.append(0.0)
        continue

    if is_invalid == False:
        dcg = 0.0
        hit_num = 0.0
        hit_at_1 = 0.0
        hit_at_3 = 0.0
        hit_at_5 = 0.0

        for i in range(len(pred_list)):
            if pred_list[i] in rel_set:
                dcg += 1.0 / (log(i + 2) / log(2))
                hit_num += 1
                if i < 1:
                    hit_at_1 += 1
                if i < 3:
                    hit_at_3 += 1
                if i < 5:
                    hit_at_5 += 1
        # idcg
        idcg = 0.0
        for i in range(min(len(rel_set), k)):
            idcg += 1.0 / (log(i + 2) / log(2))
        ndcg = dcg / idcg
        recall = hit_num / len(rel_set)
        precision = hit_num / k
        hit = 1.0 if hit_num > 0.0 else 0.0
        hit_at_1 = 1.0 if hit_at_1 > 0.0 else 0.0
        hit_at_3 = 1.0 if hit_at_3 > 0.0 else 0.0
        hit_at_5 = 1.0 if hit_at_5 > 0.0 else 0.0

        ndcgs.append(ndcg)
        recalls.append(recall)
        precisions.append(precision)
        hits.append(hit)
        hits_at_1.append(hit_at_1)
        hits_at_3.append(hit_at_3)
        hits_at_5.append(hit_at_5)

        ndcgs_all.append(ndcg)
        recalls_all.append(recall)
        precisions_all.append(precision)
        hits_all.append(hit)
        hits_at_1_all.append(hit_at_1)
        hits_at_3_all.append(hit_at_3)
        hits_at_5_all.append(hit_at_5)
            
    elif compute_all == True:
        dcg_all = 0.0
        hit_num_all = 0.0
        hit_at_1_all = 0.0
        hit_at_3_all = 0.0
        hit_at_5_all = 0.0
        for i in range(len(pred_list)):
            if pred_list[i] in rel_set:
                dcg_all += 1.0 / (log(i + 2) / log(2))
                hit_num_all += 1
                if i < 1:
                    hit_at_1_all += 1
                if i < 3:
                    hit_at_3_all += 1
                if i < 5:
                    hit_at_5_all += 1
        # idcg
        idcg_all = 0.0
        for i in range(min(len(rel_set), k)):
            idcg_all += 1.0 / (log(i + 2) / log(2))
        ndcg_all = dcg_all / idcg_all
        recall_all = hit_num_all / len(rel_set)
        precision_all = hit_num_all / k
        hit_all = 1.0 if hit_num_all > 0.0 else 0.0
        hit_at_1_all = 1.0 if hit_at_1_all > 0.0 else 0.0
        hit_at_3_all = 1.0 if hit_at_3_all > 0.0 else 0.0
        hit_at_5_all = 1.0 if hit_at_5_all > 0.0 else 0.0
        ndcgs_all.append(ndcg_all)
        recalls_all.append(recall_all)
        precisions_all.append(precision_all)
        hits_all.append(hit_all)
        hits_at_1_all.append(hit_at_1_all)
        hits_at_3_all.append(hit_at_3_all)
        hits_at_5_all.append(hit_at_5_all)
        
    else:
        ndcgs_all.append(0.0)
        recalls_all.append(0.0)
        precisions_all.append(0.0)
        hits_all.append(0.0)
        hits_at_1_all.append(0.0)
        hits_at_3_all.append(0.0)
        hits_at_5_all.append(0.0)
        

    user_metrics[uid] = {
        "ndcg": ndcgs_all[-1] * 100,
        "recall": recalls_all[-1] * 100,
        "hit": hits_all[-1] * 100,
        "precision": precisions_all[-1] * 100,
        "predictions": pred_list,
        "nb_train": nb_train,
        "nb_relations": user_relations[uid],
    }

    # break

In [44]:
avg_precision = np.mean(precisions) * 100
avg_recall = np.mean(recalls) * 100
avg_ndcg = np.mean(ndcgs) * 100
avg_hit = np.mean(hits) * 100
avg_hit_at_1 = np.mean(hits_at_1) * 100
avg_hit_at_3 = np.mean(hits_at_3) * 100
avg_hit_at_5 = np.mean(hits_at_5) * 100

avg_precision_all = np.mean(precisions_all) * 100
avg_recall_all = np.mean(recalls_all) * 100
avg_ndcg_all = np.mean(ndcgs_all) * 100
avg_hit_all = np.mean(hits_all) * 100
avg_hit_at_1_all = np.mean(hits_at_1_all) * 100
avg_hit_at_3_all = np.mean(hits_at_3_all) * 100
avg_hit_at_5_all = np.mean(hits_at_5_all) * 100

print(
    "NDCG={:.3f} |  Recall={:.3f} | HR={:.3f} | Precision={:.3f} | HR@1={:.3f} | HR@3={:.3f} | HR@5={:.3f} | Computed for all users.\n".format(
        avg_ndcg_all,
        avg_recall_all,
        avg_hit_all,
        avg_precision_all,
        avg_hit_at_1_all,
        avg_hit_at_3_all,
        avg_hit_at_5_all,
    )
)

NDCG=3.459 |  Recall=6.875 | HR=6.997 | Precision=0.700 | HR@1=1.169 | HR@3=2.584 | HR@5=3.884 | Computed for all users.



In [None]:
# # from test_agent import evaluate

# use_wandb = False

# evaluate(
#     pred_labels,
#     test_labels,
#     train_labels,
#     use_wandb,
#     config.processed_data_dir,
#     result_file_dir=result_file_dir,
#     min_items=10,
#     compute_all=True,
# )