# Experiment functions

> To be written.

In [None]:
#| default_exp experiment_functions

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from abc import ABC, abstractmethod
from typing import Union
import logging
from datetime import datetime  
import numpy as np
import sys

from ddopnew.envs.base import BaseEnvironment
from ddopnew.agents.base import BaseAgent

from tqdm import tqdm

# Think about how to handle mushroom integration.
from mushroom_rl.core import Core

In [None]:
#| export

class EarlyStoppingHandler():

    '''
    Class to handle early stopping

    '''
    def __init__(
        self,
        patience: int = 50,
        warmup: int = 100,
        criteria: str = "J",  # Whether to use discounted rewards J or total rewards R as criteria
        direction: str = "max"  # Whether reward shall be maximized or minimized
    ):

        self.history = list()
        self.patience = patience
        if warmup is None or warmup < patience * 2:
            warmup = patience * 2
        self.warmup = warmup
        self.criteria = criteria
        self.direction = direction

    def add_result(self, J, R):
        if self.criteria == "J":
            self.history.append(J)
        elif self.criteria == "R":
            self.history.append(R)
        else:
            raise ValueError("Criteria must be J or R")
        
        if len(self.history) >= self.warmup:
            if self.direction == "max":
                if sum(self.history[-self.patience*2:-self.patience]) >= sum(self.history[-self.patience:]):
                    return True
                else:
                    return False
            elif self.direction == "min":
                if sum(self.history[-self.patience*2:-self.patience]) <= sum(self.history[-self.patience:]):
                    return True
                else:
                    return False
            else:
                raise ValueError("Direction must be max or min")

## Helper functions

* Some functions that are needed to run an experiment

In [None]:
#| export

def calculate_score(dataset, env):

    """

    XXX

    """

    R = sum([row[0][2] for row in dataset])
    gamma = env.mdp_info.gamma
    J = sum([gamma**(t) * row[0][2] for t, row in enumerate(dataset)]) # Note: t starts at 1 so the first reward is already discounted

    return R, J

def log_info(R, J, n_epochs, logging, mode):
    
    '''
    Logs the R, J information repeatedly for n_epoochs.
    E.g., to draw a straight line in wandb for algorithmes
    such as XGB, RF, etc. that can be comparared to the learning
    curves of supervised or reinforcement learning algorithms.
    '''

    if logging == "wandb":
        for epoch in range(n_epochs):
            wandb.log({f"{mode}/R": R, f"{mode}/J": J})
    else:
        pass

def update_best(R, J, best_R, best_J):
    if R > best_R:
        best_R = R
    if J > best_J:
        best_J = J

    return best_R, best_J

def save_agent(agent, experiment_dir, save_best, R, J, best_R, best_J, criteria="J"):
    if save_best:
        if criteria == "R":
            if R == best_R:
                save_dir = f"{experiment_dir}/saved_models/best"
                agent.save(save_dir)
        elif criteria == "J":
            if J == best_J:
                save_dir = f"{experiment_dir}/saved_models/best"
                agent.save(save_dir)

def test_agent(agent: BaseAgent,
            env: BaseEnvironment,
            return_dataset = False,
            tracking = None, # other: "wandb",
):

    """
    Tests the agent on the environment for a single episode

    # OPEN TODO: Make possible to save dataset via wandb

    """

    # Run the test episode
    dataset = run_test_episode(env, agent)

    # Calculate the score
    R, J = calculate_score(dataset, env)

    if tracking == "wandb":
        mode = env.mode
        wandb.log({f"{mode}/R": R, f"{mode}/J": J})

    if return_dataset:
        return R, J, dataset
    else:
        return R, J

## Experiment functions

In [None]:
#| export

def run_test_episode(   env: BaseEnvironment,
                        agent: BaseAgent,
                ):

    """
    Runs and episode to test the agent's performance.
    It assumes, that agent and environment are initialized, in test/val mode
    and have done reset.
    """

    # Get initial observation
    obs, *_ = env.reset()

    dataset = []
    
    finished = False
    step = 0
    while not finished:
        
        # Sample action from agent
        action = agent.draw_action(obs)

        # Take a step in the environment

        next_obs, reward, terminated, truncated, info = env.step(action)
        
        logging.debug("##### STEP: %d #####", env.index)
        logging.debug("reward: %s", reward)
        logging.debug("info: %s", info)
        logging.debug("next observation: %s", obs)
        logging.debug("truncated: %s", truncated)

        sample = (obs, action, reward, next_obs, terminated, truncated) # unlike mushroom do not include policy_state

        obs = next_obs
        
        dataset.append((sample, info))

        finished = terminated or truncated

        step += 1
        sys.stdout.write(f"\rStep {step}")
        sys.stdout.flush()
    print()

    
    return dataset

def run_experiment( agent: BaseAgent,
                    env: BaseEnvironment,

                    n_epochs: int,

                    early_stopping_handler: Union[EarlyStoppingHandler, None] = None,
                    save_best: bool = True,

                    tracking: Union[str, None]  = None, # other: "wandb"

                    results_dir: str = "results",

                    run_id: Union[str, None] = None,

                    logging_level =  logging.WARNING
                ):

    logging.basicConfig(level=logging_level)

    # use start_time as id if no run_id is given
    if run_id is None:
        run_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")

    experiment_dir = f"{results_dir}/{run_id}"

    logging.info("Starting experiment")

    env.reset()
    env.train()
    agent.train()

    core = Core(agent, env)

    # initial evaluation
    env.val()
    agent.eval()
    R, J = test_agent(agent, env, tracking = tracking)

    env.train()
    agent.train()

    logging.info(f"Initial evaluation: R={R}, J={J}")

    best_J = J 
    best_R = R

    if agent.train_mode == "direct_fit":
        
        logging.info("Starting training with direct fit")
        agent.fit(X=env.dataloader.get_all_X(), Y=env.dataloader.get_all_Y())
        logging.info("Finished training with direct fit")

        env.val()
        agent.eval()

        R, J = test_agent(agent, env, tracking = tracking)
        best_J, best_R = update_best(R, J, best_R, best_J)

        log_info(R, J, n_epochs-1, logging, "val")

    elif agent.train_mode == "epochs_fit":
        
        logging.info("Starting training with epochs fit")
        for epoch in range(n_epochs):

            agent.fit_epoch(X=env.dataloader.get_X_train(), Y=env.dataloader.get_Y_train())

            env.val()
            agent.eval()

            R, J = test_agent(agent, env, tracking = tracking)
            
            best_J, best_R = update_best(R, J, best_R, best_J)
            save_agent(agent, experiment_dir, save_best, R, J, best_R, best_J, performance_criteria)
            if early_stopping_handler is not None:
                stop = early_stopping_handler(J, R)

            if stop:
                logging.info(f"Early stopping after {epoch+1} epochs")
                break
        
            env.train()
            agent.train()

        logging.info("Finished training with epochs fit")

    elif agent.train_mode == "env_interaction":
        pass

    else:
        raise ValueError("Unknown train mode")

    logging.info(f"Evaluation after training: R={R}, J={J}")

In [None]:
from ddopnew.envs.inventory import NewsvendorEnv
from ddopnew.dataloaders.tabular import XYDataLoader
from ddopnew.agents.basic import RandomAgent


val_index_start = 80 #90_000
test_index_start = 90 #100_000

X = np.random.rand(100, 2)
Y = np.random.rand(100, 1)

dataloader = XYDataLoader(X, Y, val_index_start, test_index_start)

environment = NewsvendorEnv(
    dataloader = dataloader,
    underage_cost = 0.42857,
    overage_cost = 1.0,
    gamma = 0.999,
    horizon_train = 365,
)

agent = RandomAgent(environment.mdp_info)

environment.test()

R, J = test_agent(agent, environment)

print(f"R: {R}, J: {J}")

Step 10
R: -6.604986797188729, J: -6.578556148664966


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()