In [1]:
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=3, batchnorm=True):
        super(ConvBlock, self).__init__()
        self.activ = activ
        self.c1 = nn.Conv2d(in_filters, out_filters, kernel_size, padding=1)
        self.bn = batchnorm
        self.b1 = nn.BatchNorm2d(out_filters)
        
    def forward(self, x):
        x = self.c1(x)
        if self.bn: 
            x = self.b1(x)
        return x

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, filters, kernel_size=3, batchnorm=True, activ=F.relu):
        super(ResidualBlock, self).__init__()
        self.activ = activ
        self.conv1 = ConvBlock(in_filters=filters, out_filters=filters, kernel_size=kernel_size, batchnorm=batchnorm)
        self.conv2 = ConvBlock(in_filters=filters, out_filters=filters, kernel_size=kernel_size, batchnorm=batchnorm)
    
    def forward(self,x):
        x_mod = self.activ(self.conv1(x))
        x_mod = self.conv2(x_mod)
        x_mod = self.activ(x_mod+x)
        return x_mod

In [4]:
class PolicyHead(nn.Module):
    def __init__(self, filters=2, kernel_size=1, batchnorm=True, activ=F.relu, board_size=13, res_filters=256):
        super(PolicyHead, self).__init__()
        self.conv = ConvBlock(in_filters=res_filters, out_filters=filters, kernel_size=kernel_size, batchnorm=batchnorm)
        self.activ = activ
        self.fc_infeatures = filters*board_size*board_size
        self.fc = nn.Linear(in_features=self.fc_infeatures, out_features=board_size*board_size+2)
    
    def forward(self,x):
        x = self.activ(self.conv(x))
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return torch.softmax(x,0)

In [24]:
class ValueHead(nn.Module):
    def __init__(self, filters=1, kernel_size=1, batchnorm=True, activ=F.relu, board_size=13, res_filters=256):
        super(ValueHead, self).__init__()
        self.conv = ConvBlock(in_filters=res_filters, out_filters=filters, kernel_size=kernel_size, batchnorm=batchnorm)
        self.activ = activ
        self.fc_infeatures = filters*board_size*board_size
        self.fc1 = nn.Linear(in_features=self.fc_infeatures, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=1)
    
    def forward(self,x):
        x = self.activ(self.conv(x))
        x = x.view(x.size(0),-1)
        x = self.fc1(x)
        x = self.fc2(X)
        return torch.tanh(x)