# FullConvNet 

<img src="Supplementary material/SC2 architectures.png">

## How to implement 
1. Conv2D from 3 to n_channels
2. Block of N ResidualConvolutional layers: <br>
    - 2 convolutions kernel 3x3 with stride and padding of 1, keeps spatial resolution unchanged
    - n_channels to n_channels, passing through hidden_channels
    - LayerNorm before processing the input to the module
    - ReLU between the two convolutions
    - Add input at the end (input is free to flow untouched)
3. Out of this part we get the spatial features, that will be processed in two ways: spatial (for sampling spatial arguments) and non-spatial, to sample the action id and the non-spatial parameters

### Non-spatial architecture
Start from spatial features, 
- flatten the two pixel dimensions in a single one, 
- apply N residual layers feature-wise (each of them acts on the i-th feature along the pixel axis)
- maxpool feature-wise to suppress the pixel dimension; Now each feature represents the result of the interaction between pixels in a different way (so it is a viable alternative to the relational module)
- N residual layers to the n_channels

(from there a simple MLP with final softmax can be plugged in at the end to get a categorical distribution for the action ids or their non-spatial parameters)

### Spatial architecture
Start from spatial features, 
- apply ReLU
- use convolution with 1 output channel

In [41]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [33]:
debug = True

In [31]:
class ResidualLayer(nn.Module):
    """
    Implements residual layer. Use LayerNorm and ReLU activation before applying the layers.
    """
    def __init__(self, n_features, n_hidden):
        super(ResidualLayer, self).__init__()
        self.norm = nn.LayerNorm(n_features)
        self.w1 = nn.Linear(n_features, n_hidden)
        self.w2 = nn.Linear(n_hidden, n_features)

    def forward(self, x):
        out = F.relu(self.w1(self.norm(x)))
        out = self.w2(out)
        return out + x

In [2]:
class ResidualConvolutional(nn.Module):
    
    def __init__(self, linear_size, n_channels, hidden_channels=12, kernel_size=3):
        super(ResidualConvolutional, self).__init__()
        
        padding = (kernel_size - 1) // 2
        assert (kernel_size - 1) % 2 == 0, 'Provide odd kernel size to use this layer'
        
        self.net = nn.Sequential(
                                nn.LayerNorm((linear_size, linear_size)),
                                nn.Conv2d(n_channels, hidden_channels, kernel_size, stride=1, padding=padding),
                                nn.ReLU(),
                                nn.Conv2d(hidden_channels, n_channels, kernel_size, stride=1, padding=padding)
                                )
        
    def forward(self, x):
        out = self.net(x)
        out = out + x
        return out

In [12]:
class SpatialFeatures(nn.Module):
    def __init__(self, n_layers, linear_size, in_channels, n_channels, **HPs):
        super(SpatialFeatures, self).__init__()
        
        self.linear_size = linear_size # screen resolution
        
        layers =  nn.ModuleList([ResidualConvolutional(linear_size, n_channels, **HPs) for _ in range(n_layers-1)])
        
        self.net = nn.Sequential(
                                nn.Conv2d(in_channels, n_channels, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(),
                                *layers
                                )
        
    def forward(self, x):
        x = self.net(x)
        return x

In [61]:
class SpatialParameters(nn.Module):
    
    def __init__(self, n_channels, linear_size):
        super(SpatialParameters, self).__init__()
        
        self.size = linear_size
        self.conv = nn.Conv2d(n_channels, 1, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape((x.shape[0],-1))
        log_probs = F.log_softmax(x, dim=(-1))
        if debug: 
            print("log_probs.shape: ", log_probs.shape)
            print("log_probs.shape (reshaped): ", log_probs.view(self.size, self.size).shape)
        probs = torch.exp(log_probs)
        
        # assume squared space
        x_lin = torch.arange(self.size)
        xx = x_lin.repeat(self.size,1)
        args = torch.cat([xx.view(self.size,self.size,1), xx.T.view(self.size,self.size,1)], axis=2)
        args = args.reshape(-1,2)
        
        distribution = Categorical(probs)
        index = distribution.sample().item() # detaching it, is it okay? maybe...
        arg = args[index] # and this are the sampled coordinates
        arg = list(arg.detach().numpy())
        
        return arg, log_probs.view(self.size, self.size)[arg[0], arg[1]], probs                       

In [46]:
class FeaturewiseMaxPool(nn.Module):
    """Applies max pooling along a given axis of a tensor"""
    def __init__(self, pixel_axis):
        super(FeaturewiseMaxPool, self).__init__()
        self.max_along_axis = pixel_axis
        
    def forward(self, x):
        x, _ = torch.max(x, axis=self.max_along_axis)
        if debug:
            print("x.shape (FeaturewiseMaxPool): ", x.shape)
        return x

In [47]:
class NonSpatialFeatures(nn.Module):
    
    def __init__(self, linear_size, n_channels, pixel_hidden_dim=128, pixel_n_residuals=4, 
                 feature_hidden_dim=64, feature_n_residuals=4):
        super(NonSpatialFeatures, self).__init__()
        
        pixel_res_layers = nn.ModuleList([ResidualLayer(linear_size**2, pixel_hidden_dim) 
                                          for _ in range(pixel_n_residuals)])
        self.pixel_res_block = nn.Sequential(*pixel_res_layers)

        self.maxpool = FeaturewiseMaxPool(pixel_axis=2)

        feature_res_layers = nn.ModuleList([ResidualLayer(n_channels, feature_hidden_dim) 
                                            for _ in range(feature_n_residuals)])
        self.feature_res_block = nn.Sequential(*feature_res_layers)
        
    def forward(self, x):
        """ Input shape (batch_dim, n_channels, linear_size, linear_size) """
        x = x.view(x.shape[0], x.shape[1],-1)
        if debug: print("x.shape: ", x.shape)
            
        x = self.pixel_res_block(x) # Interaction between pixels feature-wise
        if debug: print("x.shape: ", x.shape)
            
        x = self.maxpool(x) # Feature-wise maxpooling
        if debug: print("x.shape: ", x.shape)
            
        x = self.feature_res_block(x) # Interaction between features -> final representation
        if debug: print("x.shape: ", x.shape)
        
        return x     

In [48]:
class CategoricalNet(nn.Module):
    
    def __init__(self, n_features, size, hiddens=[32,16]):
        super(CategoricalNet, self).__init__()
        layers = []
        
        layers.append(nn.Linear(n_features, hiddens[0]))
        layers.append(nn.ReLU())
            
        for i in range(0,len(hiddens)-1):
            layers.append(nn.Linear(hiddens[i], hiddens[i+1]))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(hiddens[-1], size))
        self.net = nn.Sequential(*layers)
        
    def forward(self, state_rep):
        logits = self.net(state_rep)
        log_probs = F.log_softmax(logits, dim=(-1))
        probs = torch.exp(log_probs)
        distribution = Categorical(probs)
        arg = distribution.sample().item() 
        return [arg], log_probs.view(-1)[arg], probs

## Testing

In [62]:
linear_size = 16
in_channels = 3
n_layers = 2
n_channels = 12
n_actions = 3

In [63]:
spatial_features_net = SpatialFeatures(n_layers, linear_size, in_channels, n_channels)
spatial_params_net = SpatialParameters(n_channels, linear_size)
nonspatial_features_net = NonSpatialFeatures(linear_size, n_channels)
action_net = CategoricalNet(n_channels, n_actions)

In [64]:
x = torch.rand(1, in_channels, linear_size, linear_size)

In [65]:
spatial_features = spatial_features_net(x)
print("spatial_features: ", spatial_features.shape)

spatial_features:  torch.Size([1, 12, 16, 16])


In [67]:
spatial_params, log_prob, probs = spatial_params_net(spatial_features)
print("Spatial params: ", spatial_params)

log_probs.shape:  torch.Size([1, 256])
log_probs.shape (reshaped):  torch.Size([16, 16])
Spatial params:  [6, 7]


In [54]:
torch.exp(spatial_params).sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [37]:
nonspatial_features = nonspatial_features_net(spatial_features)
print("nonspatial_features: ", nonspatial_features.shape)

x.shape:  torch.Size([1, 12, 256])
x.shape:  torch.Size([1, 12, 256])
x.shape (FeaturewiseMaxPool):  torch.Size([1, 12])
x.shape:  torch.Size([1, 12])
x.shape:  torch.Size([1, 12])
nonspatial_features:  torch.Size([1, 12])


In [43]:
a, log_prob, probs = action_net(nonspatial_features)
print("Action sampled: ", a)

Action sampled:  [1]
