# About the tutorial
In this tutorial, we will explain how to register classes and objects into the RLHive framework.

# Introduction and Setup

### What is RLHive and how to install it

RLHive is a framework designed to facilitate research in reinforcement learning. It provides the components necessary to run a full RL experiment, for both single agent and multi agent environments. It is designed to be readable and easily extensible, to allow users to quickly run and experiment with their own ideas.

Installation - To install a specific branch from github using pip - first copy the **clone URL** of the package and then add **@ symbol** with the name of the specific branch of the package that you want to install from. Example to install from the dev branch use : \\
`https://github.com/chandar-lab/RLHive.git@dev`

In [None]:
## used for updating config.yaml files 
!pip install ruamel.yaml
!pip install pyglet

!pip install git+https://github.com/chandar-lab/RLHive.git@dev

In [None]:
!apt-get install x11-utils > /dev/null 2>&1 
!pip install pyglet > /dev/null 2>&1 
!apt-get install -y xvfb python-opengl > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1

In [None]:
## Required imports

import hive
import torch
from hive.agents.dqn import DQNAgent
from hive.runners.utils import load_config
from hive.runners.single_agent_loop import set_up_experiment
from hive.utils.loggers import get_logger
from ruamel import yaml
import os.path
import numpy as np
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay
from hive.utils.utils import Registrable
import os
import sys

### How to install environments

RLHive currently supports the following environments:



*   Gym classic control
*   Atari
* Minatar (simplified Atari)
* Minigrid (single-agent grid world)
* Marlgrid (multi-agent)
* Pettingzoo (multi-agent)

To install Gym, you could simply run `pip install gym==0.21.0`. You can also install dependencies necessary for the environments that RLHive comes with by running `pip install rlhive[<env_names>]` where `<env_names>` is a comma separated list made up of `atari`, `gym_minigrid`, and `pettingzoo`.

Minatar and Marlgrid are also supported, but must be installed separately.

* To install Minatar, run `pip install MinAtar@git+https://github.com/kenjyoung/MinAtar.git@8b39a18a60248ede15ce70142b557f3897c4e1eb`
* To install Marlgrid, run `pip install marlgrid@https://github.com/kandouss/marlgrid/archive/refs/heads/master.zip`

In [None]:
!pip install gym==0.21.0
!pip install rlhive[atari]

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gym==0.21.0
  Downloading gym-0.21.0.tar.gz (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 39.3 MB/s 
Building wheels for collected packages: gym
  Building wheel for gym (setup.py) ... [?25l[?25hdone
  Created wheel for gym: filename=gym-0.21.0-py3-none-any.whl size=1616826 sha256=8a57871d39718ad4c7be8ddad66349fd6632a48e0ce3515696ad522e82113ef0
  Stored in directory: /root/.cache/pip/wheels/76/ee/9c/36bfe3e079df99acf5ae57f4e3464ff2771b34447d6d2f2148
Successfully built gym
Installing collected packages: gym
  Attempting uninstall: gym
    Found existing installation: gym 0.25.2
    Uninstalling gym-0.25.2:
      Successfully uninstalled gym-0.25.2
Successfully installed gym-0.21.0


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ale-py~=0.7.1
  Downloading ale_py-0.7.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 19.0 MB/s 
[?25hCollecting autorom[accept-rom-license]~=0.4.2
  Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)
Collecting AutoROM.accept-rom-license
  Downloading AutoROM.accept-rom-license-0.4.2.tar.gz (9.8 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: AutoROM.accept-rom-license
  Building wheel for AutoROM.accept-rom-license (PEP 517) ... [?25l[?25hdone
  Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.4.2-py3-none-any.whl size=441027 sha256=e321c6096d04cc341f99dc6c44b7b33ed606b3405d818028784b16ef7d40993b
  Stored in dire

# Introduction to Hive Registry

<!-- ### `hive.utils.registry` -->
The registry module `hive.utils.registry` is used to register classes (corresponding to the agent, environment, logger, or runner) in the RLHive Registry. In other words, it allows you to register different types of `Registrable` classes and objects, and generates constructors for those classes in the form of `get_{type_name}`. These constructors allow you to construct objects from dictionary configs. These configs should have two fields: 


1. `name` - name used when registering a class in the registry
2. `**kwargs` - keyword arguments that will be passed to the constructor of the object


These constructors can also build objects recursively, i.e. if a config contains the config for another `Registrable` object, this will be automatically created before being passed to the constructor of the original object. These constructors also allow you to directly specify/override arguments for object constructors directly from the command line. These parameters are specified in dot `.` notation. They also are able to handle lists and dictionaries of Registrable objects.

For example, let’s consider the following scenario: Your `agent` class has an argument `arg1` which is annotated to be `List[Class1]` ( where `Class1` is `Registrable`), and the `Class1` constructor takes an argument `arg2`. In the passed YAML config, there are two different `Class1` object configs listed. The constructor will check to see if both `agent.arg1.0.arg2` and `agent.arg1.1.arg2` have been passed.

The parameters passed in the command line will be parsed according to the type annotation of the corresponding low level constructor. If it is not one of `int`, `float`, `str`, or `bool`, it simply loads the string into python using a YAML loader.

Each constructor returns the object, as well a dictionary config with all the parameters used to create the object and any `Registrable` objects created in the process of creating this object.

### Registering an Environment

Consider registering a custom environment class named `Grid` (which inherits `BaseEnv`) in the RLHive registry. 

In [None]:
from hive.envs.base import BaseEnv
class Grid(BaseEnv):
    def __init__(self, env_name = 'Grid', **kwargs):
        pass
    def reset(self):
        pass
    def step(self):
        pass
    def render(self):
        pass
    def close(self):
        pass
    def save(self):
        pass

In [None]:
from hive.utils.registry import registry
registry.register(name = 'Grid', 
                  constructor = Grid, 
                  type = BaseEnv)

More than one environment can be registered at once using the `register_all` method. Consider registering three environments, `Gridv1`, `Gridv2`, and `Gridv3` in the RLHive registry.

In [None]:
registry.register_all(
    BaseEnv,
    {
        "Gridv1": Gridv1,
        "Gridv2": Gridv2,
        "Gridv3": Gridv3,
    },
)

### Registering an Agent

Consider registering a custom environment class named `LearningAgent` (which inherits `Agent` class) in the RLHive registry. 

In [None]:
from hive.agents.agent import Agent

class LearningAgent(Agent):
    def __init__(self):
        pass
    def act(self):
        pass

In [None]:
from hive.utils.registry import registry
registry.register(name = 'LearningAgent', 
                  constructor = LearningAgent, 
                  type = Agent)

More than one agent can be registered at once using the `register_all` method. Consider registering three environments, `LearningAgentV1`, `LearningAgentV2`, and `LearningAgentV3` in the RLHive registry.

In [None]:
registry.register_all(
    Agent,
    {
        "LearningAgentV1": LearningAgentV1,
        "LearningAgentV2": LearningAgentV2,
        "LearningAgentV3": LearningAgentV3,
    },
)

### Registering a Logger

Consider registering a custom environment class named `CustomLogger` (which inherits `Logger` class) in the RLHive registry. 

In [None]:
from hive.utils.loggers import Logger

class CustomLogger(Logger):
    def __init__(self):
        pass
    def update_step(self, timescale):
        pass

In [None]:
from hive.utils.registry import registry
registry.register(name = 'CustomLogger', 
                  constructor = CustomLogger, 
                  type = Logger)

More than one logger can be registered at once using the `register_all` method. Consider registering three environments, `CustomLoggerV1`, `CustomLoggerV2`, and `CustomLoggerV3` in the RLHive registry.

In [None]:
registry.register_all(
    Agent,
    {
        "CustomLoggerV1": CustomLoggerV1,
        "CustomLoggerV2": CustomLoggerV2,
        "CustomLoggerV3": CustomLoggerV3,
    },
)

### Registering with a Custom Data Type

#### Registering Initialization Function with Custom Data Type. 

In this example, a custom initialization function `variance_scaling_` is defined below.


<!-- [Optimizer function, initialization function
wrap funciton/class https://github.com/chandar-lab/RLHive/blob/5ec70776b25c81df1236b8879e6bf7903352d390/hive/agents/qnets/utils.py#L114 ] -->

In [None]:
import math
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="uniform"):
    """Implements the :py:class:`tf.keras.initializers.VarianceScaling`
    initializer in PyTorch.
    Args:
        tensor (torch.Tensor): Tensor to initialize.
        scale (float): Scaling factor (must be positive).
        mode (str): Must be one of `"fan_in"`, `"fan_out"`, and `"fan_avg"`.
        distribution: Random distribution to use, must be one of
            "truncated_normal", "untruncated_normal" and "uniform".
    Returns:
        Initialized tensor.
    """
    fan = calculate_correct_fan(tensor, mode)
    scale /= fan
    if distribution == "truncated_normal":
        stddev = math.sqrt(scale) / 0.87962566103423978
        return torch.nn.init.trunc_normal_(tensor, 0.0, stddev, -2 * stddev, 2 * stddev)
    elif distribution == "untruncated_normal":
        stddev = math.sqrt(scale)
        return torch.nn.init.normal_(tensor, 0.0, stddev)
    elif distribution == "uniform":
        limit = math.sqrt(3.0 * scale)
        return torch.nn.init.uniform_(tensor, -limit, limit)
    else:
        raise ValueError(f"Distribution {distribution} not supported")

def calculate_correct_fan(tensor, mode):
    """Calculate fan of tensor.
    Args:
        tensor (torch.Tensor): Tensor to calculate fan of.
        mode (str): Which type of fan to compute. Must be one of `"fan_in"`,
            `"fan_out"`, and `"fan_avg"`.
    Returns:
        Fan of the tensor based on the mode.
    """
    fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor)
    if mode == "fan_in":
        return fan_in
    elif mode == "fan_out":
        return fan_out
    elif mode == "fan_avg":
        return (fan_in + fan_out) / 2
    else:
        raise ValueError(f"Fan mode {mode} not supported")

The cell below demonstrates how to register `variance_scaling_` (and other standard initialization functions) with a custom data type `InitializationFn`.

In [None]:
class InitializationFn(Registrable):
    """A wrapper for callables that produce initialization functions.
    These wrapped callables can be partially initialized through configuration
    files or command line arguments.
    """

    @classmethod
    def type_name(cls):
        """
        Returns:
            "init_fn"
        """
        return "init_fn"


registry.register_all(
    InitializationFn,
    {
        "uniform": torch.nn.init.uniform_,
        "normal": torch.nn.init.normal_,
        "constant": torch.nn.init.constant_,
        "ones": torch.nn.init.ones_,
        "zeros": torch.nn.init.zeros_,
        "eye": torch.nn.init.eye_,
        "dirac": torch.nn.init.dirac_,
        "xavier_uniform": torch.nn.init.xavier_uniform_,
        "xavier_normal": torch.nn.init.xavier_normal_,
        "kaiming_uniform": torch.nn.init.kaiming_uniform_,
        "kaiming_normal": torch.nn.init.kaiming_normal_,
        "orthogonal": torch.nn.init.orthogonal_,
        "sparse": torch.nn.init.sparse_,
        "variance_scaling": variance_scaling_,
    },
)

In [None]:
get_init_fn = getattr(registry, f"get_{InitializationFn.type_name()}")

#### Registering Optimizer Function with Custom Data Type

Similar to the previous example, we can also register optimizer functions with a custom data type. 


In [None]:
class OptimizationFn(Registrable):
    """A wrapper for callables for optimization functions.
    These wrapped callables can be partially initialized through configuration
    files or command line arguments.
    """

    @classmethod
    def type_name(cls):
        """
        Returns:
            "opt_fn"
        """
        return "opt_fn"


registry.register_all(
    OptimizationFn,
    {
        "Adam": torch.optim.Adam,
        "SGD" : torch.optim.SGD,
        "RMSprop" : torch.optim.RMSprop,
        "Adagrad" : torch.optim.Adagrad
    },
)

In [None]:
get_optimizer_fn = getattr(registry, f"get_{OptimizationFn.type_name()}")