In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
%matplotlib inline

In [None]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        # 3x3 Convolution: 9x9 --> 5x5x(30 features)
        self.conv = nn.Conv2d(1,30,3)
        self.pool = nn.MaxPool2d((2,2))
        self.l1 = nn.Linear(30*3*3,200)
        self.l2 = nn.Linear(200,120)
        self.l3 = nn.Linear(120, 81)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv(x)))
        x= x.view(-1, self.num_flat_features(x))
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    

In [None]:
import os
import re
import go
from parse_sgf import get_moves

def randReflect(board, move):
    if np.random.randint(2):
        board = board.T
        (x, y) = move
        move = (y,x)
    return board, move

class NinebyNineGames(Dataset):
    def __init__(self, root_dir, transform = None):
        '''Read all the .sgf files from root_dir'''
        self.transform = transform
        self.root_dir = root_dir
        self.sgf_files = [entry for entry in os.scandir(self.root_dir) if entry.path.endswith(".sgf")]
        self.games = [get_moves(sgf) for sgf in self.sgf_files]
            
    def __len__(self):
        return sum( len(game) for game in self.games)
    
    def __getitem__(self, idx):
        '''return the position as a (9,9) ndarray
        and next move as a squashed coordinate 0 - 80'''
        g_idx = 0
        count = 0
        while(idx > count + len(self.games[g_idx]) -1 ) :
            count += len(self.games[g_idx])
            g_idx += 1
        g = go.Game( moves = self.games[g_idx])
        
        for _ in range(idx-count):
                g.play_move()
        board, move =  np.array(g.get_board()).reshape(9,9) , g.moves[idx - count]
        if self.transform:
            board , move = self.transform(board, move)
        return torch.Tensor(board).unsqueeze(0) , go.squash(move)

In [None]:
dir = r"/home/jupyter/BokeGo/data"
data = NinebyNineGames(dir)
dataloader = DataLoader(data, batch_size = 32, shuffle = True)

pi = Policy()
err = nn.CrossEntropyLoss()
optimizer = optim.Adam(pi.parameters(), lr = 0.001)

In [None]:
#GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
pi.to(device)

epochs = 2
for _ in range(epochs):
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader,0)):
        inputs, moves = data
        inputs, moves = inputs.to(device), moves.to(device)
        
        optimizer.zero_grad()
        outputs = pi(inputs)
        
        #backprop
        loss = err(outputs, moves)
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item()
        if i % 10000 == 9999 :
            print(f"Epoch: {epoch+1}, Loss: {running_loss:.3f}")
            running_loss = 0.0


In [None]:
torch.save(pi.state_dict(), r"/home/jupyter/BokeGo/policy_train_1.pt")