In [None]:
import numpy as np
import random
import math
import torch
import torch.nn.functional as F
from torch.optim import Adam
from tqdm.notebook import trange
from mcts import MCTS

from games import ConnectFour, TicTacToe
from models import ResNet
from utils import KaggleAgent

In [None]:
class Node:
    def __init__(self, state, player, prior, game, args, parent=None, action_taken=None):
        self.state = state
        self.player = player
        self.children = []
        self.parent = parent
        self.total_value = 0
        self.visit_count = 0
        self.prior = prior
        self.action_taken = action_taken
        self.game = game
        self.args = args

    def expand(self, action_probs):
        for a, prob in enumerate(action_probs):
            if prob != 0:
                child_state = self.state.copy()
                child_state = self.game.drop_piece(child_state, a, self.player)
                child = Node(
                    child_state,
                    self.game.get_opponent_player(self.player),
                    prob,
                    self.game,
                    self.args,
                    parent=self,
                    action_taken=a,
                )
                self.children.append(child)

    def backpropagate(self, value):
        self.total_value += value
        self.visit_count += 1
        if self.parent is not None:
            self.parent.backpropagate(self.game.get_opponent_value(value))

    def is_expandable(self):
        return len(self.children) > 0

    def select_child(self):
        best_score = -np.inf
        best_child = None

        for child in self.children:
            ucb_score = self.get_ucb_score(child)
            if ucb_score > best_score:
                best_score = ucb_score
                best_child = child

        return best_child

    def get_ucb_score(self, child):
        prior_score = self.args['c_puct'] * child.prior * math.sqrt(self.visit_count) / (1 + child.visit_count)
        if child.visit_count == 0:
            return prior_score
        return prior_score - (child.total_value / child.visit_count)

In [None]:
class SelfPlayGame:
    def __init__(self, game):
        self.game_memory = []
        self.player = 1
        self.state = game.get_initial_state()
        self.root = None
        self.node = None
        self.encoded_state = None

In [None]:
class Trainer:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.args = args
        self.mcts = MCTS(self.model, self.game, self.args)

    def self_play(self, group_size=10):
        self_play_games = [SelfPlayGame(self.game) for _ in range(group_size)]
        self_play_memory = []

        while len(self_play_games) > 0:
            del_list = []

            for game in self_play_games:
                game.root = Node(self.game.get_canonical_state(game.state, game.player), 1, prior=0, game=self.game, args=self.args)

            for simulation in range(self.args['num_simulation_games']):
                for game in self_play_games:
                    game.encoded_state = None
                    node = game.root

                    while node.is_expandable():
                        node = node.select_child()

                    is_terminal, value = self.game.check_terminal_and_value(node.state, node.action_taken)
                    value = self.game.get_opponent_value(value)

                    if is_terminal:
                        node.backpropagate(value)

                    else:
                        canonical_state = self.game.get_canonical_state(node.state, node.player)
                        game.encoded_state = self.game.get_encoded_state(canonical_state)
                        game.node = node

                self_play_games_predict = [game for game in self_play_games if game.encoded_state is not None]
                if len(self_play_games_predict) > 0:
                    predict_states = [game.encoded_state for game in self_play_games_predict]
                    if len(predict_states) > 1:
                        predict_states = np.stack(predict_states).reshape(-1, 3, 6, 7)
                    else:
                        predict_states = np.array(predict_states).reshape(1, 3, 6, 7)
                    predict_states = torch.from_numpy(predict_states).float().to(self.device)
                    action_probs, values = self.model(predict_states)
                    action_probs = torch.softmax(action_probs, dim=1)
                    action_probs = action_probs.cpu().detach().numpy()
                    values = values.cpu().detach().numpy()

                for i, game in enumerate(self_play_games_predict):
                    action_probs_game, value_game = action_probs[i], values[i]
                    valid_moves = self.game.get_valid_locations(game.node.state)
                    action_probs_game = action_probs_game * valid_moves
                    action_probs_game = action_probs_game / np.sum(action_probs_game)
                    game.node.expand(action_probs_game)
                    game.node.backpropagate(value_game)

            for game in self_play_games:
                action_probs = [0] * self.game.action_size
                for child in game.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                game.game_memory.append((game.root.state, game.player, action_probs))

                visit_counts = [child.visit_count for child in game.root.children]
                actions = [child.action_taken for child in game.root.children]
                if self.args['temperature'] == 0:
                    action = actions[np.argmax(visit_counts)]
                elif self.args['temperature'] == float('inf'):
                    action = np.random.choice(actions)
                else:
                    visit_count_distribution = np.array(visit_counts) ** (1 / self.args['temperature'])
                    visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
                    action = np.random.choice(actions, p=visit_count_distribution)

                game.state = self.game.drop_piece(game.state, action, game.player)

                is_terminal, reward = self.game.check_terminal_and_value(game.state, action)
                if is_terminal:
                    return_memory = []
                    for hist_state, hist_player, hist_action_probs in game.game_memory:
                        return_memory.append((
                            self.game.get_encoded_state(hist_state), hist_action_probs, reward * ((-1) ** (hist_player != game.player))
                        ))
                        if self.args['augment']:
                            return_memory.append((
                                self.game.get_encoded_state(self.game.get_augmented_state(hist_state)), np.flip(hist_action_probs), reward * ((-1) ** (hist_player != game.player))
                            ))
                    self_play_memory.extend(return_memory)
                    del_list.append(game)

                else:
                    game.player = self.game.get_opponent_player(game.player)

            for game in del_list:
                self_play_games.remove(game)
                
        return self_play_memory

    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory) -1, self.args['batch_size']):
            state, policy, value = list(zip(*memory[batchIdx:min(len(memory) -1, batchIdx + self.args['batch_size'])]))
            state, policy, value = np.array(state), np.array(policy), np.array(value).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32).to(self.device)
            policy = torch.tensor(policy, dtype=torch.float32).to(self.device)
            value = torch.tensor(value, dtype=torch.float32).to(self.device)

            out_policy, out_value = self.model(state)
            loss_policy = F.cross_entropy(out_policy, policy) 
            loss_value = F.mse_loss(out_value, value)
            loss = loss_policy + loss_value

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def run(self):
        for iteration in range(self.args['num_iterations']):
            print(f"iteration: {iteration}")
            memory = []

            self.model.eval()
            for train_game in trange(5, desc="train_game"):
                memory += self.self_play(200)

            self.model.train()
            for epoch in trange(self.args['num_epochs'], desc="epoch"):
                self.train(memory)

            torch.save(self.model.state_dict(), f"Models/{self.game}/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"Models/{self.game}/optimizer_{iteration}.pt")


In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

GAME = 'ConnectFour'
LOAD = True

if GAME == 'ConnectFour':
    args = {
        'num_iterations': 48,             # number of highest level iterations
        'num_train_games': 500,           # number of self-play games to play within each iteration
        'num_simulation_games': 600,      # number of mcts simulations when selecting a move within self-play
        'num_epochs': 4,                  # number of epochs for training on self-play data for each iteration
        'batch_size': 128,                # batch size for training
        'temperature': 1,                 # temperature for the softmax selection of moves
        'c_puct': 2,                      # the value of the constant policy
        'augment': True,                  # whether to augment the training data with flipped states
    }

    game = ConnectFour()
    model = ResNet(9, game).to(device)
    optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    if LOAD:
        model.load_state_dict(torch.load(f'Models/{game}/model.pt', map_location=device))
        optimizer.load_state_dict(torch.load(f'Models/{game}/optimizer.pt', map_location=device))

elif GAME == 'TicTacToe':
        args = {
            'num_iterations': 8,              # number of highest level iterations
            'num_train_games': 500,           # number of self-play games to play within each iteration
            'num_simulation_games': 60,       # number of mcts simulations when selecting a move within self-play
            'num_epochs': 4,                  # number of epochs for training on self-play data for each iteration
            'batch_size': 64,                 # batch size for training
            'temperature': 1,                 # temperature for the softmax selection of moves
            'c_puct': 2,                      # the value of the constant policy
            'augment': False,                 # whether to augment the training data with flipped states
        }

        game = TicTacToe()
        model = ResNet(4, game).to(device)
        optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
        if LOAD:
            model.load_state_dict(torch.load(f'Models/{game}/model.pt', map_location=device))
            optimizer.load_state_dict(torch.load(f'Models/{game}/optimizer.pt', map_location=device))

trainer = Trainer(model, optimizer, game, args)
trainer.run()
