# Introduction to Hive Registry
In this tutorial, we will explain how to register classes and objects into the RLHive framework.

The registry module `hive.utils.registry` is used to register classes (corresponding to the agent, environment, logger, or runner) in the RLHive Registry. In other words, it allows you to register different types of `Registrable` classes and objects, and generates constructors for those classes in the form of `get_{type_name}`. These constructors allow you to construct objects from dictionary configs. These configs should have two fields: 


1. `name` - name used when registering a class in the registry
2. `**kwargs` - keyword arguments that will be passed to the constructor of the object


These constructors can also build objects recursively, i.e. if a config contains the config for another `Registrable` object, this will be automatically created before being passed to the constructor of the original object. These constructors also allow you to directly specify/override arguments for object constructors directly from the command line. These parameters are specified in dot `.` notation. They also are able to handle lists and dictionaries of Registrable objects.

For example, let’s consider the following scenario: Your `agent` class has an argument `arg1` which is annotated to be `List[Class1]` ( where `Class1` is `Registrable`), and the `Class1` constructor takes an argument `arg2`. In the passed YAML config, there are two different `Class1` object configs listed. The constructor will check to see if both `agent.arg1.0.arg2` and `agent.arg1.1.arg2` have been passed.

The parameters passed in the command line will be parsed according to the type annotation of the corresponding low level constructor. If it is not one of `int`, `float`, `str`, or `bool`, it simply loads the string into python using a YAML loader.

Each constructor returns the object, as well a dictionary config with all the parameters used to create the object and any `Registrable` objects created in the process of creating this object.

In [27]:
## used for updating config.yaml files 
%%capture
!pip install ruamel.yaml
!pip install pyglet

!pip install git+https://github.com/chandar-lab/RLHive.git@dev

In [2]:
%%capture
!apt-get install x11-utils > /dev/null 2>&1 
!pip install pyglet > /dev/null 2>&1 
!apt-get install -y xvfb python-opengl > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1

In [3]:
## Required imports
import hive
import torch
from hive.utils.utils import Registrable

### Registering an Environment

Consider registering a custom environment class named `Grid` (which inherits `BaseEnv`) in the RLHive registry. 

In [4]:
from hive.envs.base import BaseEnv
class Grid(BaseEnv):
    def __init__(self, env_name = 'Grid', **kwargs):
        pass
    def reset(self):
        pass
    def step(self):
        pass
    def render(self):
        pass
    def close(self):
        pass
    def save(self):
        pass

`registry.register` is used to register the class of interest. The parameters of the method are as follows.
- `name` *(str)* - Name of the class/object being registered
-  `constructor` *(callable)* - Callable that will be passed all kwargs from configs and be analyzed to get type annotations
- `type` - *(type)* - Type of class/object being registered. Should be subclass of `Registrable`

In [5]:
from hive.utils.registry import registry
registry.register(name = 'Grid', 
                  constructor = Grid, 
                  type = BaseEnv)

More than one environment can be registered at once using the `registry.register_all` method. The parameters of the function are as follows.

- `base_class` *(type)* - Corresponds to the type of the register function
-  `class_dict` *(dict[str, callable])* - A dictionary mapping from name to constructor

Consider registering three environments, `Gridv1`, `Gridv2`, and `Gridv3` in the RLHive registry. 

In [6]:
class Gridv1(BaseEnv):
    def __init__(self, env_name = 'Gridv1', **kwargs):
        pass
class Gridv2(BaseEnv):
    def __init__(self, env_name = 'Gridv2', **kwargs):
        pass
class Gridv3(BaseEnv):
    def __init__(self, env_name = 'Gridv3', **kwargs):
        pass
        
registry.register_all(
    BaseEnv,
    {
        "Gridv1": Gridv1,
        "Gridv2": Gridv2,
        "Gridv3": Gridv3,
    },
)

### Registering an Agent

Consider registering a custom environment class named `LearningAgent` (which inherits `Agent` class) in the RLHive registry. 

In [7]:
from hive.agents.agent import Agent

class LearningAgent(Agent):
    def __init__(self):
        pass
    def act(self):
        pass

In [8]:
from hive.utils.registry import registry
registry.register(name = 'LearningAgent', 
                  constructor = LearningAgent, 
                  type = Agent)

More than one agent can be registered at once using the `register_all` method. Consider registering three environments, `LearningAgentV1`, `LearningAgentV2`, and `LearningAgentV3` in the RLHive registry.

In [9]:
class LearningAgentV1(Agent):
    def __init__(self):
        pass
    def act(self):
        pass
class LearningAgentV2(Agent):
    def __init__(self):
        pass
    def act(self):
        pass
class LearningAgentV3(Agent):
    def __init__(self):
        pass
    def act(self):
        pass

registry.register_all(
    Agent,
    {
        "LearningAgentV1": LearningAgentV1,
        "LearningAgentV2": LearningAgentV2,
        "LearningAgentV3": LearningAgentV3,
    },
)

### Registering a Logger

Consider registering a custom environment class named `CustomLogger` (which inherits `Logger` class) in the RLHive registry. 

In [10]:
from hive.utils.loggers import Logger

class CustomLogger(Logger):
    def __init__(self):
        pass
    def update_step(self, timescale):
        pass

In [11]:
from hive.utils.registry import registry
registry.register(name = 'CustomLogger', 
                  constructor = CustomLogger, 
                  type = Logger)

More than one logger can be registered at once using the `register_all` method. Consider registering three environments, `CustomLoggerV1`, `CustomLoggerV2`, and `CustomLoggerV3` in the RLHive registry.

In [12]:
class CustomLoggerV1(Logger):
    def __init__(self):
        pass
    def update_step(self, timescale):
        pass
class CustomLoggerV2(Logger):
    def __init__(self):
        pass
    def update_step(self, timescale):
        pass
class CustomLoggerV3(Logger):
    def __init__(self):
        pass
    def update_step(self, timescale):
        pass

registry.register_all(
    Logger,
    {
        "CustomLoggerV1": CustomLoggerV1,
        "CustomLoggerV2": CustomLoggerV2,
        "CustomLoggerV3": CustomLoggerV3,
    },
)

### Registering with a Custom Data Type

#### Registering Initialization Function with Custom Data Type. 

In this example, a custom initialization function `variance_scaling_` is defined below.


<!-- [Optimizer function, initialization function
wrap funciton/class https://github.com/chandar-lab/RLHive/blob/5ec70776b25c81df1236b8879e6bf7903352d390/hive/agents/qnets/utils.py#L114 ] -->

In [13]:
import math
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="uniform"):
    """Implements the :py:class:`tf.keras.initializers.VarianceScaling`
    initializer in PyTorch.
    Args:
        tensor (torch.Tensor): Tensor to initialize.
        scale (float): Scaling factor (must be positive).
        mode (str): Must be one of `"fan_in"`, `"fan_out"`, and `"fan_avg"`.
        distribution: Random distribution to use, must be one of
            "truncated_normal", "untruncated_normal" and "uniform".
    Returns:
        Initialized tensor.
    """
    fan = calculate_correct_fan(tensor, mode)
    scale /= fan
    if distribution == "truncated_normal":
        stddev = math.sqrt(scale) / 0.87962566103423978
        return torch.nn.init.trunc_normal_(tensor, 0.0, stddev, -2 * stddev, 2 * stddev)
    elif distribution == "untruncated_normal":
        stddev = math.sqrt(scale)
        return torch.nn.init.normal_(tensor, 0.0, stddev)
    elif distribution == "uniform":
        limit = math.sqrt(3.0 * scale)
        return torch.nn.init.uniform_(tensor, -limit, limit)
    else:
        raise ValueError(f"Distribution {distribution} not supported")

def calculate_correct_fan(tensor, mode):
    """Calculate fan of tensor.
    Args:
        tensor (torch.Tensor): Tensor to calculate fan of.
        mode (str): Which type of fan to compute. Must be one of `"fan_in"`,
            `"fan_out"`, and `"fan_avg"`.
    Returns:
        Fan of the tensor based on the mode.
    """
    fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor)
    if mode == "fan_in":
        return fan_in
    elif mode == "fan_out":
        return fan_out
    elif mode == "fan_avg":
        return (fan_in + fan_out) / 2
    else:
        raise ValueError(f"Fan mode {mode} not supported")

The cell below demonstrates how to register `variance_scaling_` (and other standard initialization functions) with a custom data type `InitializationFn`.

In [23]:
class InitializationFn(Registrable):
    """A wrapper for callables that produce initialization functions.
    These wrapped callables can be partially initialized through configuration
    files or command line arguments.
    """

    @classmethod
    def type_name(cls):
        """
        Returns:
            "init_fn"
        """
        return "init_fn"


registry.register_all(
    InitializationFn,
    {
        "uniform": torch.nn.init.uniform_,
        "normal": torch.nn.init.normal_,
        "constant": torch.nn.init.constant_,
        "ones": torch.nn.init.ones_,
        "zeros": torch.nn.init.zeros_,
        "eye": torch.nn.init.eye_,
        "dirac": torch.nn.init.dirac_,
        "xavier_uniform": torch.nn.init.xavier_uniform_,
        "xavier_normal": torch.nn.init.xavier_normal_,
        "kaiming_uniform": torch.nn.init.kaiming_uniform_,
        "kaiming_normal": torch.nn.init.kaiming_normal_,
        "orthogonal": torch.nn.init.orthogonal_,
        "sparse": torch.nn.init.sparse_,
        "variance_scaling": variance_scaling_,
    },
)

In [None]:
get_init_fn = getattr(registry, f"get_{InitializationFn.type_name()}")

#### Registering Optimizer Function with Custom Data Type

Similar to the previous example, we can also register optimizer functions with a custom data type. 


In [16]:
class OptimizationFn(Registrable):
    """A wrapper for callables for optimization functions.
    These wrapped callables can be partially initialized through configuration
    files or command line arguments.
    """

    @classmethod
    def type_name(cls):
        """
        Returns:
            "opt_fn"
        """
        return "opt_fn"


registry.register_all(
    OptimizationFn,
    {
        "Adam": torch.optim.Adam,
        "SGD" : torch.optim.SGD,
        "RMSprop" : torch.optim.RMSprop,
        "Adagrad" : torch.optim.Adagrad
    },
)

In [None]:
get_optimizer_fn = getattr(registry, f"get_{OptimizationFn.type_name()}")