In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp layers

In [None]:
#|export
# Python native modules
from copy import deepcopy
from typing import *
from functools import partial
# Third party libs
from fastrl.torch_core import *
from torch import nn
import torch
from fastcore.all import L,Self,partialler,add_docs,test_eq
import numpy as np
# Local modules

# Layers
> Functions and Modules for RL that are pure pytorch

In [None]:
#|export
# 3dHistogram rendering code taken from:
# https://stackoverflow.com/questions/60432713/filled-3d-histogram-from-2d-histogram-with-plotly
# and
# https://community.plotly.com/t/adding-a-shape-to-a-3d-plot/1441/10
def triangulate_histogram(x, y, z):
   
    if len(x) != len(y) != len(z) :
        raise ValueError("The  lists x, y, z, must have the same length")
    n = len(x)
    if n % 2 :
        raise ValueError("The length of lists x, y, z must be an even number") 
    pts3d = np.vstack((x, y, z)).T
    pts3dp = np.array([[x[2*k+1], y[2*k+1], 0] for k in range(1, n//2-1)])
    pts3d = np.vstack((pts3d, pts3dp))
    #triangulate the histogram bars:
    tri = [[0,1,2], [0,2,n]]
    for k, i  in zip(list(range(n, n-3+n//2)), list(range(3, n-4, 2))):
        tri.extend([[k, i, i+1], [k, i+1, k+1]])
    tri.extend([[n-3+n//2, n-3, n-2], [n-3+n//2, n-2, n-1]])      
    return pts3d, np.array(tri)

def _create_3d_mesh(layer:str,weights:torch.Tensor):
    import plotly.graph_objects as go
    a0=weights.tolist()
    a0=np.repeat(a0,2).tolist()
    a0.insert(0,0)
    a0.pop()
    a0[-1]=0
    a1=np.arange(weights.shape[0]).tolist() 
    a1=np.repeat(a1,2)

    verts, tri = triangulate_histogram([layer]*len(a0), a1, a0)
    x, y, z = verts.T
    I, J, K = tri.T
    z = np.round(z.astype(float),4).astype(str)
    return go.Mesh3d(x=x, y=y, z=z, i=I, j=J, k=K, opacity=0.7)

def show_sequential_layer_weights(seq:nn.Sequential,title='Layer weights'):
    import plotly.express as px
    import plotly.io as pio
    import pandas as pd
    import plotly.graph_objects as go
    pio.renderers.default = "plotly_mimetype+notebook_connected"

    weights = {}
    counter = {}
    def append_weight_dict(m):
        if type(m) == nn.Linear:
            counter['ln'] = counter.get('ln',0)+1
            weights[f"ln_{counter['ln']}"] = to_detach(m.weight.view(-1,)).numpy()
        elif type(m) == nn.Conv2d:
            counter['conv'] = counter.get('conv',0)+1
            weights[f"conv_{counter['conv']}"] = to_detach(m.weight.view(-1,)).numpy()

    seq.apply(append_weight_dict)

    max_len = max([a.shape[0] for a in weights.values()])

    for k,v in weights.items():
        pre_shape = v.shape[0]
        pad = (max_len-pre_shape)//2
        weights[k] = np.pad(v,pad)
        diff = max_len-weights[k].shape[0]
        if diff!=0: 
            weights[k] = np.hstack((weights[k],np.zeros(diff)))
            
    fig=go.Figure()
    for layer,weights in weights.items():
        fig.add_traces(_create_3d_mesh(layer,weights))

    fig.update_layout(
        scene=dict(
            xaxis_title='Layer',
            yaxis_title='Neuron',
            zaxis_title='Weight Value',
        ),
        width=700,
        height=600,
        autosize=False,
        margin=dict(l=30, r=30, b=50, t=10),
        scene_camera_eye_z=0.8,
    )
    return fig.show()

Given a `nn.Sequential`, we can display the weights for `nn.Linear` and `nn.Conv2d` modules...

In [None]:
torch.manual_seed(0)
layers = nn.Sequential(
    nn.Linear(2,12),
    nn.ReLU(),
    nn.Linear(12,6),
    nn.ReLU(),
    nn.Linear(6,1)
)
show_sequential_layer_weights(layers)

In [None]:
#|export
def init_xavier_uniform_weights(m:Module,bias=0.01):
    "Initializes weights for linear layers using `torch.nn.init.xavier_uniform_`"
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(bias)

Show how `init_xavier_uniform_weights` affects the weights...

In [None]:
torch.manual_seed(0)
layers = nn.Sequential(
    nn.Linear(2,12),
    nn.ReLU(),
    nn.Linear(12,6),
    nn.ReLU(),
    nn.Linear(6,1)
)
layers.apply(init_xavier_uniform_weights)
show_sequential_layer_weights(layers)

In [None]:
#|export
def init_uniform_weights(m:Module,bound:float):
    "Initializes weights for linear layers using `torch.nn.init.uniform_`"
    if type(m) == nn.Linear:
        torch.nn.init.uniform_(m.weight,-bound,bound)

Show how `init_uniform_weights` affects the weights and that they are
randomly initialized between the bounds...

In [None]:
torch.manual_seed(0)
layers = nn.Sequential(
    nn.Linear(2,12),
    nn.ReLU(),
    nn.Linear(12,6),
    nn.ReLU(),
    nn.Linear(6,1)
)
layers.apply(partial(init_uniform_weights,bound=0.002))
show_sequential_layer_weights(layers)

In [None]:
#|export
def init_kaiming_normal_weights(m:Module,bias=0.01):
    "Initializes weights for linear layers using `torch.nn.init.kaiming_normal_`"
    if type(m) == nn.Linear:
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(bias)

Show how `init_kaiming_normal_weights` affects the weights...

In [None]:
torch.manual_seed(0)
layers = nn.Sequential(
    nn.Linear(2,12),
    nn.ReLU(),
    nn.Linear(12,6),
    nn.ReLU(),
    nn.Linear(6,1)
)
layers.apply(init_kaiming_normal_weights)
show_sequential_layer_weights(layers)

In [None]:
#|export
def simple_conv2d_block(
        # A tuple of state sizes generally representing an image of format: 
        # [channel,width,height]
        state_sz:Tuple[int,int,int],
        # Number of filters to use for each conv layer
        filters:int=32,
        # Activation function between each layer.
        activation_fn=nn.ReLU,
        # We assume the channels dim should be size 3 max. If it is more
        # we assume the width/height are in the location of channel and need to
        # be transposed.
        ignore_warning:bool=False
    ) -> Tuple[nn.Sequential,int]: # (Convolutional block,n_features_out)
    "Creates a 3 layer conv block from `state_sz` along with expected n_feature output shape."
    channels = state_sz[0]
    if channels>3 and not ignore_warning:
        warn(f'Channels is {channels}>3 in state_sz {state_sz}')
    layers = nn.Sequential(
        nn.BatchNorm2d(channels),
        nn.Conv2d(channels,channels,filters),
        activation_fn(),
        nn.Conv2d(channels,channels,filters),
        activation_fn(),
        nn.Conv2d(channels,channels,filters),   
        nn.Flatten()
    )
    m_layers = deepcopy(layers).to(device='meta')
    out_sz = m_layers(torch.ones((1,*state_sz),device='meta')).shape[-1]
    return layers.to(device='cpu'),out_sz

In [None]:
simple_conv2d_block((3,100,100))

(Sequential(
   (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (1): Conv2d(3, 3, kernel_size=(32, 32), stride=(1, 1))
   (2): ReLU()
   (3): Conv2d(3, 3, kernel_size=(32, 32), stride=(1, 1))
   (4): ReLU()
   (5): Conv2d(3, 3, kernel_size=(32, 32), stride=(1, 1))
   (6): Flatten(start_dim=1, end_dim=-1)
 ),
 147)

In [None]:
#|export
class Critic(Module):
    def __init__(
            self,
            state_sz:int,  # The input dim of the state / flattened conv output
            action_sz:int=0, # The input dim of the actions
            hidden1:int=400,    # Number of neurons connected between the 2 input/output layers
            hidden2:int=300,    # Number of neurons connected between the 2 input/output layers
            head_layer:Module=nn.Linear, # Output layer
            activation_fn:Module=nn.ReLU, # The activation function
            weight_init_fn:Callable=init_kaiming_normal_weights, # The weight initialization strategy
            # Final layer initialization strategy
            final_layer_init_fn:Callable=partial(init_uniform_weights,bound=1e-4),
            # For pixel inputs, we can plug in a `nn.Sequential` block from `ddpg_conv2d_block`.
            # This means that actions will be feed into the second linear layer instead of the 
            # first.
            conv_block:Optional[nn.Sequential]=None,
            # Whether to do batch norm. 
            batch_norm:bool=False
        ):
        self.action_sz = action_sz
        self.state_sz = state_sz
        self.conv_block = conv_block
        if conv_block is None:
            if batch_norm:
                ln_bn = nn.Sequential(
                    nn.BatchNorm1d(state_sz+action_sz),
                    nn.Linear(state_sz+action_sz,hidden1)
                )
            else:
                ln_bn = nn.Linear(state_sz+action_sz,hidden1)
            self.layers = nn.Sequential(
                ln_bn,
                activation_fn(),
                nn.Linear(hidden1,hidden2),
                activation_fn(),
                head_layer(hidden2,1),
            )
        else:
            self.conv_block = nn.Sequential(
                self.conv_block,
                nn.Linear(state_sz,hidden1),
                activation_fn(),
            )
            self.layers = nn.Sequential(
                nn.Linear(hidden1+action_sz,hidden2),
                activation_fn(),
                head_layer(hidden2,1),
            )
        self.layers.apply(weight_init_fn)
        if final_layer_init_fn is not None:
            final_layer_init_fn(self.layers[-1])

    def forward(
            self,
            s:torch.Tensor, # A single tensor of shape [Batch,`state_sz`]
            a:torch.Tensor=None # A single tensor of shape [Batch,`action_sz`]
            # A single tensor of shape [B,1] representing the cumulative value estimate of state+action combinations  
        ) -> torch.Tensor: 
            if self.conv_block:
                s = self.conv_block(s)
            if a is None:
                if self.action_sz!=0:
                    raise RuntimeError(f'`action_sz` is not 0, but no action was provided.')
                return self.layers(s)
            return self.layers(torch.hstack((s,a)))

add_docs(
Critic,
"""Takes a either:
 - 2 tensors of size [B,`state_sz`], [B,`action_sz`] 
 - 1 tensor of size [B,`state_sz`] 
 
 Returning -> [B,1] outputs a 1d tensor representing the Q value""",
forward="""Takes in either:
- 2 tensors of a state tensor and action tensor
or
- a single state tensor  
and outputs the Q value estimates of that state,action combination"""
)

The `Critic` is used by `DDPG`,`TRPO` to estimate the Q value of state-action pairs and is updated using the 
the Bellman-Equation similarly to DQN/Q-Learning and is represeted by $Q(s,a)$

Check that low dim input works...

In [None]:
torch.manual_seed(0)
critic = Critic(4,2)

state = torch.randn(1,4)
action = torch.randn(1,2)

with torch.no_grad(),evaluating(critic):
    test_eq(
        str(critic(state,action)),
        str(tensor([[0.0083]]))
    )

Check that image input works...

In [None]:
torch.manual_seed(0)

image_shape = (3,100,100)

conv_block,feature_out = simple_conv2d_block(image_shape)
critic = Critic(feature_out,2,conv_block=conv_block)

state = torch.randn(1,*image_shape)
action = torch.randn(1,2)

with torch.no_grad(),evaluating(critic):
    test_eq(
        str(critic(state,action)),
        str(tensor([[0.0102]]))
    )

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()