In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

In [2]:
class Connect4Model(nn.Module):
    def __init__(self, board_size, action_size):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                ConvBlock(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
                ConvBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
                ConvBlock(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
                ResidualBlock(channels=128, num_repeats=4),
                ConvBlock(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
                ResidualBlock(channels=64, num_repeats=4),
                ConvBlock(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
                ConvBlock(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
                nn.Flatten()
            ]
        )

        # Two heads on our network
        self.action_head = nn.Linear(in_features=board_size, out_features=action_size)
        self.value_head = nn.Linear(in_features=board_size, out_features=1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        action_logits = self.action_head(x)
        value_logit = self.value_head(x)

        return F.softmax(action_logits, dim=1), torch.tanh(value_logit)

    def predict(self, board):
        self.eval()
        with torch.no_grad():
            pi, v = self.forward(board)

        return pi.data.cpu().numpy()[0], v.data.cpu().numpy()[0]

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
        
class ResidualBlock(nn.Module):
    def __init__(self, channels, num_repeats=1):
        super().__init__()
        self.num_repeats = num_repeats
        self.layers = nn.ModuleList()
        for _ in range(self.num_repeats):
            self.layers += [
                nn.Sequential(
                    ConvBlock(channels, channels // 2, kernel_size=1),
                    ConvBlock(channels // 2, channels, kernel_size=3, padding=1)
                )
            ]

    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x)
        return x

In [3]:
model = Connect4Model(6*7, action_size=7)
model.load_state_dict(torch.load('move/AlphaZero.pth', map_location='cpu'))
model.eval()
print()




In [16]:
board = np.zeros((6, 7))
board[5, 4:] = 1

move = model(torch.from_numpy(board).float().unsqueeze(0).unsqueeze(0))[0]#.argmax().item()
print(move)

tensor([[0.5677, 0.1362, 0.0136, 0.1098, 0.0047, 0.1365, 0.0316]],
       grad_fn=<SoftmaxBackward>)


In [17]:

print(board)

[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 1.]]


In [21]:
move = model(torch.from_numpy(board).float().unsqueeze(0).unsqueeze(0))[0].argmax().item()
print(move)

0
