In [9]:
import torch
import torch.nn as nn
import numpy as np
from typing import TypeVar, List, Tuple
from utils import addActivation

In [12]:
class ConvBlock(nn.Module):
    """
    Convolutional portion of the rl model
    """
    def __init__(
                self,
                parameters: List[Tuple[int]],
                activation = "relu"):
        """
        params:
        parameters: a list of tuples for Conv2ds. Each tuple is for one conv2d
        We assume the tuple to be sizes of four,
        which are (in_channels, out_channels, kernel_size, stride) informations
        the length of the parameters is the number of Conv2ds inserted, and in
        the order the tuples in parameters are listed
        
        activation: activation function to use. Options are relu, leaky
        """
        super(ConvBlock, self).__init__()
        modules = []
        for in_channels, out_channels, kernel_size, stride in parameters:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride)
            modules.append(conv)
            modules = addActivation(modules, activation)
        self.block_ = nn.Sequential(*modules)
        
    def forward(self, X):
        return self.block_(X)

In [3]:
X = torch.rand((3,2,16,16))

In [4]:
in_channels = 2
out_channels = 16
kernel_size = 8
conv = nn.Conv2d(in_channels, out_channels, kernel_size)

In [6]:
conv(X).shape

torch.Size([3, 16, 9, 9])

In [24]:
# define (in_channels, out_channels, kernel_size, stride) 
conv_params = [
    (4, 16, 8, 4),
    (16, 32, 4, 2),
]
conv_block = ConvBlock(conv_params)

In [33]:
X = torch.rand((10,4,84,84))

In [34]:
# X =conv_block(X)
for layer in conv_block.block_:
    X = layer(X)
    print(X.shape)

torch.Size([10, 16, 20, 20])
torch.Size([10, 16, 20, 20])
torch.Size([10, 32, 9, 9])
torch.Size([10, 32, 9, 9])


In [35]:
X.shape

torch.Size([10, 32, 9, 9])

In [28]:
print(conv_block)

ConvBlock(
  (block_): Sequential(
    (0): Conv2d(4, 16, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
  )
)


In [37]:
X.flatten(start_dim = 1).shape

torch.Size([10, 2592])

In [38]:
32*9*9

2592