In [None]:
import sys 
sys.path.append('..')

#Dependencies
import os
import json 
import pickle
import wandb
from tqdm import tqdm

import pandas as pd

import torch
import numpy as np

from src.environment.ml_env import OfflineEnv
from src.model.pmf import PMF

from obp.policy.policy_type import PolicyType
from src.model.bandit import EpsilonGreedy, LinUCB, WFairLinUCB, LogisticUCB

AGENT = dict(egreedy=EpsilonGreedy, lin_ucb=LinUCB, wfair_linucb = WFairLinUCB, logistic_ucb = LogisticUCB)

In [None]:
dataset_path = "../data/movie_lens_100k_output_path.json"
with open(dataset_path) as json_file:
    _dataset_path = json.load(json_file)


dataset = {}
with open(os.path.join("..", _dataset_path["train_users_dict"]), "rb") as pkl_file:
    dataset["train_users_dict"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["train_users_history_lens"]), "rb") as pkl_file:
    dataset["train_users_history_lens"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["users_history_lens"]), "rb") as pkl_file:
    dataset["users_history_lens"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["movies_groups"]), "rb") as pkl_file:
    dataset["movies_groups"] = pickle.load(pkl_file)

In [None]:
from sklearn.preprocessing import OneHotEncoder
enc = OneHotEncoder(handle_unknown='ignore')

def age_group_bukets(age):
    if age < 20:
        return 0
    elif age < 30:
        return 1
    elif age < 40:
        return 2
    elif age < 50:
        return 3
    elif age < 60:
        return 4
    else:
        return 5


user_df = pd.read_csv("../data/ml-100k/users.csv").drop(columns=["zip_code"])
gender = pd.get_dummies(user_df.gender)
occupation = pd.get_dummies(user_df.occupation)
user_df["age"] = user_df["age"].apply(lambda x: age_group_bukets(x))
age = pd.get_dummies(user_df.age)
user_df = pd.concat([user_df, gender, occupation, age], axis=1)
user_df = user_df.drop(columns=["user_id", "age", "gender", "occupation"])
user_df

# Bandit Models

In [None]:
train_version = "lin_ucb"
output_path = "../model/bandits"

users_num = 943
items_num = 1682
state_size = 5
context_dim = 29
embedding_dim = 50
epsilon = 0.1
batch_size = 32
embedding_network_weights = "../model/pmf/emb_50_ratio_0.800000_bs_1000_e_258_wd_0.100000_lr_0.000100_trained_pmf.pt"
n_groups = 10
fairness_constraints = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

top_k = None
done_count = 10


max_episode_num = 50000
use_wandb = True

In [None]:
wandb.init(
    project="bandits",
    config={
        "bandit": train_version,
        "users_num": users_num,
        "items_num": items_num,
        "state_size": state_size,
        "context_dim": context_dim,
        "batch_size": batch_size,
        "embedding_dim": embedding_dim,
        "group_fairness": n_groups,
        "fairness_constraints": fairness_constraints,
        "reward_model": True,
    },
)

In [None]:
bandit = AGENT[train_version](
    batch_size=batch_size,
    dim=context_dim,
    n_actions=items_num,
    epsilon=epsilon,
    n_group=n_groups,
    item_group=dataset["movies_groups"],
    fairness_weight= {k+1: fairness_constraints[k] for k in range(n_groups)}
)

os.mkdir(os.path.join(output_path, bandit.policy_name))

In [None]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

reward_model = PMF(users_num, items_num, embedding_dim)
reward_model.load_state_dict(
    torch.load(embedding_network_weights, map_location=torch.device(device))
)
user_embeddings = reward_model.user_embeddings.weight.data
item_embeddings = reward_model.item_embeddings.weight.data

In [None]:
env = OfflineEnv(
    users_dict=dataset["train_users_dict"],
    users_history_lens=dataset["train_users_history_lens"],
    n_groups=n_groups,
    movies_groups=dataset["movies_groups"],
    state_size=state_size,
    done_count=done_count,
    fairness_constraints=fairness_constraints,
    reward_model=reward_model,
)

In [None]:
def calculate_ndcg(rel, irel):
    dcg = 0
    idcg = 0
    rel = [1 if r > 0 else 0 for r in rel]
    for i, (r, ir) in enumerate(zip(rel, irel)):
        dcg += (r) / np.log2(i + 2)
        idcg += (ir) / np.log2(i + 2)

    return dcg, idcg

In [None]:
sum_precision = 0
sum_ndcg = 0
sum_propfair = 0
sum_reward = 0

for episode in tqdm(range(max_episode_num)):

    # episodic reward
    episode_reward = 0
    steps = 0
    critic_loss = 0
    actor_loss = 0
    mean_action = 0
    mean_precision = 0
    mean_ndcg = 0
    
    # environment
    user_id, items_ids, done = env.reset()
    bandit.clear_group_count()

    while not done:

        # select a list of actions
        if bandit.policy_type == PolicyType.CONTEXT_FREE:
            selected_actions = bandit.select_action()
        elif bandit.policy_type == PolicyType.CONTEXTUAL:
            # observe current state & Find action
            context = user_df.iloc[user_id].values
            # user_eb = user_embeddings[user_id]
            # items_eb = item_embeddings[items_ids]
            # item_ave = torch.mean(items_eb, 0)
            # # context = torch.cat((user_eb, user_eb * item_ave, item_ave), 0).cpu().numpy()
            # context = torch.cat((user_eb, item_ave), 0).cpu().numpy()
            context = context.reshape(1, context_dim)
            selected_actions = bandit.select_action(context)
        
        ## Step
        next_items_ids, rewards, done, _ = env.step(
            selected_actions if top_k else selected_actions[0], top_k=False
        )

        rewards = rewards if top_k else [rewards]

        for action, reward in zip(selected_actions, rewards):
            if bandit.policy_type == PolicyType.CONTEXT_FREE:
                bandit.update_params(action=action, reward=reward)
            elif bandit.policy_type == PolicyType.CONTEXTUAL:
                bandit.update_params(
                    action=action,
                    reward=1 if reward >= 0.5 else 0,
                    context=context,
                )
        steps += 1
        sum_reward += np.sum(rewards) 

        if top_k:
            correct_list = [1 if r > 0 else 0 for r in reward]
            # ndcg
            dcg, idcg = calculate_ndcg(
                correct_list, [1 for _ in range(len(reward))]
            )
            mean_ndcg += dcg / idcg

            # precision
            correct_num = top_k - correct_list.count(0)
            mean_precision += correct_num / top_k
        else:
            mean_precision += 1 if reward > 0 else 0
        

        if done:
            propfair = 0
            total_exp = np.sum(list(env.group_count.values()))
            if total_exp > 0:
                propfair = np.sum(
                    np.array(fairness_constraints)
                    * np.log(
                        1
                        + np.array(list(env.group_count.values()))
                        / total_exp
                    )
                )

            sum_precision += mean_precision / steps
            sum_ndcg += mean_ndcg / steps
            sum_propfair += propfair
            sum_reward += episode_reward

            if use_wandb:
                wandb.log(
                    {
                        "precision": (mean_precision / steps) * 100,
                        "ndcg": mean_ndcg / steps,
                        "total_reward": episode_reward,
                        "propfair": propfair,
                        "cvr": mean_precision / steps,
                        "ufg": propfair
                        / max(1 - (mean_precision / steps), 0.01),
                    }
                )

    if (episode + 1) % 1000 == 0:
        with open(os.path.join(output_path, bandit.policy_name, "{}_{}.pkl".format(bandit.policy_name, episode + 1)), "wb") as file:
            pickle.dump(bandit, file)

print("Precision ",  round(sum_precision / max_episode_num, 4))
print("NDCG ", round(sum_ndcg / max_episode_num, 4))
print("Propfair ", round(sum_propfair / max_episode_num, 4))
print("Reward ", round(sum_reward / max_episode_num, 4))
