## Imports and Requirements

In [1]:
# install necessary requirements
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


## Dataset

In [3]:
# get BRIDGE dataset setup after download
import os

path = os.path.join(os.getcwd(), "../data/bridge_dataset.zip")
!unzip $path

Archive:  /Users/laurenc/Documents/GitHub/282_expansion/scripts/../data/bridge_dataset.zip
  inflating: bridge_dataset/README.md  
  inflating: bridge_dataset/summary.csv  
   creating: bridge_dataset/mdps_with_reward_shaping/
   creating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/
  inflating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/gorp_3.json  
  inflating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/gorp_5.json  
  inflating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/gorp_2.json  
  inflating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/gorp_4.json  
  inflating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/consolidated_gorp_bounds.json  
  inflating: bridge_dataset/mdps_with_reward_shaping/MiniGrid-KeyCorridorS3R3-PickupShaped-v0/consolidated.npz  
  infla

## Making BRIDGE Environments

In [2]:
# imports
import gym

In [None]:
# making environment from the BRIDGE dataset
def make_environment(type, version_num, rom="pong", horizon="50", frameskip="30", game="coinrun", difficulty="easy", level="10", env="Empty-5x5"):
    allowed_types = ["Atari","Procgen","Minigrid"]
    if type == "Atari":
        return gym.make(f"BRIDGE/{rom}_{horizon}_fs{frameskip}-v{version_num}")
    elif type == "Procgen":
        return gym.make(f"BRIDGE/{game}_{difficulty}_l{level}_{horizon}_fs{frameskip}-v{version_num}")
    elif type == "Minigrid":
        return gym.make(f"BRIDGE/MiniGrid_{env}-v{version_num}")
    else:
        raise(ValueError("Input 'type' must be one of: {}".format(allowed_types)))

## Deep RL Training

In [None]:
# imports
import multiprocessing
import os
from datetime import datetime
from logging import Logger
from typing import Callable, Dict, Optional, Any, cast

import cloudpickle
import numpy as np
import ray
import torch
from ray.rllib.algorithms import Algorithm
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import ENV_CREATOR, _global_registry, get_trainable_cls
from ray.rllib.utils.typing import PolicyID
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.worker_set import WorkerSet
from sacred import SETTINGS as sacred_settings
from sacred import Experiment

ex = Experiment("train")
sacred_settings.CONFIG.READ_ONLY_CONFIG = False

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
# utils to use
def build_logger_creator(log_dir: str, experiment_name: str):
    experiment_dir = os.path.join(
        log_dir,
        experiment_name,
        datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
    )

    def custom_logger_creator(config):
        """
        Creates a Unified logger that stores results in
        <log_dir>/<experiment_name>_<timestamp>
        """

        if not os.path.exists(experiment_dir):
            os.makedirs(experiment_dir, exist_ok=True)
        return UnifiedLogger(config, experiment_dir)

    return custom_logger_creator


def load_policies_from_checkpoint(
    checkpoint_fname: str,
    trainer: Algorithm,
    policy_map: Callable[[PolicyID], PolicyID] = lambda policy_id: policy_id,
):
    """
    Load policy model weights from a checkpoint and copy them into the given
    trainer.
    """

    with open(checkpoint_fname, "rb") as checkpoint_file:
        checkpoint_data = cloudpickle.load(checkpoint_file)
    policy_states: Dict[str, Any] = cloudpickle.loads(checkpoint_data["worker"])[
        "state"
    ]

    policy_weights = {
        policy_map(policy_id): policy_state["weights"]
        for policy_id, policy_state in policy_states.items()
    }

    def copy_policy_weights(policy: Policy, policy_id: PolicyID):
        if policy_id in policy_weights:
            policy.set_weights(policy_weights[policy_id])

    workers: WorkerSet = cast(Any, trainer).workers
    workers.foreach_policy(copy_policy_weights)

In [None]:
# train function
def train(
    config,
    log_dir,
    experiment_name,
    run,
    num_training_iters,
    stop_on_timesteps: Optional[int],
    stop_on_eval_reward: Optional[float],
    stop_on_kl: Optional[float],
    save_freq,
    checkpoint_path: Optional[str],
    checkpoint_to_load_policies: Optional[str],
    _log: Logger,
):
    multiprocessing.set_start_method("spawn")
    ray.init(
        ignore_reinit_error=True,
        include_dashboard=False,
    )

    AlgorithmClass = get_trainable_cls(run)
    trainer: Algorithm = AlgorithmClass(
        config,
        logger_creator=build_logger_creator(
            log_dir,
            experiment_name,
        ),
    )

    if checkpoint_to_load_policies is not None:
        _log.info(f"Initializing policies from {checkpoint_to_load_policies}")
        load_policies_from_checkpoint(checkpoint_to_load_policies, trainer)

    if checkpoint_path is not None:
        _log.info(f"Restoring checkpoint at {checkpoint_path}")
        trainer.restore(checkpoint_path)

    num_iters_below_kl = 0

    result = None
    for train_iter in range(num_training_iters):
        _log.info(f"Starting training iteration {trainer.iteration}")
        result = trainer.train()

        if trainer.iteration % save_freq == 0:
            checkpoint = trainer.save()
            _log.info(f"Saved checkpoint to {checkpoint}")

        if stop_on_eval_reward is not None:
            if (
                result["evaluation"]["episode_reward_mean"]
                >= stop_on_eval_reward - 1e-4
            ):
                break
        if stop_on_kl is not None:
            kl = result["info"]["learner"][DEFAULT_POLICY_ID]["learner_stats"].get(
                "kl", np.inf
            )
            if kl < stop_on_kl:
                num_iters_below_kl += 1
            else:
                num_iters_below_kl = 0
            if num_iters_below_kl >= 10:
                break
        if stop_on_timesteps is not None:
            if result["timesteps_total"] >= stop_on_timesteps:
                break

        episode_lengths = result["sampler_results"]["hist_stats"]["episode_lengths"]
        if len(episode_lengths) > 0:
            max_episode_len = max(episode_lengths)
            if run == "GORP" and trainer.iteration > max_episode_len:
                break

    checkpoint = trainer.save()
    _log.info(f"Saved final checkpoint to {checkpoint}")

    # Symlink final checkpoint to checkpoint_final
    os.symlink(
        os.path.basename(checkpoint),
        os.path.join(os.path.dirname(checkpoint), "checkpoint_final"),
    )

    return result

## Get Tabular MDP Representations -> Get the consolidated.npz file

In [None]:
# imports
from dataclasses import dataclass, field
from typing import Union

In [None]:
# utils
@dataclass
class MDPConfig:
    horizon: int
    log_dir: Union[str, None] = None
    done_on_reward: bool = False
    no_done_reward: float = 0.0
    noops_after_horizon: int = 0
    frameskip: int = 5
    num_workers: int = 1

@dataclass
class AtariMDPConfig(MDPConfig):
    rom_file: str
    done_on_life_lost: bool

@dataclass
class MiniGridMDPConfig(MDPConfig):
    env_name: str

@dataclass
class ProcgenMDPConfig(MDPConfig):
    env_name: str
    level: int = 0
    distribution_mode: str

@dataclass
class AtariEnv:
    ale: ALE
    lives: int
    config: AtariMDPConfig
    life_lost: bool

    def __post_init__(self):
        # Set ALE configurations
        self.ale.set_logger_mode('error')
        self.ale.set_int('random_seed', 0)
        self.ale.set_int('system_random_seed', 4753849)
        self.ale.set_float('repeat_action_probability', 0)

        # Load ROM and reset the game
        self.ale.load_rom(self.config.rom_file)
        self.ale.reset_game()

        # Retrieve initial number of lives
        self.lives = self.ale.lives()

# Function to create an instance of AtariEnv
def get_env(config):
    return AtariEnv(ale=ALE(), lives=0, config=config, life_lost=False)

In [None]:
def setup_tabular_configs(type, type_args):
    allowed_types = ["Atari","Procgen","Minigrid"]
    try:
        if type == "Atari":
            return AtariMDPConfig(
                rom_file = type_args["rom"],
                horizon = type_args["horizon"],
                done_on_reward = type_args["done_on_reward"],
                done_on_life_lost = type_args["done_on_life_lost"],
                no_done_reward = type_args["no_done_reward"],
                noops_after_horizon = type_args["noops_after_horizon"],
                frameskip = type_args["frameskip"],
                log_dir = type_args["out"])
        elif type == "Procgen":
            return ProcgenMDPConfig(
                env_name = type_args["env_name"],
                distribution_mode = type_args["distribution_mode"],
                level = type_args["level"],
                horizon = type_args["horizon"],
                done_on_reward = type_args["done_on_reward"],
                no_done_reward = type_args["no_done_reward"],
                noops_after_horizon = type_args["noops_after_horizon"],
                frameskip = type_args["frameskip"],
                log_dir = type_args["out"])
        elif type == "Minigrid":
            return MiniGridMDPConfig(
                env_name = type_args["env_name"],
                horizon = type_args["horizon"],
                done_on_reward = type_args["done_on_reward"],
                no_done_reward = type_args["no_done_reward"],
                noops_after_horizon = type_args["noops_after_horizon"],
                frameskip = type_args["frameskip"],
                log_dir = type_args["out"])
        else:
            raise(ValueError("Input 'type' must be one of: {}".format(allowed_types)))
    except:
        raise(ValueError("Input 'type_args' dictionary does not contain all necessary values."))


# TODO
def construct_mdp(mdp_config):
    # get env by input config type
    if isinstance(mdp_config, AtariMDPConfig):
        env = AtariEnv(ale=ALE(), lives=0, config=mdp_config, life_lost=False)
    elif isinstance(mdp_config, MiniGridMDPConfig):
        pass
    elif isinstance(mdp_config, ProcgenMDPConfig):
        pass
    else:
        raise ValueError("Input 'mdp_config' type is not one of the handled types.")
    pass

def get_tabular_reps(type, type_args):
    # setup configuration struct based on inputs
    config = setup_tabular_configs(type, type_args)
    # construct the MDP
    mdp = construct_mdp(config)
    # get the npz file

    pass

## Get complexity bounds (EPW and EH)

## Instance

In [None]:
# setup env of interest spec (constants)
