<a target="_blank" href="https://colab.research.google.com/github/instadeepai/jumanji/blob/main/examples/load_checkpoints.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
%pip install --quiet -U "jumanji[train] @ git+https://github.com/instadeepai/jumanji.git@main"

Note: you may need to restart the kernel to use updated packages.


In [3]:
# @title Set up JAX for available hardware (run me) { display-mode: "form" }

import subprocess
import os

# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system
try:
    subprocess.check_output('nvidia-smi')
    print("a GPU is connected.")
except Exception:
    # TPU or CPU
    if "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
        import jax.tools.colab_tpu

        jax.tools.colab_tpu.setup_tpu()
        print("A TPU is connected.")
    else:
        print("Only CPU accelerator is connected.")


Only CPU accelerator is connected.


In [4]:
import pickle

import jax
from huggingface_hub import hf_hub_download
from hydra import compose, initialize

from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device

%matplotlib notebook

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Load configs

In [5]:
env = "bin_pack"  # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']
agent = "a2c"  # @param ['random', 'a2c']

In [6]:
#@title Download Jumanji Configs (run me) { display-mode: "form" }

import os
import requests

def download_file(url: str, file_path: str) -> None:
    # Send an HTTP GET request to the URL
    response = requests.get(url)
    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        with open(file_path, "wb") as f:
            f.write(response.content)
    else:
        print("Failed to download the file.")

os.makedirs("configs", exist_ok=True)
config_url = "https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml"
download_file(config_url, "configs/config.yaml")
env_url = f"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env}.yaml"
os.makedirs("configs/env", exist_ok=True)
download_file(env_url, f"configs/env/{env}.yaml")

In [7]:
with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="config.yaml", overrides=[f"env={env}", f"agent={agent}"])
cfg

{'agent': 'a2c', 'seed': 0, 'logger': {'type': 'terminal', 'save_checkpoint': False, 'name': '${agent}_${env.name}'}, 'env': {'name': 'bin_pack', 'registered_version': 'BinPack-v2', 'network': {'num_transformer_layers': 2, 'transformer_num_heads': 8, 'transformer_key_size': 16, 'transformer_mlp_units': [512]}, 'training': {'num_epochs': 550, 'num_learner_steps_per_epoch': 100, 'n_steps': 30, 'total_batch_size': 64}, 'evaluation': {'eval_total_batch_size': 10000, 'greedy_eval_total_batch_size': 10000}, 'a2c': {'normalize_advantage': False, 'discount_factor': 1.0, 'bootstrapping_factor': 0.95, 'l_pg': 1.0, 'l_td': 1.0, 'l_en': 0.005, 'learning_rate': 0.0001}}}

## Load a checkpoint from the Hugging Face Hub

In [8]:
# Chose the corresponding checkpoint from the InstaDeep Model Hub
# https://huggingface.co/InstaDeepAI
REPO_ID = f"InstaDeepAI/jumanji-benchmark-a2c-{cfg.env.registered_version}"
FILENAME = f"{cfg.env.registered_version}_training_state"

model_checkpoint = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

with open(model_checkpoint, "rb") as f:
    training_state = pickle.load(f)

In [9]:
params = first_from_device(training_state.params_state.params)
env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
if agent == "a2c":
    policy = lambda *args: policy(*args)[0]

## Rollout a few episodes

In [10]:
NUM_EPISODES = 2

reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)
states = []
key = jax.random.PRNGKey(cfg.seed)
for episode in range(NUM_EPISODES):
    key, reset_key = jax.random.split(key) 
    state, timestep = reset_fn(reset_key)
    while not timestep.last():
        key, action_key = jax.random.split(key)
        observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
        action, _ = policy(observation, action_key)
        state, timestep = step_fn(state, action.squeeze(axis=0))
        states.append(state)
    # Freeze the terminal frame to pause the GIF.
    for _ in range(3):
        states.append(state)



## Display animation

In [11]:
env.animate(states, interval=150)

<IPython.core.display.Javascript object>