In [None]:
#@title mount your Google Drive
#@markdown Your work will be stored in a folder called `cs285_f2022` by default to prevent Colab instance timeouts from deleting your edits.

import os
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title set up mount symlink

DRIVE_PATH = '/content/gdrive/My\ Drive/rl_class'
DRIVE_PYTHON_PATH = DRIVE_PATH.replace('\\', '')
if not os.path.exists(DRIVE_PYTHON_PATH):
  %mkdir $DRIVE_PATH

## the space in `My Drive` causes some issues,
## make a symlink to avoid this
SYM_PATH = '/content/rl_class'
if not os.path.exists(SYM_PATH):
  !ln -s $DRIVE_PATH $SYM_PATH

In [None]:
# print('NOTE: Intentionally crashing session to use the newly installed library.\n')

# !pip uninstall -y pyarrow
# !pip install ray[rllib]
# !pip install bs4


In [None]:
#@title apt install requirements

#@markdown Run each section with Shift+Enter

#@markdown Double-click on section headers to show code.

!apt update
!apt install -y --no-install-recommends \
        build-essential \
        curl \
        git \
        gnupg2 \
        make \
        cmake \
        ffmpeg \
        swig \
        libz-dev \
        unzip \
        zlib1g-dev \
        libglfw3 \
        libglfw3-dev \
        libxrandr2 \
        libxinerama-dev \
        libxi6 \
        libxcursor-dev \
        libgl1-mesa-dev \
        libgl1-mesa-glx \
        libglew-dev \
        libosmesa6-dev \
        lsb-release \
        ack-grep \
        patchelf \
        wget \
        xpra \
        xserver-xorg-dev \
        xvfb \
        python-opengl \
        ffmpeg

clone homework repo

In [None]:
%cd $SYM_PATH
!git clone https://github.com/heng2j/multigrid.git
%git checkout colab_training_notebook
%cd multigrid/
%pip install -r requirements_colab.txt
%pip install -e .

In [None]:
!pip install gym tensorboard moviepy torch opencv-python swig box2d-py ray[rllib] scikit-image pygame numba Gymnasium black PyYAML

In [None]:
from platform import python_version
print(python_version())

In [None]:
from __future__ import annotations

import argparse
import json
import os
import ray

from multigrid.envs import *
from multigrid.rllib.models import TFModel, TorchModel, TorchLSTMModel
from pathlib import Path
from pprint import pprint
from ray import tune
from ray.rllib.algorithms import AlgorithmConfig
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import NotProvided
from ray.tune.registry import get_trainable_cls


In [None]:


def get_checkpoint_dir(search_dir: Path | str | None) -> Path | None:
    """
    Recursively search for checkpoints within the given directory.

    If more than one is found, returns the most recently modified checkpoint directory.

    Parameters
    ----------
    search_dir : Path or str
        The directory to search for checkpoints within
    """
    if search_dir:
        checkpoints = Path(search_dir).expanduser().glob('**/*.is_checkpoint')
        if checkpoints:
            return sorted(checkpoints, key=os.path.getmtime)[-1].parent

    return None

def can_use_gpu() -> bool:
    """
    Return whether or not GPU training is available.
    """
    try:
        _, tf, _ = try_import_tf()
        return tf.test.is_gpu_available()
    except:
        pass

    try:
        torch, _ = try_import_torch()
        return torch.cuda.is_available()
    except:
        pass

    return False

def policy_mapping_fn(agent_id: int, *args, **kwargs) -> str:
    """
    Map an environment agent ID to an RLlib policy ID.
    """
    return f'policy_{agent_id}'

def model_config(
    framework: str = 'torch',
    lstm: bool = False,
    custom_model_config: dict = {}):
    """
    Return a model configuration dictionary for RLlib.
    """
    if framework == 'torch':
        if lstm:
            model = TorchLSTMModel
        else:
            model = TorchModel
    else:
        if lstm:
            raise NotImplementedError
        else:
            model = TFModel

    return {
        'custom_model': model,
        'custom_model_config': custom_model_config,
        'conv_filters': [
            [16, [3, 3], 1],
            [32, [3, 3], 1],
            [64, [3, 3], 1],
        ],
        'fcnet_hiddens': [64, 64],
        'post_fcnet_hiddens': [],
        'lstm_cell_size': 256,
        'max_seq_len': 20,
    }

def algorithm_config(
    algo: str = 'PPO',
    env: str = 'MultiGrid-Empty-8x8-v0',
    env_config: dict = {},
    num_agents: int = 2,
    framework: str = 'torch',
    lstm: bool = False,
    num_workers: int = 0,
    num_gpus: int = 0,
    lr: float | None = None,
    **kwargs) -> AlgorithmConfig:
    """
    Return the RL algorithm configuration dictionary.
    """
    env_config = {**env_config, 'agents': num_agents}
    return (
        get_trainable_cls(algo)
        .get_default_config()
        .environment(env=env, env_config=env_config)
        .framework(framework)
        .rollouts(num_rollout_workers=num_workers)
        .resources(num_gpus=num_gpus if can_use_gpu() else 0)
        .multi_agent(
            policies={f'policy_{i}' for i in range(num_agents)},
            policy_mapping_fn=policy_mapping_fn,
        )
        .training(
            model=model_config(framework=framework, lstm=lstm),
            lr=(lr or NotProvided),
        )
    )

def train(
    algo: str,
    config: AlgorithmConfig,
    stop_conditions: dict,
    save_dir: str,
    load_dir: str | None = None):
    """
    Train an RLlib algorithm.
    """
    ray.init(num_cpus=(config.num_rollout_workers + 1))
    tune.run(
        algo,
        stop=stop_conditions,
        config=config,
        local_dir=save_dir,
        verbose=1,
        restore=get_checkpoint_dir(load_dir),
        checkpoint_freq=20,
        checkpoint_at_end=True,
    )
    ray.shutdown()



In [None]:
from argparse import Namespace
my_dict = {'algo': "PPO", 'framework': "torch", 'env': "MultiGrid-CompetativeRedBlueDoor-v0", "env_config": {}, "num_agents" : 1, "num_workers" : 10 , "num_gpus" : 1, "save_dir" : "./ray_results/", "num_timesteps" :  1e7 }
args = Namespace(**my_dict)

In [None]:
# parser


In [None]:

config = algorithm_config(**vars(args))
stop_conditions = {'timesteps_total': args.num_timesteps}


In [None]:
algo = "PPO"

train(algo, config, stop_conditions, args.save_dir, None)


In [None]:
ray.shutdown()

In [None]:
# Only need to train for 12 mins for 40 iterations with 10 workers and 1 GPU

TODO:
1. Show tensorport plots
2. Visualizing Agent performance
3. Saving GIF or video
4. Submit plots + Checkpoint
