In [27]:
# make sure to change args into args = parser.parse_args(args=[]) in parse.py
%load_ext autoreload
%autoreload 2

import world
import utils
import torch
import os
import dataloader
import model
import numpy as np
from tqdm import tqdm
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
def load_model(model_name):
    models = {
        "localgcn": model.LocalGCN,
        "lgn" : model.LightGCN
    }
    return models[model_name]

def load_dataset(dataset_path):
    return dataloader.Loader(path=dataset_path)

def test_one_batch(X):
    sorted_items = X[0].numpy()
    gt = X[1]
    r = utils.getLabel(gt, sorted_items)
    pre, recall, ndcg = [], [], []
    for k in world.config["topks"]:
        ret = utils.recall_precision_at_K(gt, r, k)
        pre.append(ret["precision"])
        recall.append(ret["recall"])
        ndcg.append(utils.NDCG_at_K(gt, r, k))
    return {"recall": np.array(recall), "precision": np.array(pre), "ndcg": np.array(ndcg)}

In [29]:
# GET Dataset
dataname = "ml-1m"
world.config['dataset'] = dataname
data_path = os.path.join(world.DATA_PATH, dataname)
dataset = load_dataset(dataset_path=data_path)

2024-01-15 15:24:57,478 INFO  [dataloader.py:17] loading [/home/s1/eungikim/Research/LocalGCN_public/data/ml-1m]
2024-01-15 15:24:57,708 INFO  [dataloader.py:47] 407184 interactions for training
2024-01-15 15:24:57,710 INFO  [dataloader.py:48] 112428 interactions for testing
2024-01-15 15:24:57,711 INFO  [dataloader.py:50] ml-1m Sparsity : 0.027547671321246144
2024-01-15 15:24:57,711 INFO  [dataloader.py:52] number of users : 6034
2024-01-15 15:24:57,712 INFO  [dataloader.py:53] number of items : 3126
2024-01-15 15:24:58,411 INFO  [dataloader.py:68] ml-1m is ready to go


In [30]:
# GET POP
item, count = np.unique(dataset.train_item, return_counts=True)
pop = torch.zeros(dataset.m_items)
for i, c in tqdm(zip(item, count)):
    pop[i] = c

3125it [00:00, 231723.36it/s]


In [31]:
u_batch_size = world.config["test_u_batch_size"]
test_dict: dict = dataset.test_dict

# eval mode with no dropout
max_K = max(world.config["topks"])
results = {
    "precision": np.zeros(len(world.config["topks"])),
    "recall": np.zeros(len(world.config["topks"])),
    "ndcg": np.zeros(len(world.config["topks"])),
}

users = list(test_dict.keys())
try:
    assert u_batch_size <= len(users) / 10
except AssertionError:
    world.LOGGER.info(f"test_u_batch_size is too big for this dataset, try a small one {len(users) // 10}")
users_list = []
rating_list = []
gt_list = []
total_batch = len(users) // u_batch_size + 1
test_loader = tqdm(utils.minibatch(users, batch_size=u_batch_size))

print("TOTAL BATCH:", total_batch)
for batch_users in tqdm(test_loader, total=total_batch):
    all_pos = dataset.get_user_pos_items(batch_users)
    gt = [test_dict[u] for u in batch_users]
    batch_users = torch.Tensor(batch_users).long()

    rating = pop.unsqueeze(0).repeat(batch_users.shape[0], 1)
    exclude_index = []
    exclude_items = []
    for range_i, items in enumerate(all_pos):
        exclude_index.extend([range_i] * len(items))
        exclude_items.extend(items)
    rating[exclude_index, exclude_items] = -(1 << 10)
    _, rating_K = torch.topk(rating, k=max_K)
    rating = rating.cpu().numpy()
    del rating
    users_list.append(batch_users)
    rating_list.append(rating_K.cpu())
    gt_list.append(gt)
assert total_batch == len(users_list)

X = zip(rating_list, gt_list)
pre_results = []
for x in X:
    pre_results.append(test_one_batch(x))
scale = float(u_batch_size / len(users))
for result in pre_results:
    results["recall"] += result["recall"]
    results["precision"] += result["precision"]
    results["ndcg"] += result["ndcg"]
results["recall"] /= float(len(users))
results["precision"] /= float(len(users))
results["ndcg"] /= float(len(users))

2024-01-15 15:24:58,530 INFO  [<ipython-input-31-aba5d156a18f>:16] test_u_batch_size is too big for this dataset, try a small one 603
0it [00:00, ?it/s]

TOTAL BATCH: 7


7it [00:00,  7.78it/s]
100%|██████████| 7/7 [00:00<00:00,  7.80it/s]


In [32]:
world.LOGGER.info(results)

2024-01-15 15:24:59,578 INFO  [<ipython-input-32-9ee1c07100a8>:1] {'precision': array([0.15545244, 0.13215114, 0.11823003, 0.10077892]), 'recall': array([0.01262393, 0.0458194 , 0.08138716, 0.13433756]), 'ndcg': array([0.15545244, 0.14241747, 0.1397087 , 0.14522995])}


In [11]:
world.LOGGER.info("model>> POP")

2024-01-14 21:11:30,713 INFO  [<ipython-input-11-3b42ed2ae836>:1] model>> POP
