## Training an Atari Breakout agent using PPO

In [None]:
import ray
from ray.tune import run

In [None]:
ray.init()

In [None]:
result = run(
    "PPO",
    name="atari-breakout",
    local_dir="/tmp/ray-results",
    checkpoint_at_end=True,
    stop={"timesteps_total": 1},
    config={
        # Based on `rllib/tuned_examples/ppo/atari-ppo.yaml`
        "env": "BreakoutNoFrameskip-v4",
        "framework": "torch",
        "lambda": 0.95,
        "kl_coeff": 0.5,
        "clip_rewards": True,
        "clip_param": 0.1,
        "vf_clip_param": 10.0,
        "entropy_coeff": 0.01,
        "train_batch_size": 5000,
        "rollout_fragment_length": 100,
        "sgd_minibatch_size": 500,
        "num_sgd_iter": 10,
        "num_workers": 2,
        "num_envs_per_worker": 5,
        "batch_mode": "truncate_episodes",
        "observation_filter": "NoFilter",
        "vf_share_layers": True,
        "num_gpus": 0,
    },
)

In [None]:
checkpoint_path = result.get_best_trial("episode_reward_mean").checkpoint.value

In [None]:
ray.shutdown()

## Collect a rollout

In [None]:
!rllib rollout \
--run PPO \
--use-shelve \
--no-render \
--episodes 5 \
--out /tmp/atari-breakout.ray_rollout \
$checkpoint_path

## Calculate attributations

In [None]:
config_content = f"""
import json
from pathlib import Path

import numpy as np
import ray
from ray.rllib.agents.ppo import PPOTrainer

from rld.attributation import AttributationTarget, AttributationNormalizationMode
from rld.config import Config
from rld.model import Model, RayModelWrapper
from rld.typing import ObsLike


def get_model() -> Model:
    checkpoint_path = Path(r"{checkpoint_path}")
    params_path = checkpoint_path.parents[1] / "params.json"
    with open(params_path) as f:
        params = json.load(f)
    ray.init()
    trainer = PPOTrainer(config=params)
    trainer.restore(str(checkpoint_path))
    model = RayModelWrapper(trainer.get_policy().model)
    ray.shutdown()
    return model


def baseline_builder(obs: ObsLike):
    return np.zeros_like(obs)


model = get_model()


config = Config(
    model=model,
    baseline=baseline_builder,
    target=AttributationTarget.TOP5,
    normalize_sign=AttributationNormalizationMode.POSITIVE,
)

"""
with open("/tmp/atari_breakout_config.py", "w") as f:
    f.write(config_content)

In [None]:
!rld attribute \
--rllib \
--out /tmp/atari-breakout.rld \
/tmp/atari_breakout_config.py \
/tmp/atari-breakout.ray_rollout

## Visualize results

In [None]:
!rld start --viewer cartpole /tmp/atari-breakout.rld