In [1]:
import ataxx
import ataxx.pgn
f = open('data/000cde9536aa3c15a6722452cb733632.txt', 'r')
game = ataxx.pgn.parse(f.read())

In [2]:
board = ataxx.Board(game.headers['FEN'], 512)
def board_to_list(board, inverse=None):
    if inverse is None:
        inverse = (board.turn == ataxx.WHITE)
    if not inverse:
        mapping = {ataxx.EMPTY: 0, ataxx.BLACK: 1, ataxx.WHITE: -1}
    else:
        mapping = {ataxx.EMPTY: 0, ataxx.BLACK: -1, ataxx.WHITE: 1}
    return [[mapping[board.get(x, y)] for x in range(7)] for y in range(7)]
print(list(reversed(board_to_list(board))))

[[0, 0, 1, 0, 0, -1, -1], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [-1, 0, 0, 0, 0, 0, 0], [-1, 0, 0, 0, 0, 0, 1]]


In [3]:
board.makemove(list(game.main_line())[0].move)
print(list(game.main_line())[0].move)
print(board_to_list(board))

b5
[[1, 0, 0, 0, 0, 0, -1], [1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, -1, 0, 0, 0, 0, 0], [0, 0, -1, 0, 0, 0, 0], [0, 0, -1, 0, 0, 1, 1]]


In [4]:
iter_line = iter(game.main_line())

In [5]:
 _a = next(iter_line)

In [6]:
print(str(_a.move))
print(str(_a.children[0].move))
_a.move

b5
e7


<ataxx.Move at 0x7f12781c2190>

In [7]:
from torch.utils.data import IterableDataset, DataLoader
from random import sample
from os import listdir
_direction_map = {(0,2): 1,
                  (1,2): 2,
                  (2,2): 3,
                  (2,1): 4,
                  (2,0): 5,
                  (2,-1): 6,
                  (2,-2): 7,
                  (1,-2): 8,
                  (0,-2): 9,
                  (-1,-2): 10,
                  (-2,-2): 11,
                  (-2,-1): 12,
                  (-2,0): 13,
                  (-2,1): 14,
                  (-2,2): 15,
                  (-1,2): 16
                  }
def direction(fr_x, fr_y, to_x, to_y):
    return _direction_map[(to_x-fr_x, to_y-fr_y)]

class PgnDataset(IterableDataset):
    def __init__(self, filename=None):
        super(PgnDataset).__init__()
        if filename is None:
            filename = ['data/'+x for x in listdir('data')]
        self.file = filename
        self.file_queue = sample(self.file, len(self.file))
    def __iter__(self):
        '''Current player is always 1'''
        self.iter = GameIter(self.file_queue[:])
        return self.iter
    
class GameIter():
    def __init__(self, file_queue):
        self.file_queue = file_queue
        filename = self.file_queue.pop()
        self.load_game(filename)

    def load_game(self, filename):
        with open(filename, 'r') as f:
            game = ataxx.pgn.parse(f.read())
        self.game = game
        self.board = ataxx.Board(game.headers['FEN'], 512)
        self.game_iter = iter(game.main_line())

    def __iter__(self):
        return self

    def __next__(self):
        if self.board.gameover():
            if len(self.file_queue) == 0:
                raise StopIteration
            else:
                filename = self.file_queue.pop()
                self.load_game(filename)
        current_move = next(self.game_iter)
        self.board.makemove(current_move.move)

        # Make input tensor
        board_tensor = torch.tensor(board_to_list(self.board), dtype=torch.float)
        board_player = F.relu(board_tensor)
        board_opponent = F.relu(-board_tensor)
        countdown = (512 - self.board.fullmove_clock) / 512
        turn_tensor = torch.full(board_tensor.size(), countdown, dtype=torch.float)
        input_stack = (board_player, board_opponent, turn_tensor)

        # Make value tensor
        if self.game.headers['Result'] == '1-0': # black is winner
            v = 1 if self.board.turn == ataxx.BLACK else -1
        elif self.game.headers['Result'] == '0-1': # white is winner
            v = -1 if self.board.turn == ataxx.BLACK else 1
        else:
            v = 0

        # Make policy tensor
        policy = 0
        if len(current_move.children) == 0:
            next_move = current_move.move
        else:
            next_move = current_move.children[0].move
        move_tensor = torch.zeros((17, 7, 7))
        if next_move.is_single():
            move_tensor[0, next_move.to_x, next_move.to_y] = 1
            move_direction = 0
        else:
            move_direction = direction(next_move.fr_x, next_move.fr_y, next_move.to_x, next_move.to_y)
            move_tensor[move_direction, next_move.to_x, next_move.to_y] = 1
        policy = move_direction * 7 * 7 + next_move.to_x * 7 + next_move.to_y
        if policy < 0:
            policy = 0

        return {'input': torch.stack(input_stack, dim=0),
                'value': torch.tensor([v], dtype=torch.float),
                'policy': torch.tensor(policy, dtype=torch.long)}
        

In [8]:
len(listdir('data'))

19942

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim

class EvalNet(nn.Module):
    def __init__(self, size=7):
        self.size = size
        super(EvalNet, self).__init__()
        self.fc_value = nn.Linear(size*size, 1, bias=False)

    def forward(self, batch):
        x = batch.view(-1, self.size*self.size)
        v1 = self.fc_value(F.relu(x))
        v2 = self.fc_value(F.relu(-x))
        return torch.tanh(v1 - v2)

class ResidualBlock(nn.Module):
    def __init__(self, size=7, channel=16):
        super(ResidualBlock, self).__init__()
        self.size = size
        self.channel = channel
        self.conv1a = nn.Conv2d(channel, channel//2, kernel_size=3, padding=1, bias=False)
        self.conv1b = nn.Conv2d(channel, channel//2, kernel_size=5, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu = nn.ReLU(inplace=True)
        self.conv2a = nn.Conv2d(channel, channel//2, kernel_size=3, padding=1, bias=False)
        self.conv2b = nn.Conv2d(channel, channel//2, kernel_size=5, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(channel)

    def forward(self, input):
        xa = self.conv1a(input)
        xb = self.conv1b(input)
        x = torch.cat((xa, xb), dim=1)
        x = self.bn1(x)
        x = self.relu(x)
        xa = self.conv2a(x)
        xb = self.conv2b(x)
        x = torch.cat((xa, xb), dim=1)
        x = self.bn2(x)
        x += input
        x = self.relu(x)
        return x

class SidusAtaxxNet(nn.Module):
    def __init__(self, size=7):
        self.size = size
        
        super(SidusAtaxxNet, self).__init__()
        self.conv1a = nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False)
        self.conv1b = nn.Conv2d(3, 4, kernel_size=5, padding=2, bias=False)
        self.conv1c = nn.Conv2d(3, 4, kernel_size=7, padding=3, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(16)

        self.res_layers = nn.Sequential(*[ResidualBlock() for i in range(6)])

        self.conv_val = nn.Conv2d(16, 1, kernel_size=1, bias=False)
        self.bn_val = nn.BatchNorm2d(1)
        self.fc_val1 = nn.Linear(size*size, size*size)
        self.fc_val2 = nn.Linear(size*size, 1)

        self.conv_pol = nn.Conv2d(16, 3, kernel_size=1, bias=False)
        self.bn_pol = nn.BatchNorm2d(3)
        self.fc_pol = nn.Linear(3*size*size, 17*size*size)
        
    def forward(self, batch):
        x = batch.view(-1, 3, self.size, self.size)

        xa = self.relu(self.conv1a(x))
        xb = self.relu(self.conv1b(x))
        xc = self.relu(self.conv1b(x))
        x = torch.cat((xa, xb, xc), dim=1)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.res_layers(x)

        x_val = self.conv_val(x)
        x_val = self.bn_val(x_val)
        x_val = self.relu(x_val)
        x_val = x_val.view(-1, 1*self.size*self.size)
        x_val = self.fc_val1(x_val)
        x_val = self.relu(x_val)
        x_val = self.fc_val2(x_val)

        x_pol = self.conv_pol(x)
        x_pol = self.bn_pol(x_pol)
        x_pol = self.relu(x_pol)
        x_pol = x_pol.view(-1, 3*self.size*self.size)
        x_pol = self.fc_pol(x_pol)
        
        return F.log_softmax(x_pol, dim=1), torch.tanh(x_val)

In [10]:
net = SidusAtaxxNet()
net.load_state_dict(torch.load('model.pt'))
def train(use_local_batch=False, reduce_data=None):
    if use_local_batch:
        if reduce_data != None:
            datafeeder = PgnDataset(['data/'+x for x in listdir('data')][:reduce_data])
        else:
            datafeeder = PgnDataset(['data/'+x for x in listdir('data')])
    else:
        datafeeder = DataLoader()
    # Generate batch
    dataloader = DataLoader(datafeeder, batch_size=32)
    # Don't forget dataset augumentation!
    # Train the model
    criterion_pol = nn.NLLLoss()
    criterion_val = nn.MSELoss()
    optim = torch.optim.Adam(net.parameters(), lr=0.001)
    net.train()
    net.cuda()
    for epoch in range(10):
        total_loss = 0
        loss_pol_sum = 0
        loss_val_sum = 0
        print('epoch', epoch)
        for i, data in enumerate(dataloader):
            optim.zero_grad()
            output = net(data['input'].to('cuda'))
            policy, value = output
            loss_pol = criterion_pol(policy, data['policy'].to('cuda'))
            loss_val = criterion_val(value, data['value'].to('cuda'))
            loss = loss_pol + 5 * loss_val
            loss.backward()
            optim.step()

            loss_pol_sum += loss_pol.detach().cpu().item()
            loss_val_sum += loss_val.detach().cpu().item()
            total_loss += loss.detach().cpu().item()
            print_interval = 1000
            if i % print_interval == print_interval - 1:
                print('[%d, %5d] loss: %.8f\tpol: %.8f\tval: %.8f' %
                  (epoch + 1, i + 1, total_loss/print_interval, loss_pol_sum/print_interval, loss_val_sum/print_interval))
                total_loss = 0
                loss_pol_sum = 0
                loss_val_sum = 0
                torch.save({'epoch': epoch,
                            'step': i,
                            'model_state_dict': net.state_dict(),
                            'optimizer_state_dict': optim.state_dict(),
                            'loss': loss}, 'checkpoint.tar')
        torch.save(net.state_dict(), 'model.pt')
    # Compete with the best model

def main(reduce_data=None):
    train(use_local_batch=True, reduce_data=reduce_data)

In [11]:
main()

epoch 0
[1,  1000] loss: 12.56895155	pol: 8.99324589	val: 0.71514113
[1,  2000] loss: 6.50627532	pol: 3.48140749	val: 0.60497356
[1,  3000] loss: 5.62427087	pol: 2.95286922	val: 0.53428033
[1,  4000] loss: 5.23985610	pol: 2.65146947	val: 0.51767733
[1,  5000] loss: 5.25085525	pol: 2.44507223	val: 0.56115660
[1,  6000] loss: 5.10035934	pol: 2.41169110	val: 0.53773364
[1,  7000] loss: 4.85898345	pol: 2.32592551	val: 0.50661159
[1,  8000] loss: 4.89109935	pol: 2.31540198	val: 0.51513947
[1,  9000] loss: 4.80934175	pol: 2.25203333	val: 0.51146169
[1, 10000] loss: 4.99558542	pol: 2.27726388	val: 0.54366431
[1, 11000] loss: 4.88507943	pol: 2.24559147	val: 0.52789759
[1, 12000] loss: 4.60812975	pol: 2.20129177	val: 0.48136759
[1, 13000] loss: 4.66350889	pol: 2.17827464	val: 0.49704685
[1, 14000] loss: 4.99946029	pol: 2.19628846	val: 0.56063437
[1, 15000] loss: 4.72954157	pol: 2.19178023	val: 0.50755227
[1, 16000] loss: 4.54320518	pol: 2.14176703	val: 0.48028763
[1, 17000] loss: 4.40359769	pol

KeyboardInterrupt: 

In [None]:
main()

In [None]:
pgn = PgnDataset(['data/'+x for x in listdir('data')][:10])

In [None]:
dataloader = DataLoader(pgn, batch_size=8)
for i in dataloader:
    print(i)

In [None]:
net.eval()
board_test = ataxx.Board(fen='1xoxxxx/xxooxxx/xooooox/xxooooo/xxooooo/oo1ooxx/ooxoxxx')
input_test = torch.tensor(board_to_list(board_test), dtype=torch.float32)

In [None]:
print(input_test)
with torch.no_grad():
    print(net(input_test))

In [None]:
import numpy as np
x = np.array(board_to_list(board_test))
x = x.reshape(-1, 7*7)
x = np.matmul(x, net.fc1.weight.detach().numpy().transpose()) + net.fc1.bias.detach().numpy()
x = np.maximum(x, 0)
x = np.matmul(x, net.fc2.weight.detach().numpy().transpose()) + net.fc2.bias.detach().numpy()
x = np.maximum(x, 0)
x = np.matmul(x, net.fc3.weight.detach().numpy().transpose()) + net.fc3.bias.detach().numpy()
x = np.maximum(x, 0)
x = np.matmul(x, net.fc_value.weight.detach().numpy().transpose()) + net.fc_value.bias.detach().numpy()
x

In [12]:
torch.save(net.state_dict(), 'model.pt')

In [None]:
torch.sqrt(torch.tensor(4, dtype=torch.float)).item()