# Experiment 2) A2C

## 1) Setup

### 1.1) Install necessary dependencies

In [2]:
# Update and install display packages and stable baseline 3
!pip install -r requirements.txt

# !pip install "stable-baselines3[extra]>=2.0.0a4"
!pip install wandb

[0m

### 1.2) Import Libraries

In [None]:
#=== Standard libraries for custom environment
import sys
import random
import os
import numpy as np
import gymnasium as gym
from typing import List
from gymnasium import spaces
from datetime import datetime
from copy import deepcopy
# from sklearn.preprocessing import MinMaxScaler

#=== Custom utilities stored in .py files
# Add the /app/src directory to the Python path
from app.src.spark.data import loader
from app.src.spark.data.models import Customer, Product, Category, Interaction, InteractionType
from app.src.spark import utils

#=== Stable Baselines3 model libraries
import stable_baselines3
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.logger import configure
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch
from torch import nn
import wandb

#=== Print the versions of gymnasium and stable_baselines3 for debugging purposes 
print(f"{gym.__version__=}")
print(f"{stable_baselines3.__version__=}")

gym.__version__='1.0.0'
stable_baselines3.__version__='2.4.0a11'


### 1.3) Experiment configuration

In [None]:
project='spark'
env_name = 'Spark'
model_id='a2c'
tb_log_name='A2C'
label='main'
inc=2
run_name=f'{model_id}-{label}-{inc}'
model_name_final = f"{model_id}_model_final"
total_timesteps=int(3e6) # 3m timesteps for full training
param_n_envs=1 # Added in case we wanted to allow multi-environment training via vector wrapper
save_interval = total_timesteps/10

# Weights and Biases callback configuration
config = {
    "total_timesteps": total_timesteps,
    "env_name": env_name,
}

### 1.4) Define Directories

In [5]:
# output directoies
base_dir = '.'
output_dir = os.path.join(base_dir, 'output')
env_dir =  os.path.join(output_dir, project)
logs_dir = os.path.join(env_dir, fr'logs/{run_name}')
models_dir = os.path.join(env_dir, fr'models/{run_name}')

os.makedirs(logs_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

print(logs_dir)
print(models_dir)

./output/spark/logs/a2c-main-2
./output/spark/models/a2c-main-2


### 1.4) Weights and Bias initialisation

In [6]:
os.environ['WANDB_API_KEY'] = '1df5a9458d8efa658e8df4b34a37ed2796398763'
os.environ['WANDB_NOTEBOOK_NAME'] = 'Exp2-A2C'
wandb.login()

run = wandb.init(
    project=project,
    name=run_name,
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    save_code=True,  # optional
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcrowley-m[0m ([33mcrowley-m-university-of-technology-sydney[0m). Use [1m`wandb login --relogin`[0m to force relogin


## 2) Utilities

### 2.1) Custom Callback utility

In [7]:
class WandbEvalCallback(EvalCallback, BaseCallback):
    """
        Custom Callback that combines EvalCallback() and WandbEvalCallback().
            - EvalCallback() independently evaluates the RL model and returns the "Best performing" one.
            - WandbEvalCallback() ensures logs are uploaded to the Weights and Biases online server.

        Sources:
            - SB3 API Docs:         https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
            - SB3 codebase:         https://github.com/DLR-RM/stable-baselines3/blob/c62e9259db363bf32bd920405dbe83db94123271/stable_baselines3/common/callbacks.py
            - Wandb codebase:       https://github.com/wandb/wandb/blob/main/wandb/integration/sb3/sb3.py
    """

    def __init__(self, 
                 eval_env, 
                 wandb_run,
                 save_interval: int = 10000,
                 eval_freq: int = 10000, 
                 save_path: str = None, 
                 n_eval_episodes: int = 5, 
                 deterministic: bool = True, 
                 render: bool = False, 
                 verbose: int = 1):
        
        # Initialize EvalCallback
        EvalCallback.__init__(self, 
                               eval_env=eval_env, 
                               eval_freq=eval_freq, 
                               best_model_save_path=save_path, 
                               n_eval_episodes=n_eval_episodes, 
                               deterministic=deterministic, 
                               render=render, 
                               verbose=verbose)

        # Initialize BaseCallback (WandbCallback)
        BaseCallback.__init__(self, verbose=verbose)
        self.wandb_run = wandb_run
        self.best_mean_reward = -np.inf
        self.save_interval = save_interval
        self.save_path = save_path

    def _on_step(self) -> bool:
        # Call the parent method to perform evaluation
        super_result = super(WandbEvalCallback, self)._on_step()
        
        if super_result:
            # Log metrics to wandb
            self.wandb_run.log({"mean_reward": self.last_mean_reward, "step": self.num_timesteps})

            # Log best model if it was updated
            if self.save_path is not None and self.best_mean_reward is not None:
                self.wandb_run.log({"best_mean_reward": self.best_mean_reward, "step": self.num_timesteps})

        # Save the model every 'save_interval' steps
        if self.num_timesteps % self.save_interval == 0:
            interval_save=True
            save_file = os.path.join(self.save_path, f'{model_id}_model_{self.num_timesteps}')
            self.model.save(save_file)
            if self.verbose > 0:
                print(f'Saving model to {save_file}.zip')

        return super_result or interval_save

    def _on_training_end(self) -> None:
        # Log final metrics at the end of training
        self.wandb_run.log({"training_end": True})


## 3) Custom Environment Set-up

### 3.1) Product Recommendation Environment

In [None]:
class RecommendationEnv(gym.Env):
    def __init__(self, users:List[Customer], products:List[Product], categories:List[Category], top_k:int, learning_mode:str='online'):
        super().__init__()
        
        self.users = users                  # list of users as states
        self.products = products            # products as actions, potential recommendations
        self.categories = categories
        self.top_k = top_k                  # number of recommendations
        self.learning_mode = learning_mode  # Online learning allows interaction history to be saved between episodes, offline learning resets them
        self.user_idx = 0                   # index of users list, not user_id
        self.current_step = 0               # step is also the interactions list index
        
        for user in self.users:
            user.views = np.zeros(len(products), dtype=np.int16)
            user.likes = np.zeros(len(products), dtype=np.int16)
            user.buys = np.zeros(len(products), dtype=np.int16)
            user.ratings = np.zeros(len(products), dtype=np.int16)
            for interaction in user.interactions:
                i_type = interaction.type.value    
                product_idx = interaction.product_idx  
                if i_type == InteractionType.VIEW.value:
                    user.views[product_idx] += 1
                elif i_type == InteractionType.LIKE.value:
                    user.likes[product_idx] += 1
                elif i_type == InteractionType.BUY.value:
                    user.buys[product_idx] += 1
                elif i_type == InteractionType.RATE.value:
                    user.ratings[product_idx] = interaction.value
        
        if self.learning_mode!='online':
            self.baseline_users=deepcopy(self.users)

        # define action space to return top_k recommended product ids.    
        self.action_space = spaces.MultiDiscrete([len(products)] * top_k) 
        
        # number of customers as states
        # states are derived from customer profiles and interactioms
        # Users list will keep track of unique users
        # States include subset of features including product, interaction, ratings, and time in one-hot-encoding format
        # States exclude user_ids for policy network generalisation. But internal users list will be used as reference        
        self.observation_space = spaces.Dict({
            'pref_prod': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.float32),
            'pref_cat': spaces.Box(low=0, high=1, shape=(len(self.categories),), dtype=np.float32),
            'buys': spaces.Box(low=0, high=1000, shape=(len(self.products),), dtype=np.int16),
            'views': spaces.Box(low=0, high=1000, shape=(len(self.products),), dtype=np.int16),
            'likes': spaces.Box(low=0, high=1000, shape=(len(self.products),), dtype=np.int16),
            'ratings': spaces.Box(low=0, high=5, shape=(len(self.products),), dtype=np.uint8),
            'product': spaces.Box(low=0, high=1, shape=(len(self.products),), dtype=np.uint8),
            'interaction': spaces.Box(low=0, high=1, shape=(len(list(InteractionType)),), dtype=np.uint8),
            'rating': spaces.Discrete(6)
            }) 
        
        
    def reset(self, seed=None, options=None):
        # Call the parent class's reset method to handle seeding
        super().reset(seed=seed)

        # Added hack to reset the users during every reset()
        if self.learning_mode!='online':
            self.users=deepcopy(self.baseline_users)

        self.user_idx = np.random.randint(len(self.users)) # may run through users one by one
        user = self.users[self.user_idx]
        self.current_step = 0
        return self._get_observation(user), {}# get current user features as states

    def step(self, rec_products, interaction:Interaction=None):
        """ randomly interacting with product to mimic real user unpredictable behaviours """
        self.current_step += 1
        
        # interaction passed in fro
        if interaction:    
            self.user_idx = interaction.customer_idx
        else:       
            # simulate selected recommended product and interaction
            interaction = self._simulate_interaction(rec_products) # generate random interaction
        
        reward = 0
        done = False
        
        # reward function if the recommended product is clicked
        if interaction.product_idx in rec_products:
            if interaction.type == InteractionType.NONE:
                reward = -1 # no interaction, customers not interested in recommendations
            elif interaction.type ==  InteractionType.VIEW:
                reward = 3
            elif interaction.type ==  InteractionType.LIKE:
                reward = 10
            elif interaction.type ==  InteractionType.BUY:
                reward = 50
            elif interaction.type ==  InteractionType.RATE:            
                reward = (interaction.value - 2)  * 3 # rating of 1 is negative
        
        done = interaction.type ==  InteractionType.SESSION_CLOSE
        
        interaction_info = {
                'interaction': {
                    'idx': -1,
                    'timestamp': interaction.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
                    'customer_id': interaction.customer_idx,
                    'product_id': interaction.product_idx,
                    'interaction': interaction.type.value,
                    'rating': interaction.value
                }
            }
        
        user = self.users[self.user_idx]
        
        return self._update_observation(user, interaction), reward, done, False, interaction_info

    def _update_observation(self, user:Customer, interaction:Interaction):            
        if interaction.type == InteractionType.VIEW:
            user.views[interaction.product_idx] += 1
        elif interaction.type == InteractionType.LIKE:
            user.likes[interaction.product_idx] += 1
        elif interaction.type == InteractionType.BUY:
            user.buys[interaction.product_idx] += 1
        elif interaction.type == InteractionType.RATE:
            user.ratings[interaction.product_idx] = interaction.value
            
        product_obs = np.zeros(len(self.products), dtype=np.int16)
        if interaction.type == InteractionType.VIEW \
            or interaction.type == InteractionType.LIKE  \
            or interaction.type == InteractionType.BUY \
            or interaction.type == InteractionType.RATE:
                product_obs = utils.one_hot_encode(interaction.product_idx, len(self.products))
        
        # update observation based on new data  
        obs = {
                'pref_prod': self._get_product_preferences(user),
                'pref_cat': self._get_category_preferences(user), 
                'buys': user.buys,
                'views': user.views,
                'likes': user.likes,
                'ratings': user.ratings,
                'product': product_obs,
                'interaction': self._get_interaction_observation(interaction),
                'rating': interaction.value if interaction.type == InteractionType.RATE else 0 
            }
        
        return obs
    
    def update_observation(self, user:Customer, interaction:Interaction):
        return self._update_observation(user, interaction)

    def _get_observation(self, user:Customer): 
        
        obs = {
                'pref_prod': self._get_product_preferences(user),
                'pref_cat': self._get_category_preferences(user), 
                'buys': user.buys,
                'views': user.views,
                'likes': user.likes,
                'ratings': user.ratings,
                'product': np.zeros(len(self.products)),
                'interaction': np.zeros(len(list(InteractionType))),
                'rating': 0 
            }
        
        return obs
        
    def _get_interaction_observation(self, interaction:Interaction):
        idx = list(InteractionType).index(interaction.type)
        size = len(InteractionType)
        
        return utils.one_hot_encode(idx, size)
    
    # calculate preferences based on past interactions
    def _get_product_preferences(self, user:Customer):
        view_prefs = user.views / 20
        purchase_prefs = user.buys
        like_prefs = user.likes / 15

        rating_prefs = user.ratings.copy()
        rating_prefs[rating_prefs > 0] -= 2
        
        product_prefs = view_prefs + purchase_prefs + like_prefs+ rating_prefs
        
        product_prefs / 5 # reduce space
        
        return product_prefs    # calculate preferences based on past interactions
    
    def _get_category_preferences(self, user:Customer):
        prod_prefs = self._get_product_preferences(user)
        cat_prefs = np.zeros(len(self.categories), np.float32)
        
        for idx, prod_pref in enumerate(prod_prefs):
            if prod_pref > 0:
                product = self.products[idx]
                cat_idx = product.category.idx
                cat_prefs[cat_idx] += prod_pref # accumulation of fav products for this cat
                # print(f"added pf {prod_pref} to cat {cat_idx}")
                
        cat_prefs = cat_prefs / 5 # reduce space
           
        return cat_prefs

    def _simulate_interaction(self, product_ids):
        user = self.users[self.user_idx]
        product_list = []
        
        # simulate selection
        num_products = len(product_ids)
        prod_scores = np.zeros(num_products, np.uint8)
        product_prefs = self._get_product_preferences(user)
        category_prefs = self._get_category_preferences(user)
        product_probs = np.full((num_products,), 1.0 / num_products) # equal probs by default
        
        for idx, pid in enumerate(product_ids):
            product_list.append(self.products[pid]) # get the product objects
            prod_scores[idx] = product_prefs[pid]
            
        # combining category prefs to calculate probabilities
        for idx, product in enumerate(product_list):
            cid = product.category.idx
            prod_scores[idx] += category_prefs[cid] # scores higher for favourite categories
    
        # Ensure the probabilities sum to 1 for a valid probability distribution
        if np.max(prod_scores) > 0: # probability selection based on prefernces
            product_probs = np.array(prod_scores) / sum(prod_scores)

        # Randomly select a product based on the defined probabilities
        selected_product_id = np.random.choice(product_ids, p=product_probs)
        
        # simulate interaction for the selected product
        inter_types = list(InteractionType)
        inter_scores = np.zeros(len(inter_types), np.uint8)
        inter_probs = np.full((len(inter_types),), 1.0 / len(inter_types)) # equal probs by default        
        inter_probs[-1] = 0.1 # lower prob for SESSION_CLOSE ending epsidoe flag to encourage longer training
        
        for idx, inter_type in enumerate(inter_types):
            if inter_type == InteractionType.VIEW:
                inter_scores[idx] = user.views[selected_product_id]
            if inter_type == InteractionType.LIKE:
                inter_scores[idx] = user.likes[selected_product_id]
            if inter_type == InteractionType.BUY:
                inter_scores[idx] = user.buys[selected_product_id]
            if inter_type == InteractionType.RATE:
                inter_scores[idx] = user.ratings[selected_product_id]
        
        if np.max(inter_scores) > 0:
            inter_scores += 1 # default score for interaction that are 0
            inter_scores = inter_scores * inter_probs # scale distribution
            inter_probs = np.array(inter_scores) / sum(inter_scores)

        # Randomly select a product based on the defined probabilities
        inter_probs = np.array(inter_probs) / sum(inter_probs) # normalise
        selected_interaction_type = np.random.choice(inter_types, p=inter_probs)
        
        selected_product_rating = random.randint(0, 5)
        if selected_interaction_type == InteractionType.RATE:
            # likely rating if the product is a prefered product
            prod_pref_score = product_prefs[selected_product_id]
            max_pref_score = np.max(product_prefs)
            if prod_pref_score > 0: # one of favourite products, rating of 2-5
                selected_product_rating = int((prod_pref_score/max_pref_score) * 3) + 2
                
        rating_probs = np.full(6, 1/6)
        rating_probs[selected_product_rating] = 0.7 # up chance of raing
        rating_probs = rating_probs/sum(rating_probs) # normlise
        selected_product_rating = np.random.choice([0,1,2,3,4,5], p=rating_probs)
            
        user_id = self.user_idx
        interaction_time = datetime.now()        
        interaction = Interaction(-1, interaction_time, user_id, selected_product_id, selected_interaction_type, selected_product_rating)

        return interaction

    def seed(self, seed=None):
        """
        Set the seed for reproducibility.
        """
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        random.seed(seed)
        np.random.seed(seed)
        return [seed]
    
    def render(self, mode='human'):
        if hasattr(self, 'last_action'):
            print(f"Recommended Product ID (Last Action): {self.last_action}")
        else:
            print("No product recommended yet.")

        # Optionally, print the reward received for the last action
        if hasattr(self, 'last_reward'):
            print(f"Reward for Last Action: {self.last_reward}")

        print("-----")

    def close(self):
        pass

### 2.3) Load environment data

Loading products and customers from custom dataloaders in spark.data.loaders.py

In [9]:
products = loader.load_products()
customers = loader.load_customers(include_interactions=True)
categories = loader.load_categories()

### 2.4) Initialise training and evaluation environments

In [None]:
# Training environment
env = DummyVecEnv([lambda: Monitor(RecommendationEnv(customers, products, categories, top_k=10, learning_mode='online'))])
env.seed(100)

# Evaluation environment for WandbEvalCallback
eval_env = DummyVecEnv([lambda: Monitor(RecommendationEnv(customers, products, categories, top_k=10, learning_mode='online'))])

print(env.observation_space)

Dict('buys': Box(0, 1000, (295,), int16), 'interaction': Box(0, 1, (7,), uint8), 'likes': Box(0, 1000, (295,), int16), 'pref_cat': Box(0.0, 1.0, (38,), float32), 'pref_prod': Box(0.0, 1.0, (295,), float32), 'product': Box(0, 1, (295,), uint8), 'rating': Discrete(6), 'ratings': Box(0, 5, (295,), uint8), 'views': Box(0, 1000, (295,), int16))


## 3) Define model

### 3.2) Define Hyperparameters

UserWarning: You have specified a mini-batch size of 64, but because the `RolloutBuffer` is of size `n_steps * n_envs = 100`, after every 1 untruncated mini-batches, there will be a truncated mini-batch of size 36
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=100 and n_envs=1)
  warnings.warn(


In [12]:
param_clip_range=0.2
param_learning_rate=2.5e-4
param_gamma=0.95
param_gae_lambda=0.99
param_n_steps=100
param_ent_coef = 0.01

### 3.3) Define model

In [13]:
model = A2C(
    env=env,
    policy='MultiInputPolicy',
    verbose=0,
    learning_rate=param_learning_rate, 
    n_steps=param_n_steps,
    ent_coef=param_ent_coef,
    gamma=param_gamma,
    gae_lambda=param_gae_lambda,
    # policy_kwargs={'features_extractor_class': CustomANN},
    tensorboard_log=logs_dir
)


## 4) Model training

### 4.1) Execute training

In [None]:
# Configure logger for stdout, CSV, and TensorBoard
new_logger = configure(str(logs_dir), ["stdout", "csv", "tensorboard"])
model.set_logger(new_logger)

# Model learning with WandbCallback and interim model saving
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbEvalCallback(
                                eval_env=eval_env,
                                wandb_run=run,
                                save_interval=save_interval,
                                eval_freq=save_interval,
                                save_path=models_dir
                            )
    )

# Save the final model after training completes
final_model_path = os.path.join(models_dir, model_name_final)
model.save(final_model_path)

### 4.2) Close out WandB session

In [None]:
api = wandb.Api()

# Access attributes directly from the run object
# or from the W&B App
username = wandb.run.entity
project = wandb.run.project
run_id = wandb.run.id

run = api.run(f"{username}/{project}/{run_id}")
run.config["bar"] = 32
run.update()

## 5) Testing

### 5.1) Load tensorboard session

In [None]:
# %load_ext tensorboard

In [None]:
# %tensorboard --logdir={log_dir}

### 5.2) Outcome debugging

In [None]:
# Construct observation data for customer 10
test_customer = customers[10] 

# simulated laset selected product
test_product = products[53]

# it only uses the previous interaction to predict
previous_interaction = Interaction(idx = '0', 
                          timestamp = datetime.now(), 
                          customer_idx = test_customer.idx, 
                          product_idx = test_product.idx, 
                          type = InteractionType.RATE,
                          value = 5,)

new_obs = env.update_observation(test_customer, previous_interaction)

In [None]:
recommended_products, _states = model.predict(new_obs)
print(recommended_products)

In [None]:
obs, rewards, done, flag, simulated_interaction_info = env.step(recommended_products)
simulated_interaction_info

In [None]:
obs = env.reset()
for _ in range(20):  # Make 10 recommendations
    action, _states = model.predict(obs)
    obs, rewards, done, _ = env.step(action)
    print("Recommended products:", action)