Copyright (c) 2022, salesforce.com, inc and MILA.  
All rights reserved.  
SPDX-License-Identifier: BSD-3-Clause  
For full license text, see the LICENSE file in the repo root  
or https://opensource.org/licenses/BSD-3-Clause  

In [None]:
import logging
import numpy as np
import os
import shutil
import subprocess
import sys
import time
import yaml

from scripts.train_with_rllib import create_trainer, fetch_episode_states, load_model_checkpoints

In [None]:
# Set logger level e.g., DEBUG, INFO, WARNING, ERROR.
logging.getLogger().setLevel(logging.ERROR)

In [None]:
# Needed to perform this install when the system threw the lib.so file missing error
# ! apt-get install libglib2.0-0 --yes

In [None]:
print("Training with RLlib...")
# Read the run configurations specific to the environment.
# Note: The run config yaml(s) can be edited at warp_drive/training/run_configs
# -----------------------------------------------------------------------------
config_path = os.path.join("scripts", "rice_rllib.yaml")
if not os.path.exists(config_path):
    raise ValueError(
        "The run configuration is missing. Please make sure the correct path"
        "is specified."
    )

with open(config_path, "r", encoding="utf8") as fp:
    run_config = yaml.safe_load(fp)

# Create trainer
# --------------
trainer, save_dir = create_trainer(run_config)

# Copy the source files into the results directory
# ------------------------------------------------
os.makedirs(save_dir)
for file in [
    "rice.py",
]:
    shutil.copyfile(
        os.path.join(file),
        os.path.join(save_dir, file),
    )
# Add an identifier file
with open(os.path.join(save_dir, ".rllib"), "x", encoding="utf-8") as fp:
    pass
fp.close()

### Invoke training

In [None]:
NUM_ITERS = 5
for iter in range(NUM_ITERS):
    result = trainer.train()
print(result)

### Fetch episode states

In [None]:
outputs = fetch_episode_states(trainer, ["T_i", "carbon_mass_i", "capital_i"])