From b725868de448738ddfce49e724b1094b7fe121eb Mon Sep 17 00:00:00 2001 From: thiagodks Date: Mon, 8 May 2023 11:35:54 -0300 Subject: [PATCH] add tutorial for ensemble-agent --- tutorials/example_ensemble_agents.ipynb | 437 ++++++++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 tutorials/example_ensemble_agents.ipynb diff --git a/tutorials/example_ensemble_agents.ipynb b/tutorials/example_ensemble_agents.ipynb new file mode 100644 index 0000000..36c9b43 --- /dev/null +++ b/tutorials/example_ensemble_agents.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from irec.recommendation.agents.value_functions import LinearUCB, MostPopular, GenericThompsonSampling\n", + "from irec.recommendation.agents.action_selection_policies import ASPGenericGreedy, ASPGreedy\n", + "from irec.offline_experiments.metric_evaluators import UserCumulativeInteraction\n", + "from irec.offline_experiments.evaluation_policies import FixedInteraction\n", + "from irec.recommendation.agents import SimpleEnsembleAgent, SimpleAgent\n", + "from irec.offline_experiments.metrics import Hits\n", + "from irec.environment.loader import FullData" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Applying splitting strategy: global\n", + "\n", + "Test shape: (16892, 4)\n", + "Train shape: (80393, 4)\n", + "\n", + "Generating x_validation and y_validation: \n", + "Test shape: (15729, 4)\n", + "Train shape: (61345, 4)\n" + ] + } + ], + "source": [ + "# Dataset\n", + "dataset = {\n", + " 'path': \"../datasets/MovieLens 100k/ratings.csv\",\n", + " 'random_seed': 0,\n", + " 'file_delimiter': \",\",\n", + " 'skip_head': True\n", + "}\n", + "# Splitting\n", + "splitting = {'strategy': \"global\", 'train_size': 0.8, 'test_consumes': 5}\n", + "validation = {'validation_size': 0.2}\n", + "# Loader\n", + "loader = FullData(dataset, splitting, validation)\n", + "train_dataset, test_dataset, x_validation, y_validation = loader.process()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating the agents" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "params = {\n", + " \"LinearUCB\": {\"alpha\": 1.0, \"item_var\": 0.01, \"iterations\": 20, \"num_lat\": 20, \"stop_criteria\": 0.0009, \"user_var\": 0.01, \"var\": 0.05},\n", + " \"MostPopular\": {},\n", + " \"GenericThompsonSampling\": {\"alpha_0\": {\"LinearUCB\": 100, \"MostPopular\": 1}, \"beta_0\": {\"LinearUCB\": 100, \"MostPopular\": 1}},\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating the simple agents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vf1 = LinearUCB(**params[\"LinearUCB\"])\n", + "vf2 = MostPopular(**params[\"MostPopular\"])\n", + "\n", + "asp_sa = ASPGreedy()\n", + "\n", + "agent1 = SimpleAgent(vf1, asp_sa, name=\"LinearUCB\")\n", + "agent2 = SimpleAgent(vf2, asp_sa, name=\"MostPopular\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating the Ensemble Agent" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "asp_sea = ASPGenericGreedy()\n", + "vf_sea = GenericThompsonSampling(**params[\"GenericThompsonSampling\"])\n", + "ensemble_agent = SimpleEnsembleAgent(agents=[agent1, agent2], action_selection_policy=asp_sea, name=\"EnsebleAgent\", value_function=vf_sea)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "agents = [agent1, agent2, ensemble_agent]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting the recommendations" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "eval_policy = FixedInteraction(num_interactions=100, interaction_size=1, save_info=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LinearUCB\n", + "Starting LinearUCB Training\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "rmse=0.800: 100%|██████████| 20/20 [00:15<00:00, 1.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ended LinearUCB Training\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LinearUCB: 100%|██████████| 18900/18900 [00:17<00:00, 1083.03it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MostPopular\n", + "Starting MostPopular Training\n", + "Ended MostPopular Training\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "MostPopular: 100%|██████████| 18900/18900 [00:02<00:00, 6731.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "EnsebleAgent\n", + "Starting EnsebleAgent Training\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "rmse=0.801: 100%|██████████| 20/20 [00:18<00:00, 1.08it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ended EnsebleAgent Training\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "EnsebleAgent: 100%|██████████| 18900/18900 [00:22<00:00, 844.46it/s] \n" + ] + } + ], + "source": [ + "interactions = {}\n", + "for agent in agents:\n", + " print(agent.name)\n", + " agent_interactions, action_info = eval_policy.evaluate(agent, train_dataset, test_dataset)\n", + " interactions[agent.name] = agent_interactions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluating the models" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Cumulative Evaluation Setup\n", + "evaluator = UserCumulativeInteraction(\n", + " ground_truth_dataset=test_dataset,\n", + " num_interactions=100,\n", + " interaction_size=1,\n", + " interactions_to_evaluate=[5, 10, 20, 50, 100],\n", + " relevance_evaluator_threshold=3.99\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Evaluating LinearUCB\n", + "\n", + "Computing interaction 5 with UserCumulativeInteraction\n", + "Computing interaction 10 with UserCumulativeInteraction\n", + "Computing interaction 20 with UserCumulativeInteraction\n", + "Computing interaction 50 with UserCumulativeInteraction\n", + "Computing interaction 100 with UserCumulativeInteraction\n", + "UserCumulativeInteraction spent 0.86 seconds executing Hits metric\n", + "\n", + "Evaluating MostPopular\n", + "\n", + "Computing interaction 5 with UserCumulativeInteraction\n", + "Computing interaction 10 with UserCumulativeInteraction\n", + "Computing interaction 20 with UserCumulativeInteraction\n", + "Computing interaction 50 with UserCumulativeInteraction\n", + "Computing interaction 100 with UserCumulativeInteraction\n", + "UserCumulativeInteraction spent 0.38 seconds executing Hits metric\n", + "\n", + "Evaluating EnsebleAgent\n", + "\n", + "Computing interaction 5 with UserCumulativeInteraction\n", + "Computing interaction 10 with UserCumulativeInteraction\n", + "Computing interaction 20 with UserCumulativeInteraction\n", + "Computing interaction 50 with UserCumulativeInteraction\n", + "Computing interaction 100 with UserCumulativeInteraction\n", + "UserCumulativeInteraction spent 0.38 seconds executing Hits metric\n" + ] + } + ], + "source": [ + "# Getting the results\n", + "cumulative_results = {}\n", + "for agent_name, agent_results in interactions.items():\n", + " print(f\"\\nEvaluating {agent_name}\\n\")\n", + " hits_values = evaluator.evaluate(metric_class=Hits, results=agent_results)\n", + " cumulative_results[agent_name] = hits_values" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
5102050100
Model
LinearUCB1.9365083.476196.33333314.08465622.915344
MostPopular1.6666672.9047625.12698410.06349216.703704
EnsebleAgent1.8465613.4708996.49206314.13227523.026455
\n", + "
" + ], + "text/plain": [ + " 5 10 20 50 100\n", + "Model \n", + "LinearUCB 1.936508 3.47619 6.333333 14.084656 22.915344\n", + "MostPopular 1.666667 2.904762 5.126984 10.063492 16.703704\n", + "EnsebleAgent 1.846561 3.470899 6.492063 14.132275 23.026455" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_cumulative = pd.DataFrame(columns=[\"Model\", 5, 10, 20, 50, 100])\n", + "df_cumulative[\"Model\"] = list(cumulative_results.keys())\n", + "df_cumulative.set_index(\"Model\", inplace=True)\n", + "for agent_name, results in cumulative_results.items():\n", + " df_cumulative.loc[agent_name] = [\n", + " np.mean(list(metric_values.values())) for metric_values in results\n", + " ]\n", + "df_cumulative" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}