In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
class PolicyNet(nn.Module):
    def __init__(self):
        super(PolicyNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(32*2*2, 256)
        self.action_head = nn.Linear(256, 4)
        self.value_head = nn.Linear(256, 1)
        
    def forward(self, x):
        print(f'input shape: {x.shape}')

        x = F.relu(self.conv1(x))
        print(f'shape after conv 1: {x.shape}')
        
        x = F.relu(self.conv2(x))
        print(f'shape after conv 2: {x.shape}')
        
        x = x.view(-1, 32*2*2)
        print(f'shape after flattening: {x.shape}')
        
        x = F.relu(self.fc1(x))
        print(f'shape after FC layer: {x.shape}')
        
        action_scores = self.action_head(x)
        print(f'action scores shape: {action_scores.shape}')

        state_values = self.value_head(x)
        print(f'state values shape: {state_values.shape}')

        return F.softmax(action_scores, dim=-1), state_values


In [23]:
pn = PolicyNet()

x = torch.randn((1, 6, 6))
print(x)

tensor([[[ 0.2799, -0.7359, -0.4680, -0.5308, -1.3366, -2.9482],
         [-0.8219, -1.3620,  0.2111, -1.3075,  0.3672, -1.7308],
         [ 0.7033,  0.5948,  0.5323, -0.7808, -1.2091, -0.5517],
         [ 1.0189, -0.2869, -1.5958, -1.0252,  0.3963,  0.2542],
         [ 0.3537,  1.4718,  0.1433, -1.0709,  1.0813,  2.0001],
         [-0.5400, -0.1685, -0.7335, -1.1585,  0.9359, -1.6388]]])


In [25]:
probs, value = pn(x)

input shape: torch.Size([1, 6, 6])
shape after conv 1: torch.Size([16, 4, 4])
shape after conv 2: torch.Size([32, 2, 2])
shape after flattening: torch.Size([1, 128])
shape after FC layer: torch.Size([1, 256])
action scores shape: torch.Size([1, 4])
state values shape: torch.Size([1, 1])


In [27]:
probs.detach() * 2

tensor([[0.4755, 0.5239, 0.5186, 0.4820]])