In [1]:
import torch
import numpy as np
from encode import encode_board
from game import Connect4
from mcts import MCTS

In [2]:
class vanilla(torch.nn.Module):
    def __init__(self, input_dim, n_actions, hidden_dim=100):
        self.input_dim = input_dim
        super(vanilla, self).__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.policy = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, n_actions),
            torch.nn.Softmax(dim=1)
        )
        self.value = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 1),
            torch.nn.Tanh()
        )
    def forward(self, x):
        out = x.reshape(-1, self.input_dim)
        out = self.l1(out)
        return self.policy(out).squeeze(), self.value(out).squeeze()

In [3]:
net = vanilla(6 * 7 * 3, 7)

In [4]:
env = Connect4()

In [5]:
t = torch.tensor(encode_board(env)).unsqueeze(0).float()
net(t)

(tensor([0.1602, 0.1512, 0.1305, 0.1384, 0.1343, 0.1458, 0.1396],
        grad_fn=<SqueezeBackward0>),
 tensor(-0.1011, grad_fn=<SqueezeBackward0>))

In [6]:
t.shape

torch.Size([1, 6, 7, 3])

In [7]:
args = {"n_sim":10, "exploration_constant":1.}

In [8]:
mcts = MCTS(net, args)

In [9]:
env.current_state

array([[' ', ' ', ' ', ' ', ' ', ' ', ' '],
       [' ', ' ', ' ', ' ', ' ', ' ', ' '],
       [' ', ' ', ' ', ' ', ' ', ' ', ' '],
       [' ', ' ', ' ', ' ', ' ', ' ', ' '],
       [' ', ' ', ' ', ' ', ' ', ' ', ' '],
       [' ', ' ', ' ', ' ', ' ', ' ', ' ']], dtype='<U32')

In [10]:
torch.tensor(encode_board(env)).float().unsqueeze(0).cpu().shape

torch.Size([1, 6, 7, 3])

In [11]:
mcts.search(env)

(6, 7, 3)
[[' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ']]


0.10108617

In [14]:
mcts.Qsa

{}