In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class CNN_QNET(torch.nn.Module):

    def __init__(self, width, height, block_size):
        super().__init__()
        self.width = int(width // block_size)
        self.height = int(height // block_size)

        self.hyperparams = {
            "kernel_1":5,
            "kernel_2":3,
            "kernel_3":3,
            "out_1":10,
            "out_2":20,
            "out_3":30
        }
        
        converter = 3
        converter -= self.hyperparams["kernel_1"]
        converter -= self.hyperparams["kernel_2"]
        converter -= self.hyperparams["kernel_3"]
        self.hyperparams["in_fc"] = self.hyperparams["out_3"]
        self.hyperparams["in_fc"] *= (self.width + converter)
        self.hyperparams["in_fc"] *= (self.height + converter)

        self.conv1 = nn.Conv2d(
            3, self.hyperparams["out_1"], self.hyperparams["kernel_1"]
        )
        self.conv2 = nn.Conv2d(
            self.hyperparams["out_1"], self.hyperparams["out_2"], self.hyperparams["kernel_2"]
        )
        self.conv3 = nn.Conv2d(
            self.hyperparams["out_2"], self.hyperparams["out_3"], self.hyperparams["kernel_3"]
        )
        self.fc = nn.Linear(self.hyperparams["in_fc"], 3)


    def forward(self, x):
        # convolution layer
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # fully connected layer
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [8]:
cnn = CNN_QNET(840, 640, 20)

In [18]:
for parameter in cnn.parameters():
    print(parameter.size())

torch.Size([10, 3, 5, 5])
torch.Size([10])
torch.Size([20, 10, 3, 3])
torch.Size([20])
torch.Size([30, 20, 3, 3])
torch.Size([30])
torch.Size([3, 24480])
torch.Size([3])


In [19]:
print(cnn)

CNN_QNET(
  (conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(20, 30, kernel_size=(3, 3), stride=(1, 1))
  (fc): Linear(in_features=24480, out_features=3, bias=True)
)
