<a target="_blank" href="https://colab.research.google.com/github/instadeepai/jumanji/blob/main/examples/training.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [8]:
%%capture
!pip install --quiet -U "jumanji[train] @ git+https://github.com/instadeepai/jumanji.git@main"

In [9]:
# @title Set up JAX for available hardware (run me) { display-mode: "form" }

import subprocess
import os

# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system
try:
    subprocess.check_output("nvidia-smi")
    print("a GPU is connected.")
except Exception:
    # TPU or CPU
    if "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
        import jax.tools.colab_tpu

        jax.tools.colab_tpu.setup_tpu()
        print("A TPU is connected.")
    else:
        print("Only CPU accelerator is connected.")

Only CPU accelerator is connected.


In [10]:
import warnings

warnings.filterwarnings("ignore")

from jumanji.training.train import train
from hydra import compose, initialize

In [11]:
env = "tetris"  # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'search_and_rescue', 'snake', 'sudoku', 'tetris', 'tsp']
agent = "a2c_simba"  # @param ['random', 'a2c', 'a2c_simba']

In [12]:
# @title Download Jumanji Configs (run me) { display-mode: "form" }

import os
import requests


def download_file(url: str, file_path: str) -> None:
    # Send an HTTP GET request to the URL
    response = requests.get(url)
    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        with open(file_path, "wb") as f:
            f.write(response.content)
    else:
        print("Failed to download the file.")


os.makedirs("configs", exist_ok=True)
config_url = "https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml"
download_file(config_url, "configs/config.yaml")
env_url = f"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env}.yaml"
os.makedirs("configs/env", exist_ok=True)
download_file(env_url, f"configs/env/{env}.yaml")

In [7]:
with initialize(version_base=None, config_path="../jumanji/training/configs"):
    cfg = compose(
        config_name="config.yaml",
        overrides=[
            f"env={env}",
            f"agent={agent}",
            "logger.type=terminal",
            "logger.save_checkpoint=true",
        ],
    )

train(cfg)

INFO:root:{'devices': [CpuDevice(id=0)]}
INFO:root:Experiment: a2c_tetris.
INFO:root:Starting logger.
INFO:root:Eval Stochastic >> Env Steps: 0.00e+00 | Episode Length: 14.085 | Episode Return: 3.301 | Time: 1.133
INFO:root:Eval Greedy >> Env Steps: 0.00e+00 | Episode Length: 13.934 | Episode Return: 3.926 | Time: 1.049
INFO:root:Saving checkpoint...

KeyboardInterrupt

