<a target="_blank" href="https://colab.research.google.com/github/instadeepai/jumanji/blob/main/examples/training.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
%pip install --quiet -U "jumanji[train] @ git+https://github.com/instadeepai/jumanji.git@main"

Note: you may need to restart the kernel to use updated packages.


In [2]:
# @title Set up JAX for available hardware (run me) { display-mode: "form" }

import subprocess
import os

# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system
try:
    subprocess.check_output('nvidia-smi')
    print("a GPU is connected.")
except Exception:
    # TPU or CPU
    if "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
        import jax.tools.colab_tpu

        jax.tools.colab_tpu.setup_tpu()
        print("A TPU is connected.")
    else:
        print("Only CPU accelerator is connected.")

Only CPU accelerator is connected.


In [3]:
import warnings
warnings.filterwarnings("ignore")

from jumanji.training.train import train
from hydra import compose, initialize

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
env = "snake"  # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']
agent = "random"  # @param ['random', 'a2c']

In [5]:
#@title Download Jumanji Configs (run me) { display-mode: "form" }

import os
import requests


def download_file(url: str, file_path: str) -> None:
    # Send an HTTP GET request to the URL
    response = requests.get(url)
    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        with open(file_path, "wb") as f:
            f.write(response.content)
    else:
        print("Failed to download the file.")


os.makedirs("configs", exist_ok=True)
config_url = "https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml"
download_file(config_url, "configs/config.yaml")
env_url = f"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env}.yaml"
os.makedirs("configs/env", exist_ok=True)
download_file(env_url, f"configs/env/{env}.yaml")

In [6]:
with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="config.yaml", overrides=[f"env={env}", f"agent={agent}", "logger.type=terminal", "logger.save_checkpoint=true"])
cfg

{'agent': 'random', 'seed': 0, 'logger': {'type': 'terminal', 'save_checkpoint': True, 'name': '${agent}_${env.name}'}, 'env': {'name': 'snake', 'registered_version': 'Snake-v1', 'network': {'num_channels': 32, 'policy_layers': [64, 64], 'value_layers': [128, 128]}, 'training': {'num_epochs': 200, 'num_learner_steps_per_epoch': 500, 'n_steps': 20, 'total_batch_size': 128}, 'evaluation': {'eval_total_batch_size': 1000, 'greedy_eval_total_batch_size': 1000}, 'a2c': {'normalize_advantage': False, 'discount_factor': 0.997, 'bootstrapping_factor': 0.95, 'l_pg': 1.0, 'l_td': 1.0, 'l_en': 0.01, 'learning_rate': 0.0003}}}

In [7]:
train(cfg)

INFO:root:{'devices': [CpuDevice(id=0)]}
INFO:root:Experiment: random_snake.
INFO:root:Starting logger.
INFO:root:Eval Stochastic >> Env Steps: 0.00e+00 | Episode Length: 1552.182 | Episode Return: 6.197 | Time: 8.931
INFO:root:Train >> Env Steps: 0.00e+00 | Steps Per Second: 200,421 | Time: 6.387
INFO:root:Eval Stochastic >> Env Steps: 1.28e+06 | Episode Length: 1553.735 | Episode Return: 6.130 | Time: 23.055
INFO:root:Train >> Env Steps: 1.28e+06 | Steps Per Second: 252,801 | Time: 5.063
INFO:root:Eval Stochastic >> Env Steps: 2.56e+06 | Episode Length: 1536.620 | Episode Return: 6.155 | Time: 6.869
INFO:root:Train >> Env Steps: 2.56e+06 | Steps Per Second: 328,756 | Time: 3.893
INFO:root:Eval Stochastic >> Env Steps: 3.84e+06 | Episode Length: 1512.204 | Episode Return: 6.259 | Time: 7.027
INFO:root:Train >> Env Steps: 3.84e+06 | Steps Per Second: 325,557 | Time: 3.932
INFO:root:Eval Stochastic >> Env Steps: 5.12e+06 | Episode Length: 1524.308 | Episode Return: 6.067 | Time: 7.090
I