In [25]:
from irec.recommendation.agents.value_functions import LinearUCB, MostPopular, GenericThompsonSampling, EGreedy
from irec.recommendation.agents.action_selection_policies import ASPGenericGreedy, ASPGreedy
from irec.offline_experiments.metric_evaluators import UserCumulativeInteraction
from irec.offline_experiments.evaluation_policies import FixedInteraction
from irec.recommendation.agents import SimpleEnsembleAgent, SimpleAgent
from irec.offline_experiments.metrics import Hits, EPC, Recall, ILD
from irec.environment.loader import FullData

In [26]:
import pandas as pd
import numpy as np

## Load Dataset

In [27]:
# Dataset
dataset = {
    'path': "datasets/MovieLens 100k/ratings.csv",
    'random_seed': 0,
    'file_delimiter': ",",
    'skip_head': True
}
# Splitting
splitting = {'strategy': "global", 'train_size': 0.8, 'test_consumes': 5}
validation = {'validation_size': 0.2}
# Loader
loader = FullData(dataset, splitting, validation)
train_dataset, test_dataset, x_validation, y_validation = loader.process()


Applying splitting strategy: global

Test shape: (16892, 4)
Train shape: (80393, 4)

Generating x_validation and y_validation: 
Test shape: (15729, 4)
Train shape: (61345, 4)


## Creating the agents

In [30]:
params = {
    "LinearUCB": {"alpha": 1.0, "item_var": 0.01, "iterations": 20, "num_lat": 20, "stop_criteria": 0.0009, "user_var": 0.01, "var": 0.05},
    "MostPopular": {},
    "EGreedy": {},
    "GenericThompsonSampling": {"alpha_0": {"LinearUCB": 100, "MostPopular": 1}, "beta_0": {"LinearUCB": 100, "MostPopular": 1}},
}

### Creating the simple agents

In [31]:
vf1 = LinearUCB(**params["LinearUCB"])
vf2 = MostPopular(**params["MostPopular"])
vf3 = EGreedy(**params["EGreedy"])

asp_sa = ASPGreedy()

agent1 = SimpleAgent(vf1, asp_sa, name="LinearUCB")
agent2 = SimpleAgent(vf2, asp_sa, name="MostPopular")
agent3 = SimpleAgent(vf3, asp_sa, name="EGreedy")

### Creating the Ensemble Agent

In [32]:
asp_sea = ASPGenericGreedy()
vf_sea = GenericThompsonSampling(**params["GenericThompsonSampling"])
ensemble_agent = SimpleEnsembleAgent(
    agents=[agent1, agent2],
    action_selection_policy=asp_sea,
    name="EnsembleAgent",
    value_function=vf_sea
)

In [33]:
agents = [agent1, agent2, agent3, ensemble_agent]

## Getting the recommendations

In [34]:
eval_policy = FixedInteraction(num_interactions=100, interaction_size=1, save_info=False)

In [35]:
interactions = {}
for agent in agents:
    print(agent.name)
    agent_interactions, action_info = eval_policy.evaluate(agent, train_dataset, test_dataset)
    interactions[agent.name] = agent_interactions

LinearUCB
Starting LinearUCB Training


rmse=0.800: 100%|██████████| 20/20 [00:24<00:00,  1.21s/it]


Ended LinearUCB Training


LinearUCB: 100%|██████████| 18900/18900 [00:24<00:00, 759.71it/s]


MostPopular
Starting MostPopular Training
Ended MostPopular Training


MostPopular: 100%|██████████| 18900/18900 [00:04<00:00, 4266.61it/s]


EGreedy
Starting EGreedy Training
Ended EGreedy Training


EGreedy: 100%|██████████| 18900/18900 [00:03<00:00, 6151.87it/s]


EnsembleAgent
Starting EnsembleAgent Training


rmse=0.800: 100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


Ended EnsembleAgent Training


EnsembleAgent: 100%|██████████| 18900/18900 [00:27<00:00, 682.31it/s] 


### Evaluating the models

In [37]:
# Cumulative Evaluation Setup
evaluator = UserCumulativeInteraction(
    ground_truth_dataset=test_dataset,
    num_interactions=100,
    interaction_size=1,
    interactions_to_evaluate=[5, 10, 20, 50, 100],
    relevance_evaluator_threshold=3.99
)

In [None]:
# Getting the results
cumulative_results = {}
for metric_class in [Hits, EPC, Recall, ILD]:
    for agent_name, agent_results in interactions.items():
        print(f"\nEvaluating {agent_name}\n")
        metric_values = evaluator.evaluate(metric_class=metric_class, results=agent_results)
        if metric_class.__name__ not in cumulative_results:
            cumulative_results[metric_class.__name__] = {}
        cumulative_results[metric_class.__name__].update({agent_name: metric_values})

In [39]:
cumulative_results.keys()

dict_keys(['Hits', 'EPC', 'Recall', 'ILD'])

In [40]:
cumulative_results["Hits"].keys()

dict_keys(['LinearUCB', 'MostPopular', 'EGreedy', 'EnsembleAgent'])

In [41]:
all_results = []
for metric_name, agent_values in cumulative_results.items():
    df_cumulative = pd.DataFrame(columns=["Model", 5, 10, 20, 50, 100])
    df_cumulative["Model"] = list(agent_values.keys())
    df_cumulative.set_index("Model", inplace=True)
    for agent_name, values in agent_values.items():
        df_cumulative.loc[agent_name] = [
            np.nanmean(list(metric_values.values())) for metric_values in values
        ]
    df_cumulative.columns = pd.MultiIndex.from_product([[metric_name], df_cumulative.columns])
    all_results.append(df_cumulative)

In [42]:
all_results = pd.concat(all_results, axis=1)
all_results

Unnamed: 0_level_0,Hits,Hits,Hits,Hits,Hits,EPC,EPC,EPC,EPC,EPC,Recall,Recall,Recall,Recall,Recall,ILD,ILD,ILD,ILD,ILD
Unnamed: 0_level_1,5,10,20,50,100,5,10,20,50,100,5,10,20,50,100,5,10,20,50,100
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
LinearUCB,1.94709,3.529101,6.285714,13.835979,22.57672,0.911348,0.922463,0.934473,0.94796,0.957696,0.054507,0.105812,0.175569,0.331571,0.508525,0.247045,0.281743,0.302369,0.329065,0.352404
MostPopular,1.666667,2.904762,5.126984,10.063492,16.703704,0.90456,0.911084,0.922276,0.938327,0.947497,0.04906,0.085504,0.13688,0.232438,0.365345,0.259048,0.282381,0.27854,0.291894,0.313466
EGreedy,1.126984,2.15873,4.010582,8.497354,14.613757,0.941562,0.945466,0.948347,0.953116,0.957866,0.023325,0.043733,0.084292,0.182637,0.318325,0.294796,0.29907,0.310328,0.330818,0.350296
EnsembleAgent,1.941799,3.550265,6.507937,14.079365,22.804233,0.912394,0.920391,0.931499,0.945274,0.954998,0.054251,0.102216,0.179246,0.350907,0.525458,0.234386,0.281526,0.302593,0.327921,0.350689
