## Note
In this notebook we will load a trained GMF++ model, and go over the evaluation procedure. The GMF++ is based on simple model introduced by [He et al](https://arxiv.org/abs/1708.05031). You can try to adapt other models such as MLP and NMF. The [original implementation](https://github.com/hexiangnan/neural_collaborative_filtering/tree/4aab159e81c44b062c091bdaed0ab54ac632371f) as well as other implemntations are available for single market settings.     

In [2]:
import argparse
import pandas as pd
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, ConcatDataset

import os
import json
import resource
import sys
import pickle

sys.path.insert(1, 'src')
from model import Model
from utils import *
from data import *
from train_baseline import *

In [3]:
parser = create_arg_parser()
tgt_market = 't1'
src_markets = 's1'

args = parser.parse_args(f'--tgt_market {tgt_market} --src_markets {src_markets}'.split()) #

if torch.cuda.is_available() and args.cuda:
    torch.cuda.set_device(0)
args.device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

In [5]:
# load pretrained model
model_dir = f'checkpoints/{tgt_market}_{src_markets}_toytest.model'
id_bank_dir = f'checkpoints/{tgt_market}_{src_markets}_toytest.pickle'

valid_run = f'valid_{tgt_market}_{src_markets}_toytest.tsv'

with open(id_bank_dir, 'rb') as centralid_file:
    my_id_bank = pickle.load(centralid_file)

mymodel = Model(args, my_id_bank)
mymodel.load(model_dir)

Model is GMF++!
GMF(
  (embedding_user): Embedding(9164, 8)
  (embedding_item): Embedding(10341, 8)
  (affine_output): Linear(in_features=8, out_features=1, bias=True)
  (logistic): Sigmoid()
)
Pretrained weights from checkpoints/t1_s1_toytest.model are loaded!


In [7]:
############
## Target Market data
############
tgt_train_data_dir = os.path.join(args.data_dir, args.tgt_market, 'train.tsv')
tgt_train_ratings = pd.read_csv(tgt_train_data_dir, sep='\t')

print(f'loading {tgt_train_data_dir}')
tgt_task_generator = TaskGenerator(tgt_train_ratings, my_id_bank)
print('loaded target data!')

tgt_valid_dataloader = tgt_task_generator.instance_a_market_valid_dataloader(args.tgt_market_valid, args.batch_size)
tgt_test_dataloader = tgt_task_generator.instance_a_market_valid_dataloader(args.tgt_market_test, args.batch_size)
print('loaded target test and validation data!')

loading DATA/t1/train.tsv
loaded target data!
loaded target test and validation data!


In [8]:
valid_run_mf = mymodel.predict(tgt_valid_dataloader)

tgt_valid_qrel = read_qrel_file('DATA/t1/valid_qrel.tsv')
task_ov, task_ind = get_evaluations_final(valid_run_mf, tgt_valid_qrel)

In [11]:
task_ov

{'P_5': 0.0668149796069707,
 'P_10': 0.047608453837597334,
 'P_20': 0.030385613644790505,
 'recall_5': 0.3340748980348535,
 'recall_10': 0.4760845383759733,
 'recall_20': 0.6077122728958102,
 'ndcg_cut_10': 0.27169953636762195,
 'map_cut_10': 0.20906974827998187}