In [1]:
import torch
from torch.distributions.dirichlet import Dirichlet

In [2]:
from hexconvolution import NoMCTSModel
from hexboard import Board
from hexgame import HexGame

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
noise = Dirichlet(torch.full((121,),0.03))
model = NoMCTSModel(board_size=11, layers=5, noise=noise, noise_level=0.25)

In [6]:
board = Board(size=11)

In [7]:
board

Board
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Legal moves
{(7, 3), (6, 9), (0, 7), (1, 6), (0, 10), (3, 7), (2, 5), (8, 5), (5, 8), (4, 0), (10, 8), (9, 0), (6, 7), (5, 5), (10, 7), (7, 6), (6, 10), (0, 4), (1, 1), (4, 10), (3, 2), (2, 6), (8, 2), (4, 5), (9, 3), (6, 0), (7, 5), (0, 1), (3, 1), (9, 9), (7, 8), (2, 1), (8, 9), (9, 4), (5, 1), (10, 3), (7, 2), (1, 5), (3, 6), (2, 2), (1, 10), (8, 6), (4, 1), (10, 9), (9, 7), (6, 4), (5, 4), (10, 4), (7, 1), (0, 5), (1, 0), (0, 8), (3, 5), (2, 7), (8, 3), (5, 10), (4, 6), (10, 10), (9, 2), (6, 1), (5, 7), (7, 4), (0, 2), (1, 3), (4, 8), (3, 0), (2, 8), (9, 8), (8, 0), (6, 2), (3, 10), (8,

In [8]:
hexgame = HexGame(board, model, device)

In [9]:
board_states, moves, targets = hexgame.play_moves()

In [10]:
hexgame

Board
[[-1.  0.  1.  1.  0.  1.  1.  1. -1.  0.  0.]
 [-1. -1. -1.  0. -1. -1. -1.  1.  0.  1.  0.]
 [ 1.  0.  0.  1. -1.  0.  1.  0. -1.  0.  0.]
 [ 0.  0.  0. -1.  0.  1.  0. -1.  1.  0.  1.]
 [-1.  1.  0.  0.  1. -1.  1.  0.  1.  1. -1.]
 [ 0.  1.  1.  1. -1.  1.  1. -1. -1.  0. -1.]
 [ 1. -1.  0.  0. -1.  1. -1. -1. -1.  0.  0.]
 [ 1. -1.  1. -1.  0.  0. -1.  0.  1. -1.  1.]
 [ 1.  0. -1.  1.  0. -1.  1.  0. -1.  0.  0.]
 [ 1. -1.  1.  0.  1.  1. -1. -1.  0.  1. -1.]
 [ 1. -1.  1.  0.  0. -1.  1.  1. -1. -1. -1.]]
Legal moves
{(6, 9), (0, 10), (2, 5), (6, 10), (0, 4), (3, 2), (9, 3), (7, 5), (0, 1), (3, 1), (2, 1), (8, 9), (10, 3), (3, 6), (2, 2), (1, 10), (10, 4), (2, 7), (7, 4), (1, 3), (3, 0), (9, 8), (6, 2), (8, 10), (5, 0), (3, 9), (8, 7), (4, 2), (0, 9), (3, 4), (8, 4), (5, 9), (4, 7), (7, 7), (2, 9), (8, 1), (6, 3), (2, 10), (1, 8), (4, 3)}
Illegal moves
{(7, 3), (0, 7), (1, 6), (3, 7), (8, 5), (5, 8), (4, 0), (10, 8), (9, 0), (6, 7), (5, 5), (10, 7), (7, 6), (1, 1), (4, 10)

In [25]:
board_states.shape

torch.Size([81, 2, 11, 11])

In [27]:
a = board_states.sum(dim=1).view(-1,121)

In [30]:
b = a.sum(dim=1)

In [34]:
print((b>1).type(torch.DoubleTensor)*10**10)

tensor([0.0000e+00, 0.0000e+00, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10, 1.0000e+10,
        1.0000e+10, 1.0000e+10, 1.0000e+

In [49]:
c = a * ((b>1).type(torch.Tensor)).unsqueeze(1).expand_as(a) * 10**10

In [50]:
print(c[5])

tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+10, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+