In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
%matplotlib inline

In [2]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        # 30 features
        #9x9 -> 5x5 -> 3x3
        self.conv = nn.Conv2d(1,30,3)
        self.pool = nn.MaxPool2d((2,2), ceil_mode = True)
        self.l1 = nn.Linear(30*4*4,200)
        self.l2 = nn.Linear(200,120)
        self.l3 = nn.Linear(120, 81)
    def forward(self, x):
        x = self.pool(F.relu(self.conv(x)))
        x= x.view(-1, self.num_flat_features(x))
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    

In [3]:
import os
import re
import go

def randReflect(board, move):
    if np.random.randint(2):
        board = board.T
        (x, y) = move
        move = (y,x)
    return board, move

class NinebyNineGames(Dataset):
    def __init__(self, root_dir, transform = None):
        '''Read all the .sgf files from root_dir'''
        self.transform = transform
        self.root_dir = root_dir
        self.sgf_files = [entry for entry in os.scandir(self.root_dir) if entry.path.endswith(".sgf")]
        self.games = []
        for sgf in self.sgf_files:
            with open(sgf, 'r') as f:
                match = re.findall(r"[BW]\[(\w\w)\]", f.read())
            self.games.append([(ord(m[0])-97, ord(m[1])-97) for m in match])
            
    def __len__(self):
        return sum( len(game) for game in self.games)
    
    def __getitem__(self, idx):
        '''return the position as a (9,9) ndarray
        and next move as a squashed coordinate'''
        g_idx = 0
        count = 0
        while(idx > count + len(self.games[g_idx]) -1 ) :
            count += len(self.games[g_idx])
            g_idx += 1
        g = go.Game( moves = self.games[g_idx])
        for _ in range(idx-count):
            try:
                g.play_move()
            except go.IllegalMove:
                print(f"game: {g_idx} move: {idx- count}")
        board, move =  np.array(g.get_board()).reshape(9,9) , g.moves[idx - count]
        if self.transform:
            board , move = self.transform(board, move)
        return torch.Tensor(board).unsqueeze(0) , go.squash(move)

In [4]:
dir = r"/home/jupyter/BokeGo/data"
data = NinebyNineGames(dir, transform = randReflect)
dataloader = DataLoader(data, batch_size = 16, shuffle = True)

pi = Policy()
err = nn.CrossEntropyLoss()
optimizer = optim.Adam(pi.parameters(), lr = 0.01)

for i, data in enumerate(dataloader):
    b, m = data
    print(b.size())
    break


torch.Size([16, 1, 9, 9])


In [5]:
#GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pi.to(device)

for epoch in trange(2):
    loss = 0.0
    for i, data in enumerate(dataloader,0):
        inputs, moves = data
        inputs, moves = inputs.to(device), moves.to(device)
        
        optimizer.zero_grad()
        outputs = pi(inputs)
        #backprop
        loss = err(outputs, moves)
        loss.backward()
        optimizer.step()
    
        loss += loss.item()
        if i % 1000 == 1 :
            print(f"Epoch: {epoch+1}, Loss: {loss:.3f}")
            loss = 0.0


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 1, Loss: 9.455
game: 8409 move: 54
game: 8409 move: 54
game: 8409 move: 54
game: 8409 move: 54
game: 8409 move: 54
game: 8409 move: 54
game: 7896 move: 66
game: 7896 move: 66
game: 1286 move: 57
game: 1286 move: 57
game: 1286 move: 57
game: 1286 move: 57
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 9617 move: 81
game: 7516 move: 94
game: 7516 move: 94
game: 7516 move: 94
game: 7516 move: 94
game: 7516 move: 94
game: 7516 move: 94
game: 7516 move: 94
game: 5633 move: 62
game: 5633 move: 62


KeyboardInterrupt: 

In [None]:
for i , data in enumerate(dataloader):
    board, move = data
    print(data)