# Experiment 1) PPO

## 1) Setup

### 1.1) Install necessary dependencies

In [None]:
# !pip install "stable-baselines3[extra]>=2.0.0a4"
# !pip install wandb

### 1.2) Import Libraries

In [17]:
#=== Standard libraries for custom environment
import sys
import random
import os
import numpy as np
import gymnasium as gym
from typing import List
from gymnasium import spaces
from datetime import datetime
# from sklearn.preprocessing import MinMaxScaler

#=== Custom utilities stored in .py files
# Add the /app/src directory to the Python path
sys.path.append(os.path.abspath('../app/src'))
from spark.data import loader
from spark.data.models import Customer, Product, Category, Interaction, InteractionType
from spark import utils

#=== Stable Baselines3 model libraries
import stable_baselines3
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.logger import configure
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder, VecNormalize
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch
from torch import nn
import wandb

#=== Print the versions of gymnasium and stable_baselines3 for debugging purposes 
print(f"{gym.__version__=}")
print(f"{stable_baselines3.__version__=}")

### 1.3) Experiment configuration

In [None]:
project='spark'
env_name = 'Spark'
model_id='ppo'
tb_log_name='PPO'
label='main'
inc=1
run_name=f'{model_id}-{label}-{inc}'
model_name_final = f"{model_id}_model_final"
total_timesteps=int(3e6) # 3m timesteps for full training
param_n_envs=1 # Added in case we wanted to allow multi-environment training via vector wrapper
save_interval = total_timesteps/10

# Weights and Biases callback configuration
config = {
    "total_timesteps": total_timesteps,
    "env_name": env_name,
}

### 1.4) Define Directories

In [None]:
# output directoies
base_dir = '.'
output_dir = os.path.join(base_dir, 'output')
env_dir =  os.path.join(output_dir, project)
logs_dir = os.path.join(env_dir, fr'logs/{run_name}')
models_dir = os.path.join(env_dir, fr'models/{run_name}')

os.makedirs(logs_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

print(logs_dir)
print(models_dir)

### 1.4) Weights and Bias initialisation

In [None]:
os.environ['WANDB_API_KEY'] = 'please insert your API key here!'
os.environ['WANDB_NOTEBOOK_NAME'] = 'Exp1-PPO'
wandb.login()

run = wandb.init(
    project=project,
    name=run_name,
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    save_code=True,  # optional
)

## 2) Utilities

### 2.1) Custom Callback utility

In [None]:
class WandbEvalCallback(EvalCallback, BaseCallback):
    """
        Custom Callback that combines EvalCallback() and WandbEvalCallback().
            - EvalCallback() independently evaluates the RL model and returns the "Best performing" one.
            - WandbEvalCallback() ensures logs are uploaded to the Weights and Biases online server.

        Sources:
            - SB3 API Docs:         https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
            - SB3 codebase:         https://github.com/DLR-RM/stable-baselines3/blob/c62e9259db363bf32bd920405dbe83db94123271/stable_baselines3/common/callbacks.py
            - Wandb codebase:       https://github.com/wandb/wandb/blob/main/wandb/integration/sb3/sb3.py
    """

    def __init__(self, 
                 eval_env, 
                 wandb_run,
                 save_interval: int = 10000,
                 eval_freq: int = 10000, 
                 save_path: str = None, 
                 n_eval_episodes: int = 5, 
                 deterministic: bool = True, 
                 render: bool = False, 
                 verbose: int = 1):
        
        # Initialize EvalCallback
        EvalCallback.__init__(self, 
                               eval_env=eval_env, 
                               eval_freq=eval_freq, 
                               best_model_save_path=save_path, 
                               n_eval_episodes=n_eval_episodes, 
                               deterministic=deterministic, 
                               render=render, 
                               verbose=verbose)

        # Initialize BaseCallback (WandbCallback)
        BaseCallback.__init__(self, verbose=verbose)
        self.wandb_run = wandb_run
        self.best_mean_reward = -np.inf
        self.save_interval = save_interval
        self.save_path = save_path

    def _on_step(self) -> bool:
        # Call the parent method to perform evaluation
        super_result = super(WandbEvalCallback, self)._on_step()
        
        if super_result:
            # Log metrics to wandb
            self.wandb_run.log({"mean_reward": self.last_mean_reward, "step": self.num_timesteps})

            # Log best model if it was updated
            if self.save_path is not None and self.best_mean_reward is not None:
                self.wandb_run.log({"best_mean_reward": self.best_mean_reward, "step": self.num_timesteps})

        # Save the model every 'save_interval' steps
        if self.num_timesteps % self.save_interval == 0:
            interval_save=True
            save_file = os.path.join(self.save_path, f'{model_id}_model_{self.num_timesteps}')
            self.model.save(save_file)
            if self.verbose > 0:
                print(f'Saving model to {save_file}.zip')

        return super_result or interval_save

    def _on_training_end(self) -> None:
        # Log final metrics at the end of training
        self.wandb_run.log({"training_end": True})


## 3) Custom Environment Set-up

### 3.1) Product Recommendation Environment

In [None]:

"""
Custom gym environment for product recommendations where states represent customer interactions
actions are products;

Each customer interaction is a state. In the step function, transition of states will be the same 
customer with the new interaction for a product. The transition ends when the customer made a purchase or session ends.

The aim is to maximise rewards that can lead to a purchase.

It is better to have a customer interaction to represent a state rather than a time series of steps leading to a purchase. 
The latter method may have an incomplete where customer exits the application. Also, even if the customer did not purchase,
this information is still valuable for recommendations. Hence every state is a customer interaction.

An addition meta data for each customer will be stored to understand the context of the user. For example, if a product is 
purchase many times, this may factor into preference of the states.
---
Personalization: By incorporating the user ID, the model can tailor recommendations specifically to individual users, allowing 
it to learn unique user preferences and behaviors over time.

VS

Overfitting: If the model learns too much from the user ID directly, it might overfit to individual user patterns, potentially 
missing out on broader trends that could be useful for all users.

SOLUTION
Use Embeddings: Instead of directly using user IDs, consider using an embedding layer that transforms the user ID into a dense vector representation. This approach reduces dimensionality while capturing user-specific features.

Combine Features: Use user ID embeddings in conjunction with other features like user demographics, interaction history, and product attributes. This can create a more holistic view of user preferences.

Regularization: Implement techniques like dropout or weight regularization to mitigate overfitting when using user IDs or their embeddings.

Batch Normalization: Use batch normalization to stabilize learning, especially if the user ID leads to a wide range of outputs.
----
"""
class RecommendationEnv(gym.Env):
    def __init__(self, users:List[Customer], products:List[Product], top_k:int):
        super().__init__()
        
        self.users = users                  # list of users as states
        self.products = products            # products as actions, potential recommendations
        self.top_k = top_k                  # number of recommendations
        self.user_idx = 0                   # index of users list, not user_id
        self.current_step = 0               # step is also the interactions list index
        self.categories = loader.load_categories()
        
        self.action_space = spaces.MultiDiscrete([len(products)] * 10) 
        
        # number of customers as states
        # states are derived from customer profiles and interactioms
        # Users list will keep track of unique users
        # States include subset of features including product, interaction, ratings, and time in one-hot-encoding format
        # States exclude user_ids for policy network generalisation. But internal users list will be used as reference        
        self.observation_space = spaces.Dict({
            'pref_prod': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.float32),
            'pref_cat': spaces.Box(low=0, high=1, shape=(len(self.categories),), dtype=np.float32),
            'buys': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.uint8),
            'views': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.uint8),
            'likes': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.uint8),
            'ratings': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.uint8),
            'product': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.uint8),
            'interaction': spaces.Box(low=0, high=1, shape=(len(list(InteractionType)),), dtype=np.uint8),
            'rating': spaces.Discrete(6)
            }) 
            ## add more features like time, ignored recommendtions, engagement etc
        
    def reset(self, seed=None, options=None):
        # Call the parent class's reset method to handle seeding
        super().reset(seed=seed)
        
        self.user_idx = np.random.randint(len(self.users)) # may run throught users one by one
        user = self.users[self.user_idx]
        self.current_step = 0
        return self._get_observation(user), {}# get current user features as states

    def step(self, rec_products):
        """ randomly interacting with product to mimick real user unpredictable behavious """
        self.current_step += 1
        user = self.users[self.user_idx]
        
        reward = 0
        done = False
        
        # simulate selected recommended product and interaction
        seleted_pid, interaction_type = self._simulate_interaction(rec_products) # generate random interaction
        # seleted_pid = random.choice(rec_products)
        # interaction_type = random.choice(list(InteractionType))
        
        random_rating = 0
        
        if interaction_type == InteractionType.NONE:
            reward = -1 # no interaction, customers not interested in recommendations
        elif interaction_type ==  InteractionType.VIEW:
            reward = 3
        elif interaction_type ==  InteractionType.LIKE:
            reward = 10
        elif interaction_type ==  InteractionType.BUY:
            reward = 50
        elif interaction_type ==  InteractionType.RATE:
            # generate rating, reward 1-2 is negative 3 neutral and 5 positive
            random_rating = random.randint(0, 5)
            reward = random_rating -1        
        elif interaction_type ==  InteractionType.SESSION_START:
            reward = 0
        elif interaction_type ==  InteractionType.SESSION_CLOSE:
            done = True
            reward = 0 # TODO: check if engament is too short
        else:
            reward = 0
        
        # generate random interaction
        new_interaction = Interaction(self.current_step, datetime.now(), user.idx, seleted_pid, interaction_type, random_rating)
        # reward = self._calculate_reward(user, product)
        
        return self._update_observation(new_interaction), reward, done, False, {}

    def _update_observation(self, interaction:Interaction):   
        # update user data     
        user = self.users[self.user_idx]   
        pid = interaction.product_idx    
         
        if interaction == InteractionType.VIEW:
            user.views[pid] += 1
        elif interaction == InteractionType.LIKE:
            user.likes[pid] += 1
        elif interaction == InteractionType.BUY:
            user.buys[pid] += 1
        elif interaction == InteractionType.RATE:
            user.rates[pid] = interaction.value
          
        # update observation based on new data  
        obs = {
                'pref_prod': self._get_product_preferences(user),
                'pref_cat': self._get_category_preferences(user), 
                'buys': utils.normalise(user.buys),
                'views': utils.normalise(user.views),
                'likes': utils.normalise(user.likes),
                'ratings': user.ratings,
                'product': utils.one_hot_encode(pid, len(self.products)),
                'interaction': self._get_interaction_observation(interaction),
                'rating': interaction.value if interaction.type == InteractionType.RATE else 0 
            }
        
        return obs

    def _get_observation(self, user:Customer): 
        
        obs = {
                'pref_prod': self._get_product_preferences(user),
                'pref_cat': self._get_category_preferences(user), 
                'buys': utils.normalise(user.buys),
                'views': utils.normalise(user.views),
                'likes': utils.normalise(user.likes),
                'ratings': user.ratings,
                'product': np.zeros(len(self.products)),
                'interaction': np.zeros(len(list(InteractionType))),
                'rating': 0 
            }
        
        return obs
        
    def _get_interaction_observation(self, interaction:Interaction):
        idx = list(InteractionType).index(interaction.type)
        size = len(InteractionType)
        
        return utils.one_hot_encode(idx, size)
    
    # calculate preferences based on past interactions
    def _get_product_preferences(self, user:Customer):
        view_prefs = user.views / 20
        purchase_prefs = user.buys
        like_prefs = user.likes / 15

        rating_prefs = user.ratings.copy()
        rating_prefs[rating_prefs > 0] -= 2
        
        product_prefs = view_prefs + purchase_prefs + like_prefs+ rating_prefs
        
        return product_prefs    # calculate preferences based on past interactions
    
    def _get_category_preferences(self, user:Customer):
        prod_prefs = self._get_product_preferences(user)
        cat_prefs = np.zeros(len(self.categories), np.float32)
        
        for idx, prod_pref in enumerate(prod_prefs):
            if prod_pref > 0:
                product = self.products[idx]
                cat_idx = product.category.idx
                cat_prefs[cat_idx] += prod_pref # accumulation of fav products for this cat
                # print(f"added pf {prod_pref} to cat {cat_idx}")
                
        cat_prefs = cat_prefs / 5 # reduce space     
           
        return cat_prefs

    def _simulate_interaction(self, product_ids):        
        user = self.users[self.user_idx]
        product_list = []
        
        # simulate selection
        num_products = len(product_ids)
        prod_scores = np.zeros(num_products, np.uint8)
        product_prefs = self._get_product_preferences(user)
        category_prefs = self._get_category_preferences(user)
        product_probs = np.full((num_products,), 1.0 / num_products) # equal probs by default
        product_probs[-1] = 0.1 # lower ending epsidoe flag to encourage longer training
        
        for idx, pid in enumerate(product_ids):
            product_list.append(self.products[pid]) # get the product objects
            prod_scores[idx] = product_prefs[pid]
            
        # combining category prefs to calculate probabilities
        for idx, product in enumerate(product_list):
            cid = product.category.idx
            prod_scores[idx] = category_prefs[cid] 
    
        # Ensure the probabilities sum to 1 for a valid probability distribution
        if np.argmax(prod_scores) > 0: # the product is in the preferences
            product_probs = np.array(prod_scores) / sum(prod_scores)

        # Randomly select a product based on the defined probabilities
        selected_product_id = np.random.choice(product_ids, p=product_probs)
        
        # simulate interaction for the selected product
        inter_types = list(InteractionType)
        inter_scores = np.zeros(len(inter_types), np.uint8)
        inter_probs = np.full((len(inter_types),), 1.0 / len(inter_types)) # equal probs by default
        
        
        for idx, inter_type in enumerate(inter_types):
            if inter_type == InteractionType.VIEW:
                inter_scores[idx] = user.views[selected_product_id]
            if inter_type == InteractionType.LIKE:
                inter_scores[idx] = user.likes[selected_product_id]
            if inter_type == InteractionType.BUY:
                inter_scores[idx] = user.buys[selected_product_id]
            if inter_type == InteractionType.RATE:
                inter_scores[idx] = user.ratings[selected_product_id]
        
        if np.argmax(inter_scores) > 0:
            inter_scores[inter_scores == 0] = 1 # default score for interaction that are 0
            inter_probs = np.array(inter_scores) / sum(inter_scores)

        # Randomly select a product based on the defined probabilities
        selected_interaction_type = np.random.choice(inter_types, p=inter_probs)
        
        return selected_product_id, selected_interaction_type

    def seed(self, seed=None):
        """
        Set the seed for reproducibility.
        """
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        random.seed(seed)
        np.random.seed(seed)
        return [seed]
    
    def render(self, mode='human'):
        if hasattr(self, 'last_action'):
            print(f"Recommended Product ID (Last Action): {self.last_action}")
        else:
            print("No product recommended yet.")

        # Optionally, print the reward received for the last action
        if hasattr(self, 'last_reward'):
            print(f"Reward for Last Action: {self.last_reward}")

        print("-----")

    def close(self):
        pass

### 2.3) Load environment data

Loading products and customers from custom dataloaders in spark.data.loaders.py

In [None]:
products = loader.load_products()
customers = loader.load_customers(include_interactions=True)

Initialising customer profiles from interactions data

In [None]:
for customer in customers:
    customer.views = np.zeros(len(products), dtype=np.int8)
    customer.likes = np.zeros(len(products), dtype=np.int8)
    customer.buys = np.zeros(len(products), dtype=np.int8)
    customer.ratings = np.zeros(len(products), dtype=np.int8)
    for interaction in customer.interactions:
        i_type = interaction.type.value    
        product_idx = interaction.product_idx  
        # print(f"customer {customer.idx} interaction {type} product {product_idx}")
        if i_type == InteractionType.VIEW.value:
            customer.views[product_idx] += 1
            # print(f"customer {customer.idx} view", customer.views)
        elif i_type == InteractionType.LIKE.value:
            customer.likes[product_idx] += 1
            # print(f"customer {customer.idx} like", customer.likes)
        elif i_type == InteractionType.BUY.value:
            customer.buys[product_idx] += 1
            # print(f"customer {customer.idx} buy", customer.buys)
        elif i_type == InteractionType.RATE.value:
            customer.ratings[product_idx] = interaction.value
            # print(f"customer {customer.idx} rate", customer.rates)

### 2.4) Initialise training and evaluation environments

In [None]:
# Training environment
env = DummyVecEnv([lambda: Monitor(RecommendationEnv(customers, products, top_k=10))])
env.seed(100)

# Evaluation environment for WandbEvalCallback
eval_env = DummyVecEnv([lambda: Monitor(RecommendationEnv(customers, products, top_k=10))])

print(env.observation_space)

## 3) Define model

### 3.1) Custom Feature Extractor

In [None]:
# Feature extraction from frames as observations / states
class CustomANN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=128):
        super(CustomANN, self).__init__(observation_space, features_dim)
        
        # Define your neural network layers
        self.net = nn.Sequential(
            nn.Linear(observation_space.shape[0], 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim)  # Output dimension should match features_dim
        )
    
    def forward(self, observations):
        return self.net(observations)
    

### 3.2) Define Hyperparameters

UserWarning: You have specified a mini-batch size of 64, but because the `RolloutBuffer` is of size `n_steps * n_envs = 100`, after every 1 untruncated mini-batches, there will be a truncated mini-batch of size 36
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=100 and n_envs=1)
  warnings.warn(


In [None]:
param_clip_range = 0.2
param_learning_rate = 0.0001
param_gamma=0.995
param_gae_lambda=0.95
param_n_steps=100
param_batch_size=50 # Set to 50 as it is a factor of 100
param_n_epochs=10
param_ent_coef = 0.01

### 3.3) Define model

In [None]:
model = PPO(
    env=env,
    policy='MultiInputPolicy',
    verbose=0,
    clip_range=param_clip_range,
    learning_rate=param_learning_rate,
    n_epochs=param_n_epochs,
    n_steps=param_n_steps,
    ent_coef=param_ent_coef,
    batch_size=param_batch_size,
    gamma=param_gamma,
    gae_lambda=param_gae_lambda,
    # policy_kwargs={'features_extractor_class': CustomANN},
    tensorboard_log=logs_dir
)

## 4) Model training

### 4.1) Execute training

In [None]:
# Configure logger for stdout, CSV, and TensorBoard
new_logger = configure(str(logs_dir), ["stdout", "csv", "tensorboard"])
model.set_logger(new_logger)

# Model learning with WandbCallback and interim model saving
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbEvalCallback(
                                eval_env=eval_env,
                                wandb_run=run,
                                save_interval=save_interval,
                                eval_freq=save_interval,
                                save_path=models_dir
                            )
    )

# Save the final model after training completes
final_model_path = os.path.join(models_dir, model_name_final)
model.save(final_model_path)

### 4.2) Close out WandB session

In [None]:
api = wandb.Api()

# Access attributes directly from the run object
# or from the W&B App
username = wandb.run.entity
project = wandb.run.project
run_id = wandb.run.id

run = api.run(f"{username}/{project}/{run_id}")
run.config["bar"] = 32
run.update()

## 5) Testing

### 5.1) Load tensorboard session

In [None]:
# %load_ext tensorboard

In [None]:
# %tensorboard --logdir={log_dir}

### 5.2) Outcome debugging

In [None]:
# Construct observation data for customer 10
test_customer = customers[10] 

# simulated laset selected product
test_product = products[53]

# it only uses the previous interaction to predict
previous_interaction = Interaction(idx = '0', 
                          timestamp = datetime.now(), 
                          customer_idx = test_customer.idx, 
                          product_idx = test_product.idx, 
                          type = InteractionType.RATE,
                          value = 5,)

new_obs = env.update_observation(test_customer, interaction)

In [None]:
recommended_products, _states = model.predict(new_obs)
print(recommended_products)

In [None]:
obs, rewards, done, flag, simulated_interaction_info = env.step(recommended_products)
simulated_interaction_info

In [None]:
obs = env.reset()
for _ in range(20):  # Make 10 recommendations
    action, _states = model.predict(obs)
    obs, rewards, done, _ = env.step(action)
    print("Recommended products:", action)