# Relational Deep Reinforcement Learning

**Plan:**
1. Architecture
2. Agent
3. Environment
4. Training cycle

## Relational architecture

**Input: (b,n,n,1)** = (batch length, linear size, linear size, greyscale)

**Extract entities: (b,n,n,1) -> (b, m, m, 2k)** 
* embedding layer: vocab_size = MAX_PIXELS+1, embedding_dim = n_dim
* convolutional_layer1(kernel_size = (2,2), input_filters = n_dim, output_filters = k, stride = 1, pad = (1,1))
* convolutional_layer2(kernel_size = (2,2), input_filters = k, output_filters = 2k, stride = 1, pad = (1,1))

**Relational block: (b, m, m, 2k) -> (b,d_m)**
* Positional Encoding: (b, m, m, 2k) -> (b, m^2, d_m)
* N Multi-Headed Attention blocks: (b, m^2, d_m) -> (b, m^2, d_m)

**Feature-wise max pooling: (b, m^2, d_m) -> (b, d_m)**

**Multi-Layer Perceptron: (b, d_m) -> (b, d_m)**
* 4 fully connected layers (d_m,d_m) with ReLUs (TODO: add skip-connections)

**Actor output: (b,d_m) -> (b,a)** [a = number of possible actions]
* Single linear layer with softmax at the end

**Critic output: (b,d_m) -> (b,1)** 
* Single linear layer without activation function

## Control architecture

**Input: (b,n,n,1)** = (batch length, linear size, linear size, greyscale)

**Extract entities: (b,n,n,1) -> (b, m, m, 2k)** 
* embedding layer: vocab_size = MAX_PIXELS+1, embedding_dim = n_dim
* convolutional_layer1(kernel_size = (2,2), input_filters = n_dim, output_filters = k, stride = 1, pad = (1,1))
* convolutional_layer2(kernel_size = (2,2), input_filters = k, output_filters = 2k, stride = 1, pad = (1,1))

**1D Convolutional block: (b, m, m, 2k) -> (b, m^2, d_m)**
* Positional Encoding: (b, m, m, 2k) -> (b, m^2, d_m)
* 2 1D convolutional blocks with ReLUs: (b, m^2, d_m) -> (b, m^2, d_m) - pixel-wise

**Feature-wise max pooling: (b, m^2, d_m) -> (b, d_m)**

**Multi-Layer Perceptron: (b, d_m) -> (b, d_m)**
* 4 fully connected layers (d_m,d_m) with ReLUs - feature-wise
* (TODO: add skip-connections)

**Actor output: (b,d_m) -> (b,a)** [a = number of possible actions]
* Single linear layer with softmax at the end

**Critic output: (b,d_m) -> (b,1)** 
* Single linear layer without activation function

In [17]:
import numpy as np
import torch 

import torch.nn as nn
import torch.nn.functional as F

In [18]:
from RelationalModule import RelationalNetworks as rnet
from RelationalModule import ControlNetworks as cnet

In [27]:
from importlib import reload
reload(rnet)
reload(cnet)

<module 'RelationalModule.ControlNetworks' from '/home/nicola/Nicola_unipd/MasterThesis/RelationalDeepRL/RelationalModule/ControlNetworks.py'>

In [20]:
control_net = cnet.ControlNet(in_channels=1, n_kernels=24, vocab_size=117, n_dim=3,
                              n_features=256, hidden_dim=64, n_control_modules=2, n_linears=4)

In [29]:
x = torch.randint(high=116, size = (1,14,14))
y = control_net(x)

In [28]:
box_net = rnet.BoxWorldNet(in_channels=1, n_kernels=24, vocab_size=117, n_dim=3,
                              n_features=256, n_attn_modules=2, n_linears=4)

Sequential(
  (0): ExtractEntities(
    (embed): Embedding(117, 3)
    (net): Sequential(
      (0): Conv2d(3, 12, kernel_size=(2, 2), stride=(1, 1))
      (1): ReLU()
      (2): Conv2d(12, 24, kernel_size=(2, 2), stride=(1, 1))
      (3): ReLU()
    )
  )
  (1): RelationalModule(
    (net): Sequential(
      (0): PositionalEncoding(
        (projection): Linear(in_features=26, out_features=256, bias=True)
      )
      (1): AttentionBlock(
        (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (attn): MultiheadAttention(
          (out_proj): Linear(in_features=256, out_features=256, bias=True)
        )
        (ff): PositionwiseFeedForward(
          (w_1): Linear(in_features=256, out_features=64, bias=True)
          (w_2): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): AttentionBlock(
        (norm): LayerNorm((256,), eps=1e-0

In [30]:
box_net(x).shape

x.shape (before ExtractEntities):  torch.Size([1, 14, 14])
x.shape (ExtractEntities):  torch.Size([1, 24, 12, 12])
x.shape (After encoding):  torch.Size([1, 26, 12, 12])
x.shape (Before transposing and projection):  torch.Size([1, 26, 144])
x.shape (PositionalEncoding):  torch.Size([144, 1, 256])
x.shape (RelationalModule):  torch.Size([144, 1, 256])
x.shape (FeaturewiseMaxPool):  torch.Size([1, 256])
x.shape (BoxWorldNet):  torch.Size([1, 256])


torch.Size([1, 256])

In [27]:
get_entities = rnet.ExtractEntities(k_out = 24, k_in=3)
pe = rnet.PositionalEncoding(24, 256)
encoder = rnet.AttentionBlock(256, 2)
rel = rnet.RelationalModule(24, 256, 4, 2)

Sequential(
  (0): PositionalEncoding(
    (projection): Linear(in_features=26, out_features=256, bias=True)
  )
  (1): AttentionBlock(
    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (attn): MultiheadAttention(
      (out_proj): Linear(in_features=256, out_features=256, bias=True)
    )
    (ff): PositionwiseFeedForward(
      (w_1): Linear(in_features=256, out_features=64, bias=True)
      (w_2): Linear(in_features=64, out_features=256, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (2): AttentionBlock(
    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (attn): MultiheadAttention(
      (out_proj): Linear(in_features=256, out_features=256, bias=True)
    )
    (ff): PositionwiseFeedForward(
      (w_1): Linear(in_features=256, out_features=64, bias=True)
      (w_2): Linear(in_features=64, out_features=256, bias=True)
      (dropout

In [28]:
# single frame-like input
x = torch.rand((20,3,12,12))

# Convolutional pass
y = get_entities(x)

# Positional encoding
z = pe(y)

# MHA
w = encoder(z)

x.shape (before ExtractEntities):  torch.Size([20, 3, 12, 12])
x.shape (ExtractEntities):  torch.Size([20, 24, 10, 10])
x.shape (After encoding):  torch.Size([20, 26, 10, 10])
x.shape (Before transposing and projection):  torch.Size([20, 26, 100])
x.shape (PositionalEncoding):  torch.Size([100, 20, 256])


In [47]:
# Positional encoding + multiple MHA
w2 = rel(y)
print("w2.shape: ", w2.shape)

w2.shape:  torch.Size([100, 1, 256])


In [48]:
# All together
full_net = rnet.BoxWorldNet()
out = full_net(x)
out.shape

torch.Size([1, 256])

In [49]:
def f(a):
    pass

def g(a, **k):
    f(**k)

In [50]:
d = {'a':1}
g(a=2, **d)

TypeError: g() got multiple values for keyword argument 'a'

In [52]:
help(F.log_softmax)

Help on function log_softmax in module torch.nn.functional:

log_softmax(input, dim=None, _stacklevel=3, dtype=None)
    Applies a softmax followed by a logarithm.
    
    While mathematically equivalent to log(softmax(x)), doing these two
    operations separately is slower, and numerically unstable. This function
    uses an alternative formulation to compute the output and gradient correctly.
    
    See :class:`~torch.nn.LogSoftmax` for more details.
    
    Arguments:
        input (Tensor): input
        dim (int): A dimension along which log_softmax will be computed.
        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
          If specified, the input tensor is casted to :attr:`dtype` before the operation
          is performed. This is useful for preventing data type overflows. Default: None.



# Multi-channel embedding layer test

In [6]:
help(nn.Embedding)

Help on class Embedding in module torch.nn.modules.sparse:

class Embedding(torch.nn.modules.module.Module)
 |  Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None)
 |  
 |  A simple lookup table that stores embeddings of a fixed dictionary and size.
 |  
 |  This module is often used to store word embeddings and retrieve them using indices.
 |  The input to the module is a list of indices, and the output is the corresponding
 |  word embeddings.
 |  
 |  Args:
 |      num_embeddings (int): size of the dictionary of embeddings
 |      embedding_dim (int): the size of each embedding vector
 |      padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
 |                                       (initialized to zeros) whenever it encounters the index.
 |      max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_nor

In [22]:
reload(rnet)

<module 'RelationalModule.RelationalNetworks' from '/home/nicola/Nicola_unipd/MasterThesis/RelationalDeepRL/RelationalModule/RelationalNetworks.py'>

In [30]:
get_entities = rnet.ExtractEntities(k_out = 24, k_in=3, n_dim=3)

In [31]:
embed = nn.Embedding(255,3)

In [35]:
x = torch.randint(high=116, size = (3,14,14))

In [36]:
get_entities(x).shape

torch.Size([1, 24, 12, 12])

In [21]:
x = x.transpose(-1,-2).reshape(x.shape[0],-1,x.shape[-2],x.shape[-1])
x.shape

torch.Size([10, 3, 14, 14])

In [10]:
bw_net = rnet.BoxWorldNet()

In [11]:
bw_net(x).shape

torch.Size([1, 256])