In [1]:
import torch
from torch import nn
import gym
import numpy as np
from go_ai import data
from tqdm import tqdm_notebook

BOARD_SIZE = 5
MODEL_SAVE_FILE = 'models/actorcritic_{0}x{0}.h5'.format(BOARD_SIZE)
LOAD_TRAINED = True

TEMP_DECAY = 0.9
INIT_TEMP = 1
MIN_TEMP = 1

go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

In [19]:
class PolicyValueNet(nn.Module):
    def __init__(self, board_size, init_temp, min_temp):
        super().__init__()
        self.board_size = board_size
        self.temp = init_temp
        self.min_temp = min_temp
        self.main1 = nn.Sequential(
            nn.Conv2d(6, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 1, 3, padding=1),
            nn.BatchNorm2d(1),
            nn.ReLU()
        )
        self.main2 = nn.Sequential(
            nn.Linear(board_size ** 2, board_size ** 2),
            nn.BatchNorm1d(board_size ** 2),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(
            nn.Linear(board_size ** 2, board_size ** 2 + 1),
        )
        self.value = nn.Sequential(
            nn.Linear(board_size ** 2, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, state):
        invalid_values = data.batch_invalid_values(state)
        x = self.main1(state)
        x = torch.flatten(x, start_dim=1)
        x = self.main2(x)
        policy = self.policy(x)
        policy += invalid_values
        policy /= self.temp
        policy = nn.functional.softmax(policy, dim=1)
        value = self.value(x)
        return policy, value
    
    def decay_temp(self, decay):
        self.temp *= decay
        self.temp = max(self.temp, self.min_temp)
        
class RandomAgent():
    def __call__(self, state):
        valid_moves = go_env.gogame.get_valid_moves(state[0])
        # Do not pass if possible
        if np.sum(valid_moves) > 1:
            valid_moves[-1] = 0
        probs = valid_moves / np.sum(valid_moves)
        return torch.from_numpy(probs[np.newaxis]), None
    
    def eval(self):
        pass

In [3]:
def play_game(env, model1, model2):
    states = []
    
    env.reset()
    state = env.get_canonical_state()
    states.append(state)
    done = False
    while not done:
        turn = env.turn()
        state_tensor = torch.from_numpy(state[np.newaxis]).type(torch.FloatTensor)
        if turn == 0:
            action_probs, _ = model1(state_tensor)
        else:
            action_probs, _ = model2(state_tensor)
        action = np.random.choice(np.arange(0, env.size * env.size + 1), p=action_probs.detach().numpy()[0])
        _, _, done, _ = env.step(action)
        state = env.get_canonical_state()
        states.append(state)
    winner = env.get_winner()
    canonical_winners = [winner if i % 2 == 0 else 1 - winner for i in range(len(states))]
    return states, canonical_winners

def generate_trajectories(env, model1, model2, num_episodes):
    model1.eval()
    model2.eval()
    state_list = []
    winner_list = []
    pbar = tqdm_notebook(range(num_episodes), desc='Trajectory generation')
    for i in pbar:
        states, winners = play_game(env, model1, model2)
        state_list.extend(states)
        winner_list.extend(winners)
        pbar.set_postfix_str('Average length: ' + str(len(state_list) / (i + 1)))
    return state_list, winner_list

def pit(env, model1, model2, num_episodes):
    model1.eval()
    model2.eval()
    model1_wins = 0
    model2_wins = 0
    pbar = tqdm_notebook(range(num_episodes // 2), desc='Playing black')
    for i in pbar:
        _, winners = play_game(env, model1, model2)
        if winners[0] == 1:
            model1_wins += 1
        elif winners[0] == 0:
            model2_wins += 1
        pbar.set_postfix_str('Black WR: {}'.format(model1_wins / (i + 1)))
    black_wins = model1_wins
    pbar = tqdm_notebook(range(num_episodes // 2), desc='Playing white')
    for i in pbar:
        _, winners = play_game(env, model2, model1)
        if winners[0] == 1:
            model2_wins += 1
        elif winners[0] == 0:
            model1_wins += 1
        pbar.set_postfix_str('White WR: {}'.format((model1_wins - black_wins) / (i + 1)))
    print('Model 1 WR: {}'.format(model1_wins / num_episodes))
    print('Model 2 WR: {}'.format(model2_wins / num_episodes))
    return model1_wins, model2_wins

In [4]:
def policy_eval(model, opt, states, winners, batch_size):
    model.train()
    state_batches = np.array_split(states, len(states) // batch_size)
    winner_batches = np.array_split(winners, len(winners) // batch_size)
    total_correct = 0
    pbar = tqdm_notebook(range(len(state_batches)), desc='Policy evaluation')
    for b in pbar:
        b_s = torch.from_numpy(state_batches[b]).type(torch.FloatTensor)
        b_w = winner_batches[b]
        b_w_tensor = torch.from_numpy(b_w).type(torch.FloatTensor)
        opt.zero_grad()
        _, pred_win = model(b_s)
        loss = nn.functional.binary_cross_entropy(pred_win[:,0], b_w_tensor)
        loss.backward()
        opt.step()
        correct = (pred_win > 0.5).type(torch.IntTensor)[:,0] == b_w_tensor.type(torch.IntTensor)
        total_correct += np.sum(correct.numpy())
        accuracy = total_correct / (batch_size * (b + 1))
        pbar.set_postfix_str('Loss: ' + str(loss.item()) + ' Accuracy: ' + str(accuracy))

In [5]:
from go_ai.montecarlo import invert_qval, canonical_winning, batch_canonical_children_states

def get_qvals(env, model, states):
    canonical_next_states = batch_canonical_children_states(states)
    next_states_tensor = torch.from_numpy(canonical_next_states).type(torch.FloatTensor)
    _, canonical_next_vals = model(next_states_tensor)

    curr_idx = 0
    batch_qvals = []
    for state in states:
        valid_moves = env.gogame.get_valid_moves(state)
        Qs = []
        for move in range(env.gogame.get_action_size(state)):
            if valid_moves[move]:
                canonical_next_state = canonical_next_states[curr_idx]
                terminal = env.gogame.get_game_ended(canonical_next_state)
                winning = canonical_winning(canonical_next_state)
                oppo_val = (1 - terminal) * canonical_next_vals[curr_idx].item() + (terminal) * winning
                qval = invert_qval(oppo_val)
                Qs.append(qval)
                curr_idx += 1
            else:
                Qs.append(0)

        batch_qvals.append(Qs)

    assert curr_idx == len(canonical_next_vals), (curr_idx, len(canonical_next_vals))
    return np.array(batch_qvals)

def policy_iter(env, model, opt, states, batch_size):
    model.train()
    state_batches = np.array_split(states, len(states) // batch_size)
    pbar = tqdm_notebook(state_batches, desc='Policy iteration')
    for states in pbar:
        states_tensor = torch.from_numpy(states).type(torch.FloatTensor)
        policy, _ = model(states_tensor)
        qvals = get_qvals(env, model, states)
        greedy = np.argmax(qvals, axis=1)
        greedy_tensor = torch.from_numpy(greedy).type(torch.LongTensor)
        opt.zero_grad()
        loss = nn.functional.cross_entropy(policy, greedy_tensor)
        loss.backward()
        opt.step()
        pbar.set_postfix_str('Loss: ' + str(loss.item()))

In [6]:
def train_step(env, model, opt, batch_size, num_episodes):
    states, winners = generate_trajectories(go_env, model, model, num_episodes)
    policy_eval(model, opt, states, winners, batch_size)
    policy_iter(env, model, opt, states, batch_size)
    model.decay_temp(TEMP_DECAY)
    
def train(env, model, iterations, lr, batch_size, eps_per_iter):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for i in range(iterations):
        print('Iteration {}'.format(i))
        train_step(env, model, opt, batch_size, eps_per_iter)

In [None]:
model = PolicyValueNet(BOARD_SIZE, INIT_TEMP, MIN_TEMP)
if LOAD_TRAINED:
    model.load_state_dict(torch.load(MODEL_SAVE_FILE))
    
train(go_env, model, iterations=10, lr=0.001, batch_size=32, eps_per_iter=256)
torch.save(model.state_dict(), MODEL_SAVE_FILE)

In [8]:
baseline = PolicyValueNet(5, 1, 1)
baseline.load_state_dict(torch.load('models/acbaseline_5x5.h5'))
pit(go_env, model, baseline, 500)

HBox(children=(IntProgress(value=0, description='Playing black', max=250, style=ProgressStyle(description_widt…




HBox(children=(IntProgress(value=0, description='Playing white', max=250, style=ProgressStyle(description_widt…


Model 1 WR: 0.518
Model 2 WR: 0.328


(259, 164)

In [20]:
rand = RandomAgent()
pit(go_env, model, rand, 500)

HBox(children=(IntProgress(value=0, description='Playing black', max=250, style=ProgressStyle(description_widt…




HBox(children=(IntProgress(value=0, description='Playing white', max=250, style=ProgressStyle(description_widt…


Model 1 WR: 0.924
Model 2 WR: 0.058


(462, 29)