In [1]:
import functools

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.nn import init

In [15]:
class ResBlock(nn.Module):
    def __init__(self, c):
        super(ResBlock, self).__init__()
        res = [
            nn.Conv2d(c, c, kernel_size=3, padding=1),
            nn.BatchNorm2d(c),
            nn.ReLU(True),
            nn.Conv2d(c, c, kernel_size=3, padding=1),
            nn.BatchNorm2d(c),
        ]
        self.res = nn.Sequential(*res)
        self.relu = nn.ReLU(True)
    def forward(self, x):
        out = x + self.res(x)
        out = self.relu(x)
        return out

In [22]:
class Net(nn.Module):
    def __init__(self, in_c, n_c=256, depth=8, n_a=73, n_hidden=256, width=8):
        super(Net, self).__init__()
        model = [
            nn.Conv2d(in_c, n_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_c),
            nn.ReLU(True),
        ]
        for _ in range(depth):
            model += [ResBlock(n_c)]
        self.net = nn.Sequential(*model)
        self.policy = nn.Sequential(
            nn.Conv2d(n_c, n_a, kernel_size=1),
            nn.ReLU(True)
        )
        self.value_conv = nn.Sequential(
            nn.Conv2d(n_c, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.value = nn.Sequential(
            nn.Linear(width**2, n_hidden),
            nn.ReLU(True),
            nn.Linear(n_hidden, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.net(x)
        p = self.policy(x)
        v = self.value_conv(x)
        v = self.value(v.view(v.size(0),-1))
        return p, v
            

In [23]:
net = Net(10)

In [24]:
board = torch.ones((1,10,8,8))
p, v = net(board)

In [25]:
from torchinfo import summary

batch_size = 5 
summary(net, input_data=[torch.ones(batch_size, 10, 8,8)])

Layer (type:depth-idx)                   Output Shape              Param #
Net                                      --                        --
├─Sequential: 1-1                        [5, 256, 8, 8]            --
│    └─Conv2d: 2-1                       [5, 256, 8, 8]            23,296
│    └─BatchNorm2d: 2-2                  [5, 256, 8, 8]            512
│    └─ReLU: 2-3                         [5, 256, 8, 8]            --
│    └─ResBlock: 2-4                     [5, 256, 8, 8]            --
│    │    └─Sequential: 3-1              [5, 256, 8, 8]            1,181,184
│    │    └─ReLU: 3-2                    [5, 256, 8, 8]            --
│    └─ResBlock: 2-5                     [5, 256, 8, 8]            --
│    │    └─Sequential: 3-3              [5, 256, 8, 8]            1,181,184
│    │    └─ReLU: 3-4                    [5, 256, 8, 8]            --
│    └─ResBlock: 2-6                     [5, 256, 8, 8]            --
│    │    └─Sequential: 3-5              [5, 256, 8, 8]           