# Deep Q-Learning: Space Invaders

In [1]:
import json
from functools import partial
from tempfile import TemporaryFile
from typing import Any, NamedTuple

import numpy as np
import optuna
import pandas as pd
import torch
from dotenv import dotenv_values
from tqdm import tqdm

from r2seedo.io import load_keypair, load_n_verify_model, sign_n_save_model
from r2seedo.models.dqn import (
    DQN,
    DQNConfig,
    DQNTrainingConfig,
    calculate_loss,
    update_target_network,
)
from r2seedo.utils import get_device
from r2seedo.utils.environment import (
    AtariEnvConfig,
    capture_replay,
    copy_n_adj_next_observation,
    get_replay_buffer,
)
from r2seedo.utils.training import (
    ExplorationConfig,
    LearningRateConfig,
    ReplayBufferConfig,
    ScheduleMode,
    TrainingConfig,
    set_learning_rate,
)

# Environment Configuration

In [2]:
# Environment configuration
env_id = "SpaceInvadersNoFrameskip-v4"
env_config = AtariEnvConfig(env_id=env_id)
eval_env_config = AtariEnvConfig(**{**env_config.to_dict(), "clip_reward": False})
print(env_config)

{
  "env_id": "SpaceInvadersNoFrameskip-v4",
  "clip_reward": true,
  "frame_skip": 4,
  "frame_stack": 4,
  "grayscale": true,
  "noop_max": 30,
  "screen_size": [
    84,
    84
  ],
  "terminal_on_life_loss": true
}


# Hyperparameter Sampler

In [3]:
def sample_hyperparams(trial: optuna.Trial) -> dict[str, Any]:
    """Sample hyperparameters for training a DQN model."""
    dqn_params = {
        # "exp_hidden_dim": trial.suggest_int("exp_hidden_dim", 3, 7),  # [8, 128]
    }
    train_params = {
        # "lr_start": trial.suggest_float("lr_start", 2e-5, 1e-2, log=True),
        # "lr_constant_fraction": trial.suggest_float("lr_constant_fraction", 0.0, 0.5),
        "exploration_fraction": trial.suggest_float("exploration_fraction", 0.05, 0.7),
        # "explore_anneal_mode": trial.suggest_categorical(
        #     "explore_anneal_mode", [ScheduleMode.LINEAR, ScheduleMode.EXPONENTIAL]
        # ),
        "gamma": 1.0 - trial.suggest_float("gamma", 0.001, 0.2),
        # "double_q": trial.suggest_categorical("double_q", [True, False]),
        "target_inertia": 1.0
        - trial.suggest_float("target_inertia", 0.0001, 0.1, log=True),
    }

    return {"dqn_params": dqn_params, "train_params": train_params}


def get_configs_from_hyperparams(
    total_timesteps: int,
    evaluation_rate: float,
    observation_shape: tuple[int, int, int],
    num_actions: int,
    hyperparams: dict[str, Any] | None = None,
) -> tuple[DQNConfig, TrainingConfig]:
    """Initialize configurations for DQN model & training from hyperparameters."""
    # Unpack hyperparameters
    if hyperparams is None:
        hyperparams = {}
    dqn_params: dict[str, Any] = hyperparams.get("dqn_params", {})
    train_params: dict[str, Any] = hyperparams.get("train_params", {})

    # DQN configuration
    dqn_config = DQNConfig(
        observation_shape=observation_shape,
        num_actions=num_actions,
        hidden_dim=2 ** dqn_params.get("exp_hidden_dim", 6),
    )

    # Training configuration
    training_config = TrainingConfig(
        total_timesteps=total_timesteps,
        num_envs=3,
        train_freq=4,
        evaluation_rate=evaluation_rate,
        learning_rate_config=LearningRateConfig(
            lr_start=train_params.get("lr_start", 1e-3),
            constant_fraction=train_params.get("lr_constant_fraction", 0.25),
        ),
        exploration_config=ExplorationConfig(
            exploration_fraction=train_params.get("exploration_fraction", 0.2),
            anneal_mode=train_params.get("explore_anneal_mode", ScheduleMode.LINEAR),
        ),
        replay_buffer_config=ReplayBufferConfig(buffer_fraction=0.25, batch_size=32),
        other=DQNTrainingConfig(
            gamma=train_params.get("gamma", 0.99),
            double_q=train_params.get("double_q", True),
            target_inertia=train_params.get("target_inertia", 0.999),
        ),
    )

    return dqn_config, training_config

# Evaluation function

In [4]:
class EvaluationResults(NamedTuple):
    """Evaluation results."""

    mean_score: float
    score_std: float
    mean_len: float
    len_std: float


def evaluate(
    dqn: DQN,
    env_config: AtariEnvConfig,
    epsilon: float = 0,
    num_episodes: int = 30,
    seed: int | None = 1234,
) -> EvaluationResults:
    """Evaluate DQN agent on Atari environment."""
    # Set model to evaluation mode
    dqn.eval()

    # Initialize environment
    evaluation_env = env_config.make_env(num_envs=1, seed=seed)

    # Evaluate model
    scores = []
    episode_lens = []
    for _ in range(num_episodes):
        # Reset environment
        obs = evaluation_env.reset()
        rewards: list[float] = []
        done = False

        # Run episode
        while not done:
            # Choose action
            with torch.no_grad():
                action = (
                    dqn.get_action(
                        observation=torch.from_numpy(obs).to(dqn.device),
                        epsilon=epsilon,
                    )
                    .cpu()
                    .numpy()
                )

            # Step environment
            obs, reward, done, _ = evaluation_env.step(action)

            # Update episode reward
            rewards.append(reward.item())

        # Tabulate episode results
        scores.append(sum(rewards))
        episode_lens.append(len(rewards))

    # Close environment
    evaluation_env.close()

    # Return results
    return EvaluationResults(
        mean_score=np.mean(scores).item(),
        score_std=np.std(scores).item(),
        mean_len=np.mean(episode_lens).item(),
        len_std=np.std(episode_lens).item(),
    )

# Training Function

In [7]:
def train_dqn(
    env_config: AtariEnvConfig,
    dqn_config: DQNConfig,
    training_config: TrainingConfig,
    eval_env_config: AtariEnvConfig | None = None,
    trial: optuna.Trial | None = None,
    device: torch.device | None = None,
) -> tuple[DQN, pd.DataFrame]:
    """Initialize and train a DQN model."""
    # Set device
    if device is None:
        device = get_device()

    # Initialize training environment
    training_env = env_config.make_env(num_envs=training_config.num_envs)

    # Initialize networks
    online_net = DQN(dqn_config).to(device)
    target_net = DQN(dqn_config).to(device).eval()

    # Initialize replay buffer
    replay_buffer = get_replay_buffer(
        env=training_env,
        buffer_size=training_config.replay_buffer_config.get_buffer_size(
            training_config.total_timesteps
        ),
        device=device,
    )

    # Initialize optimizer
    optimizer = training_config.optimizer_cls(online_net.parameters())

    # Parameter schedules over training
    param_schedule = training_config.get_parameter_schedule()

    # Reset environment and get initial observation
    obs = training_env.reset()

    # Initialize evaluation statistics
    if eval_env_config is None:
        # Use training environment configuration for evaluation
        eval_env_config = env_config
    eval_stats = [evaluate(online_net, eval_env_config)]

    # Create a temporary file to store model state
    state_dict_fp = TemporaryFile()
    torch.save(online_net.state_dict(), state_dict_fp)
    state_dict_fp.seek(0)

    # Run training loop
    for step_i in tqdm(range(training_config.total_timesteps)):
        # Set online network to training mode
        online_net.train()

        # Get action from online network
        with torch.no_grad():
            action = (
                online_net.get_action(
                    observation=torch.from_numpy(obs).to(device),
                    epsilon=param_schedule[step_i].exploration_rate,
                )
                .cpu()
                .numpy()
            )

        # Take step in environment given action
        next_obs, reward, termination, infos = training_env.step(action)

        # Add to replay buffer
        replay_buffer.add(
            obs=obs,
            next_obs=copy_n_adj_next_observation(next_obs, termination, infos),
            action=action,
            reward=reward,
            done=termination,
            infos=infos,
        )

        # Update observation
        obs = next_obs

        # Train online network
        if (step_i % training_config.train_freq == 0) and (
            step_i >= training_config.replay_buffer_config.batch_size
        ):
            # Sample batch from replay buffer
            samples = replay_buffer.sample(
                training_config.replay_buffer_config.batch_size
            )

            # Compute TD loss relative to target network's value estimate
            loss = calculate_loss(
                online_net,
                target_net,
                samples,
                gamma=training_config.other.gamma,
                double_q=training_config.other.double_q,
            )

            # Update online network
            set_learning_rate(optimizer, param_schedule[step_i].learning_rate)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update target network
            update_target_network(
                target_net,
                online_net,
                target_inertia=training_config.other.target_inertia,
            )

        # Evaluate
        if step_i % training_config.evaluation_freq == 0:
            # Evaluate online network
            stats_i = evaluate(online_net, eval_env_config)
            eval_stats.append(stats_i)

            # Save model state
            if stats_i.mean_score >= max(s.mean_score for s in eval_stats):
                torch.save(online_net.state_dict(), state_dict_fp)
                state_dict_fp.seek(0)

            if trial is not None:
                # Report intermediate evaluation results to Optuna
                trial.report(stats_i.mean_score, step_i)

                if trial.should_prune():
                    # Close environment and prune trial
                    training_env.close()
                    raise optuna.exceptions.TrialPruned()

    # Close environment
    training_env.close()

    # Evaluate final model
    stats_i = evaluate(online_net, eval_env_config)
    eval_stats.append(stats_i)

    if stats_i.mean_score < max(s.mean_score for s in eval_stats):
        # Load best model state
        online_net.load_state_dict(torch.load(state_dict_fp))
    state_dict_fp.close()

    # Convert evaluation statistics to DataFrame
    eval_stats = pd.DataFrame(eval_stats)

    return online_net, eval_stats

# Objective function

In [None]:
def objective(
    trial: optuna.Trial,
    env_config: AtariEnvConfig,
    eval_env_config: AtariEnvConfig | None,
    observation_shape: tuple[int, int, int],
    num_actions: int,
    total_timesteps: int,
    evaluation_rate: float,
) -> float:
    """Optimization objective."""
    # Sample hyperparameters
    hyperparams = sample_hyperparams(trial)

    # Get configurations
    dqn_config, training_config = get_configs_from_hyperparams(
        total_timesteps=total_timesteps,
        evaluation_rate=evaluation_rate,
        observation_shape=observation_shape,
        num_actions=num_actions,
        hyperparams=hyperparams,
    )

    # Train DQN model
    dqn, eval_stats = train_dqn(
        env_config=env_config,
        dqn_config=dqn_config,
        training_config=training_config,
        eval_env_config=eval_env_config,
        trial=trial,
    )

    # Report evaluation results
    return eval_stats["mean_score"].iloc[-1]

# Run hyperparameter optimization

In [None]:
NUM_TIMESTEPS = 100_000
EVAL_RATE = 0.2
WARMUP_STEPS = int(NUM_TIMESTEPS * 2 * EVAL_RATE)
NUM_TRIALS = 30
NUM_STARTUP_TRIALS = 3

# Select sampler
sampler = optuna.samplers.GPSampler(n_startup_trials=NUM_STARTUP_TRIALS)

# Select pruner
pruner = optuna.pruners.MedianPruner(
    n_startup_trials=NUM_STARTUP_TRIALS, n_warmup_steps=WARMUP_STEPS
)

# Create study
study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")

# Run hyperparameter optimization
study.optimize(
    func=partial(
        objective,
        env_config=env_config,
        eval_env_config=eval_env_config,
        observation_shape=(4, 84, 84),
        num_actions=6,
        total_timesteps=NUM_TIMESTEPS,
        evaluation_rate=EVAL_RATE,
    ),
    n_trials=NUM_TRIALS,
    n_jobs=1,
)

# Save study results
study_results = study.trials_dataframe()
study_results.to_csv("../models/deep-q-space-invaders/study_results.csv", index=False)

# Train final model

In [8]:
dqn_config, training_config = get_configs_from_hyperparams(
    total_timesteps=300_000,
    evaluation_rate=0.1,
    observation_shape=(4, 84, 84),
    num_actions=6,
    hyperparams={
        "train_params": {
            "exploration_fraction": 0.10,
            "gamma": 0.8,
            "target_inertia": 0.995,
        }
    },
)

model, results = train_dqn(
    env_config=env_config,
    dqn_config=dqn_config,
    training_config=training_config,
    eval_env_config=eval_env_config,
    device=get_device(),
)

100%|██████████| 300000/300000 [42:43<00:00, 117.03it/s]  


# Save model

In [13]:
# Sign and save the model to local directory
env_config = dotenv_values()

sign_n_save_model(
    model=model.cpu(),
    destination="../models/deep-q-space-invaders",
    keypair=load_keypair("../models/private.pem", env_config["MODEL_KEY_PASSWORD"]),
    overwrite=True,
)

# Load and evaluate model

In [14]:
# Load model
agent: DQN = load_n_verify_model("../models/deep-q-space-invaders")
agent.eval()

# Evaluate the model
num_episodes = 30
eval_results = evaluate(
    agent,
    eval_env_config,
)

eval_stats = {
    "env_id": eval_env_config.env_id,
    "num_episodes": num_episodes,
    "mean_reward": eval_results.mean_score,
    "std_reward": eval_results.score_std,
}

with open("../models/deep-q-space-invaders/eval_results.json", "w") as f:
    json.dump(eval_stats, f, indent=2)

# Capture replay

In [None]:
rewards = capture_replay(
    env=eval_env_config.make_env(num_envs=1, seed=0).envs[0],
    action_func=partial(agent.get_action, epsilon=0.0),
    video_folder="../models/deep-q-space-invaders",
    max_steps=None,
)
print(f"Total steps: {len(rewards)}\nTotal reward: {sum(rewards)}")