In [1]:
from recsim.simulator import environment
from recsim.simulator import recsim_gym

from movies_lib.samplers import MovieDocumentSampler
from movies_lib.model import MovieUserModel
from movies_lib.sb3_wrapper import RecSimWrapper
import matplotlib.pyplot as plt

from stable_baselines3 import PPO  # Using Proximal Policy Optimization, but you can choose another algorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback


2024-04-19 16:28:37.375858: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
genres = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Horror', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller']

In [3]:
slate_size = 5
num_candidates = 20
user = MovieUserModel(slate_size)
doc = MovieDocumentSampler()

# Initialize the environment for the movie recommendation system
movie_env = environment.Environment(
    user,  # Use the adapted user model for movies
    doc,     # Use the adapted document sampler for movies
    num_candidates,
    slate_size,
    resample_documents=True  # Enable resampling of documents for each step
)

eval_env = environment.Environment(
    user,  # Use the adapted user model for movies
    doc,     # Use the adapted document sampler for movies
    num_candidates,
    slate_size,
    resample_documents=False  # Enable resampling of documents for each step
)

In [4]:
def movie_watched_rating_reward(responses):
    reward = 0.0
    for response in responses:
        if response.watched:
            reward += response.rating
    return reward/(len(responses)*5)

In [5]:
movie_gym_env = recsim_gym.RecSimGymEnv(movie_env, movie_watched_rating_reward)
movie_gym_env_eval = recsim_gym.RecSimGymEnv(eval_env, movie_watched_rating_reward)

In [6]:
env = RecSimWrapper(movie_gym_env)
eval_env = RecSimWrapper(movie_gym_env_eval)

In [7]:
env.action_space

MultiDiscrete([20 20 20 20 20])

In [8]:
policy_kwargs = dict(
    net_arch=[256, 256]  # Two hidden layers with 256 neurons each
)

In [9]:
model = PPO("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs,learning_rate=0.0001)

  return torch._C._cuda_getDeviceCount() > 0


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [10]:
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
                             log_path='./logs/', eval_freq=500,n_eval_episodes=10,
                             deterministic=True, render=False)

# Include the callback in the learning process
model.learn(total_timesteps=10000, callback=eval_callback,progress_bar=True)



---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 500      |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.7     |
| time/              |          |
|    total_timesteps | 1000     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 1500     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.7     |
| time/              |          |
|    total_timesteps | 2000     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 124      |
|    ep_rew_mean     | 24.1     |
| time/              |          |
|    fps             | 168      |
|    iterations      | 1        |
|    time_elapsed    | 12       |
|    total_timesteps | 2048     |
---------------------------------


-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 120         |
|    mean_reward          | 23.8        |
| time/                   |             |
|    total_timesteps      | 2500        |
| train/                  |             |
|    approx_kl            | 0.021363668 |
|    clip_fraction        | 0.271       |
|    clip_range           | 0.2         |
|    entropy_loss         | -15         |
|    explained_variance   | -0.256      |
|    learning_rate        | 0.0001      |
|    loss                 | 0.185       |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0933     |
|    value_loss           | 0.928       |
-----------------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 3000     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 3500     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 4000     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 124      |
|    ep_rew_mean     | 24.3     |
| time/              |          |
|    fps             | 164      |
|    iterations      | 2        |
|    time_elapsed    | 24       |
|    total_timesteps | 4096     |
---------------------------------


-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 120         |
|    mean_reward          | 23.4        |
| time/                   |             |
|    total_timesteps      | 4500        |
| train/                  |             |
|    approx_kl            | 0.019764272 |
|    clip_fraction        | 0.235       |
|    clip_range           | 0.2         |
|    entropy_loss         | -15         |
|    explained_variance   | -0.0265     |
|    learning_rate        | 0.0001      |
|    loss                 | 0.397       |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.084      |
|    value_loss           | 1.77        |
-----------------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.6     |
| time/              |          |
|    total_timesteps | 5000     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.5     |
| time/              |          |
|    total_timesteps | 5500     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.5     |
| time/              |          |
|    total_timesteps | 6000     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 125      |
|    ep_rew_mean     | 24.4     |
| time/              |          |
|    fps             | 161      |
|    iterations      | 3        |
|    time_elapsed    | 38       |
|    total_timesteps | 6144     |
---------------------------------


-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 120         |
|    mean_reward          | 23          |
| time/                   |             |
|    total_timesteps      | 6500        |
| train/                  |             |
|    approx_kl            | 0.019290417 |
|    clip_fraction        | 0.228       |
|    clip_range           | 0.2         |
|    entropy_loss         | -15         |
|    explained_variance   | -0.0372     |
|    learning_rate        | 0.0001      |
|    loss                 | 0.944       |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0857     |
|    value_loss           | 2.79        |
-----------------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.2     |
| time/              |          |
|    total_timesteps | 7000     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23       |
| time/              |          |
|    total_timesteps | 7500     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.1     |
| time/              |          |
|    total_timesteps | 8000     |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 125      |
|    ep_rew_mean     | 24.4     |
| time/              |          |
|    fps             | 159      |
|    iterations      | 4        |
|    time_elapsed    | 51       |
|    total_timesteps | 8192     |
---------------------------------


-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 120         |
|    mean_reward          | 23.8        |
| time/                   |             |
|    total_timesteps      | 8500        |
| train/                  |             |
|    approx_kl            | 0.018872611 |
|    clip_fraction        | 0.228       |
|    clip_range           | 0.2         |
|    entropy_loss         | -15         |
|    explained_variance   | -0.0356     |
|    learning_rate        | 0.0001      |
|    loss                 | 1.19        |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0855     |
|    value_loss           | 3.47        |
-----------------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 9000     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.8     |
| time/              |          |
|    total_timesteps | 9500     |
---------------------------------


---------------------------------
| eval/              |          |
|    mean_ep_length  | 120      |
|    mean_reward     | 23.7     |
| time/              |          |
|    total_timesteps | 10000    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 125      |
|    ep_rew_mean     | 24.4     |
| time/              |          |
|    fps             | 158      |
|    iterations      | 5        |
|    time_elapsed    | 64       |
|    total_timesteps | 10240    |
---------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x7efecbe22460>

In [11]:
obs = env.reset()
action, _states = model.predict(obs)

In [12]:
action

array([12,  5, 14,  3,  4])

In [13]:
observation, reward, done, info = env.env.step(action)

In [14]:
reward

0.2

In [15]:
action

array([12,  5, 14,  3,  4])

In [16]:
observation

{'user': array([0.49106572]),
 'doc': OrderedDict([('206520', array([3.43594356, 4.        ])),
              ('206521', array([0.56382752, 0.        ])),
              ('206522', array([4.65348487, 0.        ])),
              ('206523', array([4.02508554, 1.        ])),
              ('206524', array([1.41051154, 0.        ])),
              ('206525', array([0.61482368, 5.        ])),
              ('206526', array([4.51292521, 9.        ])),
              ('206527', array([4.37440284, 8.        ])),
              ('206528', array([4.7987108, 5.       ])),
              ('206529', array([3.31983315, 7.        ])),
              ('206530', array([3.37729225, 0.        ])),
              ('206531', array([2.16570497, 8.        ])),
              ('206532', array([2.72696541, 9.        ])),
              ('206533', array([1.50283216, 7.        ])),
              ('206534', array([1.21736726, 0.        ])),
              ('206535', array([2.24862341, 9.        ])),
              ('20653

In [17]:
k = list(observation['doc'].keys())
v = list(observation['doc'].values())
mv = list(zip(k,v))
for i in action:
    print(mv[i])

('206532', array([2.72696541, 9.        ]))
('206525', array([0.61482368, 5.        ]))
('206534', array([1.21736726, 0.        ]))
('206523', array([4.02508554, 1.        ]))
('206524', array([1.41051154, 0.        ]))
