In [1]:
import torch
from torch import nn
import gym
import numpy as np
from go_ai import data
from tqdm import tqdm_notebook

BOARD_SIZE = 5
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

In [2]:
class PolicyValueNet(nn.Module):
    def __init__(self, board_size):
        super().__init__()
        self.board_size = board_size
        self.main = nn.Sequential(
            nn.Linear(6 * board_size * board_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(
            nn.Linear(256, board_size * board_size + 1),
        )
        self.value = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, state):
        invalid_values = data.batch_invalid_values(state)
        x = torch.flatten(state, start_dim=1)
        x = self.main(x)
        policy = self.policy(x)
        policy += invalid_values
        policy = nn.functional.softmax(policy, dim=1)
        value = self.value(x)
        return policy, value
        

In [3]:
net = PolicyValueNet(BOARD_SIZE)
state = go_env.get_state()
state_batches = state[np.newaxis]
states_tensor = torch.from_numpy(state_batches).type(torch.FloatTensor)
policy, value = net(states_tensor)
print(policy)
print(value)
policy.shape

tensor([[0.0414, 0.0378, 0.0377, 0.0373, 0.0401, 0.0384, 0.0393, 0.0372, 0.0364,
         0.0370, 0.0394, 0.0394, 0.0408, 0.0407, 0.0382, 0.0357, 0.0373, 0.0400,
         0.0365, 0.0369, 0.0378, 0.0379, 0.0392, 0.0395, 0.0392, 0.0387]],
       grad_fn=<SoftmaxBackward>)
tensor([[0.4983]], grad_fn=<SigmoidBackward>)


torch.Size([1, 26])

In [4]:
opt = torch.optim.Adam(net.parameters(), lr=0.001)

In [5]:
def play_game(env, model1, model2):
    states = []
    
    env.reset()
    state = env.get_canonical_state()
    states.append(state)
    done = False
    while not done:
        turn = go_env.turn()
        state_tensor = torch.from_numpy(state[np.newaxis]).type(torch.FloatTensor)
        if turn == 0:
            action_probs, _ = model1(state_tensor)
        else:
            action_probs, _ = model2(state_tensor)
        action = np.random.choice(np.arange(0, env.size * env.size + 1), p=action_probs.detach().numpy()[0])
        _, _, done, _ = env.step(action)
        state = env.get_canonical_state()
        states.append(state)
    winner = go_env.get_winner()
    canonical_winners = [winner if i % 2 == 0 else 1 - winner for i in range(len(states))]
    return states, canonical_winners

def generate_trajectories(env, model1, model2, num_episodes):
    state_list = []
    winner_list = []
    for i in range(num_episodes):
        states, winners = play_game(env, model1, model2)
        state_list.extend(states)
        winner_list.extend(winners)
    return state_list, winner_list

In [9]:
states, winners = generate_trajectories(go_env, net, net, 100)

In [7]:
def policy_eval(model, opt, states, winners, batch_size):
    state_batches = np.array_split(states, len(states) // batch_size)
    winner_batches = np.array_split(winners, len(winners) // batch_size)
    pbar = tqdm_notebook(range(len(state_batches)))
    for b in pbar:
        b_s = torch.from_numpy(state_batches[b]).type(torch.FloatTensor)
        b_w = winner_batches[b]
        b_w_tensor = torch.from_numpy(b_w).type(torch.FloatTensor)
        opt.zero_grad()
        _, pred_win = model(b_s)
        loss = nn.functional.binary_cross_entropy(pred_win, b_w_tensor)
        loss.backward()
        opt.step()
        correct = (pred_win > 0.5).type(torch.IntTensor)[:,0] == b_w_tensor.type(torch.IntTensor)
        accuracy = np.mean(correct.numpy())
        pbar.set_postfix_str('Loss: ' + str(loss.item()) + ' Accuracy: ' + str(accuracy))

In [10]:
policy_eval(net, opt, states, winners, 32)

HBox(children=(IntProgress(value=0, max=114), HTML(value='')))

  # This is added back by InteractiveShellApp.init_path()
  # This is added back by InteractiveShellApp.init_path()





In [13]:
def get_qvals(env, model, states):
    canonical_next_states = batch_canonical_children_states(states)
    _, canonical_next_vals = model(canonical_next_states)

    curr_idx = 0
    batch_qvals = []
    for state in states:
        valid_moves = env.gogame.get_valid_moves(state)
        Qs = []
        for move in range(env_gogame.get_action_size(state)):
            if valid_moves[move]:
                canonical_next_state = canonical_next_states[curr_idx]
                terminal = env.gogame.get_game_ended(canonical_next_state)
                winning = canonical_winning(canonical_next_state)
                oppo_val = (1 - terminal) * canonical_next_vals[curr_idx].item() + (terminal) * winning
                qval = invert_qval(oppo_val)
                Qs.append(qval)
                curr_idx += 1
            else:
                Qs.append(0)

        batch_qvals.append(Qs)

    assert curr_idx == len(canonical_next_vals), (curr_idx, len(canonical_next_vals))
    return np.array(batch_qvals)

def policy_iter(env, model, opt, states, batch_size):
    state_batches = np.array_split(states, len(states) // batch_size)
    pbar = tqdm_notebook(state_batches)
    for states in pbar:
        states_tensor = torch.from_numpy(states).type(torch.FloatTensor)
        policy, _ = model(states_tensor)
        qvals = get_qvals(env, model, states)
        greedy = np.argmax(qvals, axis=1)
        opt.zero_grad()
        loss = nn.functional.cross_entropy(policy, greedy)
        loss.backward()
        opt.step()
        pbar.set_postfix_str('Loss: ' + str(loss.item()))

In [14]:
policy_iter(go_env, net, opt, states, 32)

HBox(children=(IntProgress(value=0, max=114), HTML(value='')))

NameError: name 'batch_canonical_children_states' is not defined