<a href="https://colab.research.google.com/github/joefarrington/viso_jax/blob/main/notebooks/reproduce_viso_jax_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reproduce experiments from "Going faster to see further: GPU-accelerated value iteration and simulation for perishable inventory control"

# Introduction



This notebook accompanies the paper "*Going faster to see further: GPU-accelerated value iteration and simulation for perishable inventory control*". It provides a way to reproduce the main experiments without a local GPU, and without requiring any local setup. 

If you are new to Google Colab, you may wish to first work through this [introductory notebook](https://colab.research.google.com/) and/or read the [FAQ](https://research.google.com/colaboratory/faq.html). 

This notebook was last tested on 2023-05-27. Subsequent changes to the GPU drivers and default Python environment on Google Colab may cause compatability issues, please raise an issue on the GitHub respository if you encounter one. 

## Using a GPU runtime on Colab

Colab provides access to free cloud-based GPUs. This notebook is set to use a GPU runtime by default. Running the cell below will print the details of the GPU. If it fails, go to the menu on the top left of the screen and select **Runtime** -> **Change runtime type**. Select **GPU** from the dropdown list for hardware accelerator.  

In [None]:
!nvidia-smi

## Value iteration wall time and checkpoints

Our value iteration method can save checkpoints at the end of each iteration, which can be used to restart value iteration if Colab times out. 

While testing this notebook, we observed that saving checkpoints is much slower on Colab than on our local development machine and this can lead to significantly increased wall times for some settings.

We therefore provide an option to set the checkpoint frequency for Scenerio A and Scenario B - a checkpoint frequency of 0 corresponds to not saving checkpoints at all. These are set by default to 100 for Scenario A and 1 for Scenario B, the frequencies we used for the results reported in paper.

For Scenario C, the checkpoints are required for the convergence test and therefore we do not provide this option. 

## Mounting Google Drive for permanent storage

Colab provides you with a temporary working directory.  This working space is not permanent, and any outputs stored in this space may be lost when the runtime is restarted. 

If you want to store outputs, you can [mount your Google Drive onto Colab](https://colab.research.google.com/notebooks/io.ipynb). If you mount your Google Drive to this notebook, the viso_jax GitHub repo will be cloned onto your Google Drive and any outputs/checkpoints will be stored on your Google Drive. This will be particularly helpful if you want to run any of the longer experiments and be able to restart them from a checkpoint if Colab times out. 

To mount your Google Drive, check the box for the variable `mount_google_drive` at the start of the [Setup](#setup) section, run that cell, and follow the instructions to authorize the process before running the next cell.

## Running experiments

First, run all of the cells in the [Setup](#setup) section to clone the viso_jax GitHub repository, and install viso_jax and its dependencies. Once setup has been completed, the cells in the [Run experiments](#run-experiments) section can be run in any order.

Each form corresponds to a scenario in the paper. Use the dropdown boxes to select the maximum useful life $m$, the ID number of the experiment, and whether to use value iteration or simulation optimization to reproduce the results for the specified experiment using the specified method.

## Advanced

The [Advanced](#advanced) section demonstrates how to:
* restart value iteration from a checkpoint
* reduce the GPU memory requirements
* run value iteration at single-precision
* run experiments using different random seeds, and;
* change a scenario setting to run a modified version of an experiment. 

The cells within each subsection should be run in order (but the subsections are independent).

## Running this notebook locally

This notebook can be run locally, as a [Jupyter Notebook](https://jupyter.org/). We recommend that you follow the installation instructions for viso_jax on the [README](https://github.com/joefarrington/viso_jax) page of the GitHub repository, and then run the notebook using the local virtual environment in which you have installed viso_jax. 

Enter the path to the local copy of the viso_jax git repository in the final cell of the [Setup](#setup) section so that subsequent commands that move between directories start in the correct place. 

# Setup

In [None]:
#@title Mount Google Drive
mount_google_drive = True #@param {type:"boolean"}

if mount_google_drive:
  from google.colab import drive
  drive.mount('/content/gdrive/')

In [None]:
import sys
import os
from pathlib import Path

# If we're in a Colab notebook, pull any changes to viso_jax repo (or create if it doesn't exist)
# and install viso_jax and dependencies
# This process is more complicated on Colab due to changes in default packages and drivers
if 'google.colab' in sys.modules:
  try:
    # If Google Drive successfully mounted, clone into drive
    os.chdir("/content/gdrive/MyDrive");
  except:
    # Otherwise clone into the temporary
    os.chdir("/content/")
  !git -C viso_jax pull || git clone https://joefarrington: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@github.com/joefarrington/viso_jax.git viso_jax;
  viso_jax_dir = Path("viso_jax").absolute()
  os.chdir("viso_jax");

  # Install jedi to avoid Colab error
  !pip install jedi>=0.10 > install_jedi.log
  # Install our package
  !pip install . > install_viso_jax.log
  # New default on Colab, manually specify versions of chex and orbax-checkpoint
  !pip install chex==0.1.6
  !pip install orbax-checkpoint==0.1.1
  # Manually specify version of jax and jaxlib; Colab seems to prefer this to reading it from pyproject.toml
  !pip install jax==0.3.25 > install_jax.log
  !pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html > install_jaxlib.log


# If we're not in a Colab notebook, don't take any of those actions and ask the user to manually add the path to the viso_jax github repo in the cell below
else:
  print("Current environment appears not to be a Colab notebook. Ensure that viso_jax has been installed in current environment. Add the path the local copy of the viso_jax git repo in the field below.")

# See https://stackoverflow.com/questions/74117246/python-logging-module-not-working-in-google-colab
import logging
# Remove all handlers associated with the root logger object, and update config
# to print value iteration and simopt logs as output
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO,)

In [None]:
#@title Add path to repo if running locally
path_to_local_viso_jax_git_repo = "" #@param {type:"string"}
if 'google.colab' not in sys.modules:
  viso_jax_dir = Path(f"path_to_local_viso_jax_git_repo").absolute()

# Run experiments

In [None]:
#@title Run an experiment for Scenario A
m = "2" #@param ["2", "3", "4", "5"]
experiment = "1" #@param ["1", "2", "3", "4", "5", "6", "7", "8"]
method = "Value iteration" #@param ["Value iteration", "Simulation optimization"]
checkpoint_frequency = 100 #@param {type:"slider", min:0, max:500, step:1}

if method == "Value iteration":
  os.chdir(viso_jax_dir / "viso_jax/value_iteration")
  !python run_value_iteration.py +experiment=de_moor_perishable/m{m}/exp{experiment} vi_runner.checkpoint_frequency={checkpoint_frequency}
elif method == "Simulation optimization":
  os.chdir(viso_jax_dir / "viso_jax/simopt")
  !python run_optuna_simopt.py +experiment=de_moor_perishable/m{m}/exp{experiment}

In [None]:
#@title Run an experiment for Scenario B
m = "2" #@param ["2", "3"]
experiment = "1" #@param ["1", "2", "3", "4", "P1", "P2", "P3", "P4"]
method = "Value iteration" #@param ["Value iteration", "Simulation optimization"]
checkpoint_frequency = 1 #@param {type:"slider", min:0, max:100, step:1}

if m=="2" and experiment in ["3", "4"]:
  raise ValueError(f"No experiment {experiment} for m=2")
if m=="3" and experiment in ["P1", "P2", "P3", "P4"]:
  raise ValueError(f"No experiment {experiment} for m=2")

if experiment in ["1", "2", "3", "4"]:
  exp_name = f"exp{experiment}"
else:
  exp_name = f"ortega_{experiment}"

if method == "Value iteration":
  os.chdir(viso_jax_dir / "viso_jax/value_iteration")
  !python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m{m}/{exp_name} vi_runner.checkpoint_frequency={checkpoint_frequency}
elif method == "Simulation optimization":
  os.chdir(viso_jax_dir / "viso_jax/simopt")
  !python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m{m}/{exp_name}

In [None]:
#@title Run an experiment for Scenario C
m = "3" #@param ["3", "5", "8"]
experiment = "1" #@param ["1", "2"]
method = "Value iteration" #@param ["Value iteration", "Simulation optimization"]

if method == "Value iteration":
  if m == "8":
    raise ValueError("Value iteration not feasible when m=8")
  os.chdir(viso_jax_dir / "viso_jax/value_iteration")
  !python run_value_iteration.py +experiment=mirjalili_perishable_platelet/m{m}/exp{experiment}
elif method == "Simulation optimization":
  os.chdir(viso_jax_dir / "viso_jax/simopt")
  !python run_optuna_simopt.py +experiment=mirjalili_perishable_platelet/m{m}/exp{experiment}

# Advanced

## Resuming value iteration from a checkpoint

At the time of writing (February 2023), the Colab [FAQ](https://research.google.com/colaboratory/faq.html) states that the free tier of Google Colab has a maximum time limit of 12 hours but a session may also time out if left idle. The time limit can be extended using Colab Pro. 

The 12 hour time limit of the free tier is long enough to run all of the experiments except value iteration for Scenario C when $m=5$. In practice, we have found that our notebooks time out due to inactivity when running some of the shorter experiments. 

One possible solution to this is to resume value iteration from a checkpoint file. If you wish to do this, we recommend that you mount your Google Drive using the process described above. 

In the example below, there is no interruption. For illustrative purposes, we load a checkpoint from the temporary working directory, but in the event of a time-out there is no guarantee that it will still be available when restarting the runtime.

By default, value iteration outputs are saved in an outputs directory in the value_iteration directory, with a path specifying the scenario, value of $m$, experiment ID, date and time, for example:
`value_iteration/outputs/hendrix_perishable_substitution_two_product/m2/exp1/2023-02-07/12-53-07`. The checkpoints are saved in a subdirectory, called `checkpoints`. 

In this example, to ensure that the cells below run without needing to specify a date and time, we set a custom output directory using the command line argument `hydra.run.dir`.

In [None]:
# Run a complete experiment
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 hydra.run.dir='restore_cp_demo'

The experiment saves checkpoints and, by default, a log file (`run_value_iteration.log`), the policy (`policy.csv`), the final values (`V.csv`), and an output file that we use to record information for inclusion in results tables (`output_info.yaml`)

In [None]:
!ls restore_cp_demo

Within the `checkpoints` subdirectory of the output, there is a checkpoint from the end of each iteration.

In [None]:
!ls restore_cp_demo/checkpoints

We resume from a checkpoint by using the command line argument `vi_runner.resume_from_checkpoint` and providing the path to the checkpoint we want to use.

In [None]:
# Repeat the experiment, starting with a checkpoint for iteration 8
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
checkpoint_path = viso_jax_dir / "viso_jax/value_iteration/restore_cp_demo/checkpoints/values_8.csv"
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 vi_runner.resume_from_checkpoint={checkpoint_path} hydra.run.dir='restored_from_cp_demo'

We can see that the outputs for value iteration from iteration 9 to 11 are the same as in the original run above, and the the metrics reported on the simulated rollouts are the same.

## Reducing the GPU memory demand

The default parallelism in the code has been set to run on the Nvidia GeForce RTX 3060, with 12GB of VRAM, on our development machine. 

In our recent experience of the free tier of Google Colab, we have been allocated an Nividia Tesla T4 with 15GB of VRAM and therefore this has not been a problem. However, the GPU allocated by Colab is not guaranteed and you may wish to run this notebook, or other code from the repository, on a GPU with less VRAM.

If you encounted a GPU out-of-memory error, you can try to avoid it by reducing the amount of work performed in parallel as set out below.

### Value iteration

For value iteration, the key setting is the maximum batch size: the number of states being simultaneously updated. This can be adjusted using the command line argument `vi_runner.max_batch_size`.

In [None]:
# Run an experiment with a large max batch size
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 vi_runner.max_batch_size=5000

In [None]:
# Run an experiment with a small max batch size
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 vi_runner.max_batch_size=50

### Simulation optimization

For simulation optimization the key settings are `param_search.max_parallel_trials` (the number of different combinations of parameters for the heuristic policy to run in one iteration) and `param_search.num_rollouts` (the number of rollouts to run for each set of parameters for the heuristic policy). By default these are set at 50 and 4,000 respectively. For experiments that use Optuna NSGAII sampler we also use the comment line argument `param_search.sampler.population_size` to match `param_search.num_rollouts` so that the size of each generation of the genetic algorithm is the same as the number of trials being run in one iteration. 


In [None]:
# Run a simulation optimization experiment with the default settings
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1

In [None]:
# Run an experiment that evaluates fewer combinations of parameters in parallel
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 param_search.max_parallel_trials=10 param_search.sampler.population_size=10

In [None]:
# Run an experiment with fewer rollouts per set of parameters
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 param_search.num_rollouts=1000

In [None]:
# Run an experiment that evaluates fewer combinations of parameters in parallel and runs fewer rollouts per set of parameters
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 param_search.max_parallel_trials=10 param_search.sampler.population_size=10  param_search.num_rollouts=1000

### Evaluation rollouts

For both value iteration and simulation optimization we run 10,000 rollouts with the best identified policy by default. This number can be reduced using the command line argument `evaluation.num_rollouts`. 

In [None]:
# Run a value iteration experiment with a fewer evaluation rollouts
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 evaluation.num_rollouts=5000

In [None]:
# Run a simulation optimization experiment with fewer evaluation rollouts
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 evaluation.num_rollouts=5000

## Running value iteration at single-precision

We explain in the paper that we have used double-precision numbers for our value iteration experiments, due to the instability in convergence we observed for some problems when using the default single-precision while conducting preliminary experiments.

We provide the ability to switch between these options using the command line, as demonstrated below.

In [None]:
# Run value iteration at double-precision, the default in our configs (but not for JAX)
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/ortega_P4

In [None]:
# Run value iteration at single-precision, the default for JAX
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/ortega_P4 vi_runner.checkpoint_frequency=0

## Running experiments with different random seeds

Our experiments use seeds to ensure reproducibility. These seeds can be changed using command line arguments.

JAX handles random number generation and seeding differently to other libraries like NumPy - we recommend reading this [tutorial](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#rngs-and-state). 

The seeds we set in the configuration files, and override below using the command line, are used to initialize an initial JAX PRNGKey. This key is then split to be passed to each call of a function that generates a random output.

### Value iteration

A seed is only used at the evaluation stage of value iteration. It is used to make the stochastic elements of the transitions in the environment reproducible. 

We use the same seed for the evaluation stages of value iteration and simulation optimization so that, for a given experiment, it would be possible to perform pairwise comparisons to test how each policy performs when faced with the same pattern of demand (and any other stochastic elements in the transition). 

The command line argument for the seed is `evaluation.seed`.

In [None]:
# Run value iteration and evaluate on scenarios generated using the default seed from the config file
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1

In [None]:
# Run value iteration and evaluate on scenarios generated using a seed specified at the command line
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 evaluation.seed=999

We can see from the results above that the evaluation results are similar, but different because they are based on two different sets of 10,000 rollouts.

### Simulation optimization

For simulation optimization we supply two different seeds using the configuration file. 

One, `evaluation.seed` is the same as for the value iteration experiments, and is used at the evaluation stage after the best parameters for the heuristic policy have been identified. 

The second, `param_search.seed` fulfills the the same function for the rollouts performed during the simulation optimization process.

`param_search.seed` is also separately passed into our Optuna sampler to make the heuristic search of the policy parameter space deterministic. 

In [None]:
# Run simulation optimization and evaluate on scenarios generated using the default seeds from the config file
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1

In [None]:
# Run simulation optimization using a seed specified at the command line and evaluate on scenarios generated using the default param_seach.seed from the config file
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 param_search.seed=999

In [None]:
# Run simulation optimization using the default evaluation.seed from the config file and evaluate on scenarios generated using the a seed specified at the command line
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=hendrix_perishable_substitution_two_product/m2/exp1 evaluation.seed=999

We can see that the search process is different in the first and second cases, but both identified the same combination of parameters as the best. The evaluation outputs for the first and second case are the same - they use the same evaluation seed and therefore the policies are being evaluated on the same set of 10,000 rollouts.

The third case follows the same search process as the first, but the evaluation outputs are different becauase they are calculated based on a different set of 10,000 rollouts.

## Changing a scenario setting

Scenario settings are settings that relate to the problem description, for example the parameters of the demand distribution, the maximum useful life or the variable order cost. 

We use [hydra](https://hydra.cc/) for configuration, which supports composable configuration files. The final configuration for an experiment is therefore drawn from multiple yaml files.

If you wish to change scenario settings, you should update the entry in the scenario settings config. The change will then be propogated to any other configs that require the same information. For example, when running value iteration, the scenario settings config feeds through both to the config that parameterises the class that performs value iteration (a value iteration runner) and the config that parameterises the class that manages the simulation rollouts of the final policy (a rollout wrapper).

In the examples below, we run an experiment with the default settings followed by a modified version, changing one or more scenario settings at the command line.

### Value iteration

In [None]:
# Run a value iteration experiment
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=de_moor_perishable/m2/exp1

In [None]:
# Run a value iteration experiment with a higher shortage cost
os.chdir(viso_jax_dir / "viso_jax/value_iteration")
!python run_value_iteration.py +experiment=de_moor_perishable/m2/exp1 scenario_settings.shortage_cost=10

### Simulation optimization

In [None]:
# Run a simulation optimization experiment
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=de_moor_perishable/m2/exp1

In [None]:
#Run a simulation optimization experiment with a higher mean demand and higher max order quantity
os.chdir(viso_jax_dir / "viso_jax/simopt")
!python run_optuna_simopt.py +experiment=de_moor_perishable/m2/exp1 scenario_settings.demand_gamma_mean=8 scenario_settings.max_order_quantity=20