In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
class View(nn.Module):

    def __init__(self, *shape):
        # shape is a list
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(*self.shape)

class policy(nn.Module):
    
    def __init__(self):
        super(policy, self).__init__()
        
        self.sense_im = nn.Sequential( # BxCx32x32
                    nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2), # Bx32x32x32
                    nn.MaxPool2d(kernel_size=3, stride=2), # Bx32x15x15
                    nn.ReLU(inplace=True),
                    nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2), # Bx32x15x15
                    nn.ReLU(inplace=True),
                    nn.AvgPool2d(kernel_size=3, stride=2), # Bx32x7x7
                    nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2), # Bx64x7x7
                    nn.ReLU(inplace=True),
                    nn.AvgPool2d(kernel_size=3, stride=2), # Bx64x3x3
                    View(-1, 576),
                    nn.Linear(576, 256),
                    nn.ReLU(inplace=True)
                )
        self.sense_pro = nn.Sequential(
                    nn.Linear(2, 16),
                    nn.ReLU(inplace=True)
                 )
        self.fuse = nn.Sequential( # 256+16
                nn.Linear(272, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 256), # Bx256
                nn.BatchNorm1d(256)
            )
        self.aggregate = nn.LSTM(input_size=256, hidden_size=256, num_layers=1)
        self.act = nn.Sequential( # self.rnn_hidden_size
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 2)
           )
        
    def forward(self, x, l, hidden=None):
        
        batch_size = x.size(0)

        if hidden is None:
            hidden = [torch.zeros(1, batch_size, 256), # hidden state: (num_layers, batch_size, hidden size)
                      torch.zeros(1, batch_size, 256)] # cell state  :(num_layers, batch_size, hidden size)
        # ---- Sense the inputs ----
        x1 = self.sense_im(x)
        x2 = self.sense_pro(l)
        
        x = torch.cat([x1, x2], dim=1)
        # ---- Fuse the representations ----
        x = self.fuse(x)

        # ---- Update the belief state about the panorama ----
        # Note: input to aggregate lstm has to be seq_length x batch_size x input_dims
        # Since we are feeding in the inputs one by one, it is 1 x batch_size x 256
        x, hidden = self.aggregate(x.view(1, *x.size()), hidden)
        
        
        act_input = hidden[0].view(batch_size, -1)
        # Concatenate the relative change

        # ---- Predict the action propabilities ----
        out = self.act(act_input)
        
        return out, hidden

In [10]:
model = policy()

In [11]:
print(model)

policy(
  (sense_im): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): AvgPool2d(kernel_size=3, stride=2, padding=0)
    (6): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (7): ReLU(inplace=True)
    (8): AvgPool2d(kernel_size=3, stride=2, padding=0)
    (9): View()
    (10): Linear(in_features=576, out_features=256, bias=True)
    (11): ReLU(inplace=True)
  )
  (sense_pro): Sequential(
    (0): Linear(in_features=2, out_features=16, bias=True)
    (1): ReLU(inplace=True)
  )
  (fuse): Sequential(
    (0): Linear(in_features=272, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=256,