In [None]:
#Dependencies

import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import itertools
import matplotlib.pyplot as plt
import time
import math
import collections

import json 
import pickle

from src.environment.ml_env import SimulatedEnv, SimulatedFairEnv

import obp
from obp.policy.policy_type import PolicyType
from bandit import EpsilonGreedy, LinUCB, WFairLinUCB, FairLinUCB
from src.model.pmf import PMF

In [None]:
dataset_path = "data/movie_lens_100k_output_path.json"
with open(dataset_path) as json_file:
    _dataset_path = json.load(json_file)


dataset = {}
with open(_dataset_path["eval_users_dict"], "rb") as pkl_file:
    dataset["eval_users_dict"] = pickle.load(pkl_file)

with open(_dataset_path["eval_users_dict_positive_items"], "rb") as pkl_file:
    dataset["eval_users_dict_positive_items"] = pickle.load(pkl_file)

with open(_dataset_path["eval_users_history_lens"], "rb") as pkl_file:
    dataset["eval_users_history_lens"] = pickle.load(pkl_file)

with open(_dataset_path["users_history_lens"], "rb") as pkl_file:
    dataset["users_history_lens"] = pickle.load(pkl_file)

with open(_dataset_path["movies_groups"], "rb") as pkl_file:
    dataset["movies_groups"] = pickle.load(pkl_file)

In [None]:
ENV = dict(drr=SimulatedEnv, fairrec=SimulatedFairEnv)

In [None]:
train_ids = [
    "egreedy_0.1_2021-10-25_19-45-40.pkl",
    "linear_ucb_0.25_2021-10-25_19-46-08.pkl",
    "wfair_linear_ucb_0.25_2021-10-25_19-46-18.pkl"
]

In [None]:
algorithm = "drr"
train_version = "bandits"
train_id = train_ids[1]
output_path = "model/{}/{}".format(train_version, train_id)

users_num = 943
items_num = 1682
state_size = 5
embedding_dim = 50
emb_model = "user_movie"
embedding_network_weights = "model/pmf/emb_50_ratio_0.800000_bs_1000_e_258_wd_0.100000_lr_0.000100_trained_pmf.pt"
n_groups = 10
fairness_constraints = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]


top_k = 10
done_count = 10

In [None]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

reward_model = PMF(users_num, items_num, embedding_dim)
reward_model.load_state_dict(
    torch.load(embedding_network_weights, map_location=torch.device(device))
)
user_embeddings = reward_model.user_embeddings.weight.data
item_embeddings = reward_model.item_embeddings.weight.data

In [None]:
def calculate_ndcg(rel, irel):
    dcg = 0
    idcg = 0
    rel = [1 if r > 0 else 0 for r in rel]
    for i, (r, ir) in enumerate(zip(rel, irel)):
        dcg += (r) / np.log2(i + 2)
        idcg += (ir) / np.log2(i + 2)

    return dcg, idcg

In [None]:
with open(output_path, "rb") as pkl_file:
    bandit = pickle.load(pkl_file)

In [None]:
sum_precision = 0
sum_ndcg = 0
sum_propfair = 0

env = ENV[algorithm](
    users_dict=dataset["eval_users_dict"],
    users_history_lens=dataset["eval_users_history_lens"],
    n_groups=n_groups,
    movies_groups=dataset["movies_groups"],
    state_size=state_size,
    done_count=done_count,
    fairness_constraints=fairness_constraints,
)
available_users = env.available_users

sum_precision = 0
sum_ndcg = 0
sum_propfair = 0

for user_id in tqdm(available_users):
    with open(output_path, "rb") as pkl_file:
        bandit = pickle.load(pkl_file)

    eval_env = ENV[algorithm](
        users_dict=dataset["eval_users_dict"],
        users_history_lens=dataset["eval_users_history_lens"],
        n_groups=n_groups,
        movies_groups=dataset["movies_groups"],
        state_size=state_size,
        done_count=done_count,
        fairness_constraints=fairness_constraints,
        fix_user_id=user_id,
        reward_model=reward_model,
    )
    
    steps = 0
    mean_action = 0
    mean_precision = 0
    mean_ndcg = 0

    # environment
    user_id, items_ids, done = eval_env.reset()
    bandit.len_list = top_k
    bandit.clear_group_count()

    while not done:
        steps += 1

        # select a list of actions
        if bandit.policy_type == PolicyType.CONTEXT_FREE:
            selected_actions = bandit.select_action()
        elif bandit.policy_type == PolicyType.CONTEXTUAL:
            # observe current state & Find action
            user_eb = user_embeddings[user_id]
            items_eb = item_embeddings[items_ids]
            item_ave = torch.mean(items_eb, 0)

            context = torch.cat((user_eb, user_eb * item_ave, item_ave), 0).cpu().numpy()
            context = context.reshape(1, 150)
            selected_actions = bandit.select_action(
                context
            )

        # calculate reward and observe new state
        ## Step
        next_items_ids, rewards, done, _ = eval_env.step(
            selected_actions, top_k=top_k
        )

        # rewards = [1 if (r*2)+3 >= 4 else 0 for r in rewards]

        for action, reward in zip(selected_actions, rewards):
            if bandit.policy_type == PolicyType.CONTEXT_FREE:
                bandit.update_params(action=action, reward=reward)
            elif bandit.policy_type == PolicyType.CONTEXTUAL:
                bandit.update_params(
                    action=action,
                    reward=reward,
                    context=context,
                )

        if top_k:
            correct_list = [1 if r > 0 else 0 for r in rewards]
            # ndcg
            dcg, idcg = calculate_ndcg(
                correct_list, [1 for _ in range(len(rewards))]
            )
            mean_ndcg += dcg / idcg

            # precision
            correct_num = top_k - correct_list.count(0)
            mean_precision += correct_num / top_k
        else:
            mean_precision += 1 if rewards > 0 else 0

        if done:
            sum_propfair += bandit.propfair
            sum_precision += mean_precision / steps
            sum_ndcg += mean_ndcg / steps


    del eval_env

print("---------- Evaluation")
print("- precision@: ", round(sum_precision / len(dataset["eval_users_dict"]), 4))
print("- ndcg@: ", round(sum_ndcg / len(dataset["eval_users_dict"]), 4))
print("- propfair: ", round(sum_propfair / len(dataset["eval_users_dict"]), 4)), 
print(
    "- ufg: ",
    round((sum_propfair / len(dataset["eval_users_dict"]))
    / (1 - (sum_precision / len(dataset["eval_users_dict"]))), 4)
)
print()