In [1]:
import numpy as np
print(np.__version__)


import torch
print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import trange
import math
import copy
import time
import random

1.26.4
2.2.2


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)




cpu


In [3]:
if torch.cuda.is_available():
    print("CUDA is available")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
    
    for i in range(torch.cuda.device_count()):
        device = torch.device(f'cuda:{i}')
        print(f"\nCUDA Device {i}:")
        print(f"  Name: {torch.cuda.get_device_name(device)}")
        print(f"  Compute Capability: {torch.cuda.get_device_capability(device)}")
        print(f"  Total Memory: {torch.cuda.get_device_properties(device).total_memory / 1024**3:.2f} GB")
        print(f"  CUDA Version: {torch.version.cuda}")
else:
    print("CUDA is not available")

CUDA is not available


In [229]:
np.set_printoptions(linewidth=200)

INIT_BOARD = np.array([0, 2, 0, 0, 0, 0, -5, 0, -3, 0, 0, 0, 5,
              -5, 0, 0, 0, 3, 0, 5, 0, 0, 0, 0, -2, 0])

UNIQUE_JUMPS = [[1, 1, 1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6],
                [2, 2, 2, 2], [2, 3], [2, 4], [2, 5], [2, 6],
                [3, 3, 3, 3], [3, 4], [3, 5], [3, 6],
                [4, 4, 4, 4], [4, 5], [4, 6],
                [5, 5, 5, 5], [5, 6],
                [6, 6, 6, 6]]

def timed(function):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = function(*args, **kwargs)
        end_time = time.time()
        print(f"{function.__name__} took {end_time - start_time} seconds to run.")
        return result
    return wrapper

def action_to_playidx(action):
    ''' Convert action to playidx'''
    playidx = []
    for i in range(4):
        playidx.append(action % 26)
        action = action // 26
    return playidx

def playidx_to_action(playidx):
    '''Convert playidx to action'''
    action = 0
    for i in range(4):
        action += playidx[i] * (26**i)
    return action

def roll_dice():
    '''returns list of jumps sorted smallest to largest'''
    die1 = random.randint(1, 6)
    die2 = random.randint(1, 6)

    if die1 == die2:
        return [die1] * 4
    else:
        return sorted([die1, die2])


In [230]:
INIT_BOARD = np.array(INIT_BOARD)
print(INIT_BOARD)


[ 0  2  0  0  0  0 -5  0 -3  0  0  0  5 -5  0  0  0  3  0  5  0  0  0  0 -2  0]


In [244]:
class Backgammon:
    def __init__(self):
        self.idx_count = 26
        self.action_size = 26 ** 4

    def get_initial_board(self):
        return np.array(INIT_BOARD)

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, board, player):
        if player == -1:
            board = np.flip(-board)
        return board

    def get_next_state(self, board, play, player):
        for move in play:
            self.make_move(board, move)
        return board

    def check_win(self, board):
        return np.all(board >= 0) or np.all(board <= 0)

    def get_value_and_terminated(self, board):
        '''Always returns 1 for value'''
        if self.check_win(board):
            return 1, True
        return 0, False

    def get_encoded_state(self, board, jumps):
        # [26] + jumps
        # (7, 26) + (6)

        white_cnt = np.maximum(board, 0).astype(np.float32) / 15
        white_one = (board == 1).astype(np.float32)
        white_tower = (board > 1).astype(np.float32)
        black_cnt = np.maximum(-board, 0).astype(np.float32) / 15
        black_one = (board == -1).astype(np.float32)
        black_tower = (board < -1).astype(np.float32)
        empty = (board == 0).astype(np.float32)  # Using small threshold for float comparison

        # Stack the arrays vertically
        encoded_board = np.vstack((white_cnt, white_one, white_tower, black_cnt, black_one, black_tower, empty))

        jumps_encoded = np.zeros(4, dtype=np.float32)
        jumps_encoded[:len(jumps)] = jumps
        jumps_encoded = jumps_encoded/6

        indicies = np.arange(26)
        black_pip = np.sum(black_cnt * indicies) / 200
        indicies = 25-indicies
        white_pip = np.sum(white_cnt * indicies) / 200

        features = np.concatenate([jumps_encoded, [white_pip, black_pip]]).astype(np.float32)

        return encoded_board, features

    def get_encoded_states_batched(self, boards, jumps):
        # Encode boards
        white_cnt = np.maximum(boards, 0).astype(np.float32) / 15
        white_one = (boards == 1).astype(np.float32)
        white_tower = (boards > 1).astype(np.float32)
        black_cnt = np.maximum(-boards, 0).astype(np.float32) / 15
        black_one = (boards == -1).astype(np.float32)
        black_tower = (boards < -1).astype(np.float32)
        empty = (boards == 0).astype(np.float32)

        encoded_boards = np.stack([white_cnt, white_one, white_tower, black_cnt, black_one, black_tower, empty], axis=1)

        # Encode features
        jumps_encoded = jumps.astype(np.float32) / 6

        indices = np.arange(26)
        black_pip = np.sum(black_cnt * indices, axis=1) / 200
        white_pip = np.sum(white_cnt * (25 - indices), axis=1) / 200

        encoded_features = np.column_stack([jumps_encoded, white_pip, black_pip]).astype(np.float32)

        return encoded_boards, encoded_features

    def get_valid_plays(self, board, jumps, player):
        if len(jumps) == 4:
            return self._generate_plays_quads(board, jumps)
        return self._generate_plays(board, jumps)

    def plays_to_actions(self, plays):
        actions = np.zeros(self.action_size)
        for play in plays:
            play_idx = []
            for move in sorted(play, key=lambda move: (move[1]-move[0])):
                play_idx.append(move[0])
            while len(play_idx) < 4:
                play_idx.append(25)

            action_idx = playidx_to_action(play_idx)
            actions[action_idx] = 1
        return actions

    def play_to_action(self, play):
        play = sorted(play, key=lambda move: (move[1]-move[0]))
        play_idx = []
        for move in play:
            play_idx.append(move[0])
        while len(play_idx) < 4:
            play_idx.append(25)
        return playidx_to_action(play_idx)


    def is_valid_move(self, board, jumps, move):
        # print(board)
        start, end = move[0], move[1]

        jump = abs(end - start)
        if jump not in jumps:
            return False

        # Check if start exists
        # print(board)
        # print(move)
        if board[start] <= 0:
            return False

        # Check if move is in right direction
        if end <= start:
            return False

        # Check if jailed
        if board[0] > 0:
            if start != 0:
                return False

        # Check if trying to bear off
        if end >= 25:
            if np.all(board[0:19] <= 0):
                if end == 25:
                    return True
                elif np.all(board[19:start] <= 0):
                    return True
                return False
            else:
                return False

        # Check if end is blocked
        if board[end] < -1:
            return False
        return True

    def _generate_plays(self, board, jumps):
        # print("Generating plays")
        board = board
        jumps = jumps

        def get_movable_pieces(board):
            if board[0] > 0:
                return [0]
            moveable_pieces = []
            for i in range(len(board)):
                if board[i] > 0:
                    moveable_pieces.append(i)
            return moveable_pieces

        res = []

        play = []

        def dfs(board, jumps):
            #print("called dfs")
            movable_pieces = get_movable_pieces(board)
            # print(f"movable_pieces: {get_movable_pieces(board)} with jumps {jumps}")
            if movable_pieces == [] or jumps == []:
                res.append(copy.deepcopy(play))
                return



            for piece in movable_pieces:
                for j in range(len(jumps)):
                    move = (piece, piece+jumps[j])  # make negative for black
                    # print(board, movable_pieces, move)
                    # print("hi")
                    if self.is_valid_move(board, jumps, move) == False:
                        # print("Invalid move")
                        res.append(copy.deepcopy(play))
                        continue

                    #print(f"board: {board}, jumps={jumps}")
                    #print(f"move: {move}")
                    tmp_board = board.copy()
                    tmp_board = self.make_move(tmp_board, move)

                    tmp_jump = copy.copy(jumps)
                    tmp_jump.pop(j)
                    #print(f"tmp_board: {tmp_board}, tmp_jump={tmp_jump}")
                    #print()

                    play.append(move)

                    dfs(tmp_board, tmp_jump)
                    play.pop()

        def remove_duplicate_plays(res):
            res = [tuple(play) for play in res]
            res = set(res)
            res = list(res)
            res = [list(play) for play in res]
            return res

        def sort_plays(res):
            for i in range(len(res)):
                res[i] = sorted(res[i], key=lambda x: (x[0], x[1]))
            return res

        dfs(board, jumps)
        # print(f"{len(res)} plays generated by dfs")
        # print(res)
        if not res:
            return []
        res = sort_plays(res)
        res = remove_duplicate_plays(res)
        # print(res)
        # res.sort(key=lambda x: [x[0][0], x[1][0], x[0][1]])
        max_length_play = max(len(play) for play in res)
        res = [play for play in res if len(play) == max_length_play]

        return res

    def _generate_plays_quads(self, board, jumps):

        def get_movable_pieces(board):
            if board[0] > 0:
                return [0]
            moveable_pieces = []
            for i in range(len(board)):
                if board[i] > 0:
                    moveable_pieces.append(i)
            return moveable_pieces

        res = []
        play = []

        def dfs(board, jumps):
            movable_pieces = get_movable_pieces(board)
            if movable_pieces == [] or jumps == []:
                res.append(copy.deepcopy(play))
                return

            for piece in movable_pieces:
                move = (piece, piece+jumps[0])

                if self.is_valid_move(board, jumps, move) == False:
                    res.append(copy.deepcopy(play))
                    continue


                #print(f"board: {board}, jumps={jumps}")
                #print(f"move: {move}")
                tmp_board = board.copy()
                tmp_board = self.make_move(tmp_board, move)
                tmp_jump = copy.copy(jumps)
                tmp_jump.pop()
                #print(f"tmp_board: {tmp_board}, tmp_jump={tmp_jump}")
                #print()
                play.append(move)

                dfs(tmp_board, tmp_jump)
                play.pop()
        dfs(board, jumps)

        def sort_plays(res):
            for i in range(len(res)):
                res[i] = sorted(res[i], key=lambda x: (x[0]))
            return res
        sorted_res = sort_plays(res)

        def remove_duplicate_plays(res):
            res = [tuple(play) for play in res]
            res = set(res)
            res = list(res)
            res = [list(play) for play in res]
            return res
        processed_res = remove_duplicate_plays(sorted_res)

        if not processed_res:
            return []

        max_length_play = max(len(play) for play in processed_res)
        longest_plays = [play for play in processed_res if len(play) == max_length_play]

        return longest_plays

    def make_move(self, board, move) -> None:
        '''Apply move on board'''
        start, end = move

        if board[start] == 0:
            raise ValueError(f"Invalid move {move}: No piece to move at start position. Board: \n{self.draw(board)}")
        board[start] -= 1
        # Handle bearoff
        if end >= 25:
            return board
        # Handle hitting opponent checker
        if board[end] == -1:
            board[end] = 1
            board[25] -= 1
        elif board[end] < -1:
            raise ValueError(f"Invalid move {move}: End position has more than -1 checker. Board: \n{self.draw(board)}")
        # Handle regular move
        else:
            board[end] += 1
        return board

    def draw(self, board) -> str:
        '''String representation of the board'''
        def transform(i):
            x = str(i)
            if x[0] != "-":
                x = " " + x
            if len(x) < 3:
                x += " "
            return x

        # top index numbers for the board
        top_idx = "  " + " ".join(f"{transform(i)}" for i in range(12, 0, -1)) + "   " + transform(0)

        # bottom index numbers for the board
        bot_idx = "  " + " ".join(f"{transform(i)}" for i in range(13, 25)) + "   " + transform(25)

        # ===== boarder for top and bottom
        boarder = "="*57

        # top row of the board from board
        top = "||" + \
            " ".join(f"{transform(board[i])}" for i in range(12, 6, -1)) + "|" + \
            " ".join(f"{transform(board[i])}" for i in range(6, 0, -1)) + "|| " + \
            str(transform(board[0]))

        # bottom row of the board from board
        bot = "||" + \
            " ".join(f"{transform(board[i])}" for i in range(13, 19)) + "|" + \
            " ".join(f"{transform(board[i])}" for i in range(19, 25)) + "|| " + \
            str(transform(board[25]))

        # x combines all the strings together to make the board
        x = top_idx + "\n" + boarder + "\n" + top + \
            "\n\n\n" + bot + "\n" + boarder + "\n" + bot_idx

        # print(x)
        return x

In [245]:
class Node:
    def __init__(self, game, args, board, jumps, parent=None, action_taken=None, play_taken=None, prior=0, visit_count=0, level=0):
        self.game = game
        self.args = args
        self.board = board
        self.jumps = jumps
        self.parent = parent
        self.play_taken = play_taken
        self.action_taken = action_taken

        self.prior = prior
        self.children = []

        self.visit_count = visit_count
        self.value_sum = 0
        self.state_value = None

        self.level = level

        self.search_weight = 1
        if self.jumps and len(self.jumps) == 2:
            self.search_weight = 2

    def __str__(self):
        board_repr = self.game.draw(board=self.board)
        board_repr = '\n'.join([" "*2*self.level + line for line in board_repr.splitlines()])
        return \
f"""
{" "*2*self.level}{"-"*60}
{" "*2*self.level}Level: {self.level}, N: {self.visit_count}, val: {self.value_sum:.3f}, prior: {self.prior:.3f}, uct:{self.parent.get_ucb(self) if self.parent else None}, weight: {self.search_weight}, num_children: {len(self.children)}
{board_repr}
{" "*2*self.level}jumps={self.jumps}, state_value={self.state_value}, action_taken={self.action_taken}
{"  "*2*self.level}board = {self.board}
{" "*2*self.level}{"-"*60}
"""


    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        if self.jumps == []:
            return self.find_random_child_weighted()
        else:
            best_child = None
            best_ucb = -np.inf

            for child in self.children:
                ucb = self.get_ucb(child)
                if ucb > best_ucb:
                    best_child = child
                    best_ucb = ucb

            return best_child

    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = float("inf")
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    def find_random_child_weighted(self):
        if not self.children:
            raise RuntimeError(
                f"find random child weighted called on leaf node {self.node}")
        weights = [child.search_weight for child in self.children]
        selected_child = random.choices(self.children, weights=weights, k=1)[0]
        return selected_child

    def expand(self, policy, plays):
        # plays = self.game.get_valid_plays(self.board, self.jumps, 1)
        if len(self.children) > 0:
            # print(f"expand():skip already expanded node")
            return
        for play in plays:
            child_action_taken = self.game.play_to_action(play)
            child_board = self.board.copy()
            child_board = self.game.get_next_state(child_board, play, 1)
            child_board = self.game.change_perspective(child_board, player=-1)
            child_prior = float(policy[child_action_taken])


            child = Node(self.game, self.args, board=child_board, jumps=[], parent=self, action_taken=child_action_taken, prior=child_prior, visit_count=0, level=self.level+1)
            self.children.append(child)

            if play == [[]]:
                print(f"created {len(self.children)} children for play {play}")

            for possible_jump in UNIQUE_JUMPS:
                child_child = Node(self.game, self.args, board=child_board, jumps=possible_jump, parent=child, action_taken=child_action_taken, prior=1, visit_count=0, level=self.level+2)
                child.children.append(child_child)


    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        if self.jumps == []:
            value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)


In [233]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, num_features, device):
        super().__init__()

        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv1d(7, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm1d(num_hidden),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )

        # Policy head
        self.policyConv = nn.Conv1d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.policyBN = nn.BatchNorm1d(num_hidden)
        self.policyFlatten = nn.Flatten()
        self.policyLinear = nn.Sequential(
            nn.Linear(num_hidden * game.idx_count + num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, game.action_size)
        )

        # Value head
        self.valueConv = nn.Conv1d(num_hidden, 2, kernel_size=3, padding=1)
        self.valueBN = nn.BatchNorm1d(2)
        self.valueFlatten = nn.Flatten()
        self.valueLinear = nn.Sequential(
            nn.Linear(2 * game.idx_count + num_features, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Tanh()
        )

        self.to(device)
    #@timed
    def forward(self, x, features):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)

        # Policy head
        policy = self.policyConv(x)
        policy = self.policyBN(policy)
        policy = F.relu(policy)
        policy = self.policyFlatten(policy)
        policy = torch.cat([policy, features], dim=1)
        policy = self.policyLinear(policy)

        # Value head
        value = self.valueConv(x)
        value = self.valueBN(value)
        value = F.relu(value)
        value = self.valueFlatten(value)
        value = torch.cat([value, features], dim=1)
        value = self.valueLinear(value)

        return policy, value


class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv1d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(num_hidden)
        self.conv2 = nn.Conv1d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(num_hidden)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

In [234]:
class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, board, jumps):
        root = Node(self.game, self.args, board, jumps, visit_count=1)
        device = self.model.device

        enc_board, enc_features = self.game.get_encoded_state(root.board, root.jumps)
        policy, value = self.model(torch.tensor(enc_board, device=device).unsqueeze(0), torch.tensor(enc_features, device=device).unsqueeze(0))

        value = float(value.item())


        policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)

        valid_plays = self.game.get_valid_plays(root.board, root.jumps, 1)
        valid_actions = self.game.plays_to_actions(valid_plays)
        policy *= valid_actions
        policy /= np.sum(policy)


        root.state_value = value
        root.expand(policy, valid_plays)

        for search in range(self.args['num_searches']):
            # print(search)
            node = root

            while node.is_fully_expanded():
                node = node.select()

            value, is_terminal = self.game.get_value_and_terminated(node.board)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                enc_board, enc_features = self.game.get_encoded_state(node.board, node.jumps)
                policy, value = self.model(torch.tensor(enc_board, device=device).unsqueeze(0), torch.tensor(enc_features, device=device).unsqueeze(0))
                value = float(value.item())
                node.state_value = value

                policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

                valid_plays = self.game.get_valid_plays(node.board, node.jumps, 1)
                valid_actions = self.game.plays_to_actions(valid_plays)
                policy *= valid_actions
                policy /= np.sum(policy)

                node.expand(policy, valid_plays)

            node.backpropagate(value)


        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs, root

In [247]:
TEST_BOARD = np.array([0, 2, 0, 0, 0, 0, -3, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, -2, 0])

class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)

    def selfPlay(self):
        memory = []
        player = 1
        # board = self.game.get_initial_board()
        # jumps = roll_dice()
        board = INIT_BOARD
        jumps = [1,3]

        while True:
            # neutral_board = self.game.change_perspective(board, player)
            action_probs, root = self.mcts.search(board, jumps)


            memory.append((board, jumps, action_probs, player))

            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            temperature_action_probs /= np.sum(temperature_action_probs)
            # Add this check
            if np.isnan(temperature_action_probs).any():
                raise ValueError(f"AZ.selfPlay(): NaN detected in temperature_action_probs: {temperature_action_probs}")

            action = np.random.choice(self.game.action_size, p=temperature_action_probs)
            # print("+"*60)
            # print(f"Finished search for player {player}:")
            # print(root)
            # print(f"action taken: {action}, {action_to_playidx(action)}")
            # print("+"*60)
            # print()

            # Make move based on policy
            #state = self.game.get_next_state(state, action, player)
            child = None
            for child in root.children:
                if child.action_taken == action:
                    child = child
                    break
            if child is None:
                raise ValueError(f"Child node with action {action} not found")
            child = child.find_random_child_weighted()

            board = child.board
            # board = self.game.change_perspective(board, -1)
            jumps = child.jumps


            value, is_terminal = self.game.get_value_and_terminated(board)
            if is_terminal:
                returnMemory = []
                for hist_neutral_board, hist_jumps, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    encoded_board, encoded_features = self.game.get_encoded_state(hist_neutral_board, hist_jumps)
                    returnMemory.append((
                        encoded_board,
                        encoded_features,
                        hist_action_probs,
                        hist_outcome
                    ))
                # print("Game over. Value: ", value)
                # print(f"state[-4] Value: {returnMemory[-4][3]}")
                # print(f"state[-3] Value: {returnMemory[-3][3]}")
                # print(f"state[-2] Value: {returnMemory[-2][3]}")
                # print(f"state[-1] Value: {returnMemory[-1][3]}")
                return returnMemory

            player = self.game.get_opponent(player)



    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            if len(sample) == 0:
                continue
            board, jumps, policy_targets, value_targets = zip(*sample)

            board, jumps, policy_targets, value_targets = np.array(board), np.array(jumps), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            board = torch.tensor(board, dtype=torch.float32, device=self.model.device)
            jumps = torch.tensor(jumps, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(board, jumps)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()

            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            # torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            # torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")
        torch.save(self.model.state_dict(), f"model_final.pt")
        torch.save(self.optimizer.state_dict(), f"optimizer_final.pt")

In [236]:
class MCTSParallel:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, spGames):

        boards, padded_jumps = [], []
        for spg in spGames:
            boards.append(spg.board)
            if len(spg.jumps) < 4:
                padded_jumps.append(spg.jumps + [0,0])
            else:
                padded_jumps.append(spg.jumps)
        boards = np.array(boards, dtype=np.float32)
        padded_jumps = np.array(padded_jumps, dtype=np.float32)

        enc_boards, enc_features = self.game.get_encoded_states_batched(boards, padded_jumps)
        policy, value = self.model(
            torch.tensor(enc_boards, device=self.model.device),
            torch.tensor(enc_features, device=self.model.device)
        )

        policy = torch.softmax(policy, axis=1).detach().cpu().numpy()
        value = value.detach().cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size, size=policy.shape[0])

        # print(f"Top of search, policy shape: {policy.shape}, boards shape: {boards.shape}, jumps shape: {jumps.shape}")


        for i, spg in enumerate(spGames):
            # if i == 0:
            #     print(f"In enumerate spGames, policy[i] shape: {policy[i].shape}, boards[i] shape: {boards[i].shape}, jumps[i] shape: {jumps[i].shape}")
            spg_board, spg_jumps = spg.board, spg.jumps
            spg_policy, spg_value = policy[i], float(value[i][0])
            valid_plays = self.game.get_valid_plays(spg_board, spg_jumps, 1)
            valid_actions = self.game.plays_to_actions(valid_plays)
            spg_policy *= valid_actions
            spg_policy /= np.sum(spg_policy)

            if spg.root is None:
                spg.root = Node(self.game, self.args, spg_board, spg_jumps, visit_count=1)
            spg.root.expand(spg_policy, valid_plays)

        for search in range(self.args['num_searches']):
            for spg in spGames:
                spg.node = None
                node = spg.root
                while node.is_fully_expanded():
                    node = node.select()
                value, is_terminal = self.game.get_value_and_terminated(node.board)
                value = self.game.get_opponent_value(value)
                if is_terminal:
                    node.backpropagate(value)
                else:
                    spg.node = node


            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]

            if len(expandable_spGames) > 0:
                boards = []
                jumps = []
                for mappingIdx in expandable_spGames:
                    boards.append(spGames[mappingIdx].node.board)
                    if len(spGames[mappingIdx].node.jumps) < 4:
                        jumps.append(spGames[mappingIdx].node.jumps + [0,0])
                    else:
                        jumps.append(spGames[mappingIdx].node.jumps)
                boards = np.array(boards, dtype=np.float32)
                jumps = np.array(jumps, dtype=np.float32)

                enc_boards, enc_features = self.game.get_encoded_states_batched(boards, jumps)
                policy, value = self.model(
                    torch.tensor(enc_boards, device=self.model.device),
                    torch.tensor(enc_features, device=self.model.device)
                )
                policy = torch.softmax(policy, axis=1).detach().cpu().numpy()
                value = value.detach().cpu().numpy()

            for i, mappingIdx in enumerate(expandable_spGames):
                node = spGames[mappingIdx].node
                spg_policy, spg_value = policy[i], float(value[i][0])
                node.state_value = spg_value

                valid_plays = self.game.get_valid_plays(node.board, node.jumps, 1)
                valid_actions = self.game.plays_to_actions(valid_plays)
                spg_policy *= valid_actions
                spg_policy /= np.sum(spg_policy)

                node.expand(spg_policy, valid_plays)
                node.backpropagate(spg_value)

In [237]:
TEST_BOARD = np.array([0, 2, 0, 0, 0, 0, -3, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, -2, 0])


class AlphaZeroParallel:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTSParallel(game, args, model)

    def selfPlay(self):
        return_memory = []
        player = 1
        spGames = [SPG(self.game) for spg in range(self.args['num_parallel_games'])]

        P1_wins = 0
        P2_wins = 0

        while len(spGames) > 0:
            start_time = time.time()

            self.mcts.search(spGames)

            for i in range(len(spGames))[::-1]:
                spg = spGames[i]

                action_probs = np.zeros(self.game.action_size)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                if np.sum(action_probs) == 0:
                    print(f"    AZF.selfPlay(): Sum of spGames action probs is 0, action_probs{action_probs}")


                action_probs /= np.sum(action_probs)

                spg.memory.append((spg.root.board, spg.root.jumps, action_probs, player))

                temperature_action_probs = action_probs  ** (1 / self.args['temperature'])
                temperature_action_probs /= np.sum(temperature_action_probs)
                # Add this check
                if np.isnan(temperature_action_probs).any():
                    print(spg.root)
                    for child in spg.root.children:
                        print(child)
                    raise ValueError(f"AZF.selfPlay(): NaN detected in temperature_action_probs: {temperature_action_probs}, action_probs{action_probs}")

                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                chosen_child = None
                for child in spg.root.children:
                    if child.action_taken == action:
                        chosen_child = child
                        break
                if chosen_child is None:
                    raise ValueError(f"AZF.selfPlay(): Chosen child node with action {action} not found")

                spg.root = chosen_child.find_random_child_weighted()
                spg.root.parent = None
                spg.board = spg.root.board
                spg.jumps = spg.root.jumps

                value, is_terminal = self.game.get_value_and_terminated(spg.board)
                if is_terminal:
                    if player == 1:
                        P1_wins += 1
                    else:
                        P2_wins += 1
                    # print(f"AZ.selfPlay: Terminal State. Add  {len(spg.memory)} rows of data to return_memory.")
                    for hist_neutral_board, hist_jumps, hist_action_probs, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                        encoded_board, encoded_features = self.game.get_encoded_state(hist_neutral_board, hist_jumps)
                        return_memory.append((
                            encoded_board,
                            encoded_features,
                            hist_action_probs,
                            hist_outcome
                        ))
                    # print(f"\nAZF.selfPlay(): SFgame ended with {len(spg.memory)} states:")
                    # print(f"memory[0]: {spg.memory[0][0]}, Player= {spg.memory[0][3]}, V= {return_memory[0][3]}")
                    # print(f"memory[1]: {spg.memory[1][0]}, Player= {spg.memory[1][3]}, V= {return_memory[1][3]}")
                    # print(f"memory[2]: {spg.memory[2][0]}, Player= {spg.memory[2][3]}, V= {return_memory[2][3]}")
                    # print(f"memory[-3]: {spg.memory[-3][0]}, jumps={spg.memory[-3][1]}, Player= {spg.memory[-3][3]}, V= {return_memory[-3][3]}")
                    # print(f"memory[-2]: {spg.memory[-2][0]}, jumps={spg.memory[-2][1]}, Player= {spg.memory[-2][3]}, V= {return_memory[-2][3]}")
                    # print(f"memory[-1]: {spg.memory[-1][0]}, jumps={spg.memory[-1][1]}, Player= {spg.memory[-1][3]}, V= {return_memory[-1][3]}")
                    # print(f"")
                    del spGames[i]
            player = self.game.get_opponent(player)

            end_time = time.time()
            iteration_time = end_time - start_time
            print(f"AZF.selfPlay(): SP iteration time: {iteration_time:.4f} seconds for {len(spGames)} games with {self.args['num_searches']} num_searches")
        print(f"AFZ.selfPlay(): P1 wins {P1_wins} P-1 wins {P2_wins}")
        return return_memory

    def train(self, memory):
        random.shuffle(memory)
        total_policy_loss = 0
        total_value_loss = 0
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            if len(sample) == 0:
                continue
            board, jumps, policy_targets, value_targets = zip(*sample)

            board, jumps, policy_targets, value_targets = np.array(board), np.array(jumps), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            board = torch.tensor(board, dtype=torch.float32, device=self.model.device)
            jumps = torch.tensor(jumps, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(board, jumps)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_policy_loss += policy_loss.item()
            total_value_loss += value_loss.item()
        total_value_loss /= len(memory)
        total_policy_loss /= len(memory)
        print(f"AZ.Train(): Total policy loss: {total_policy_loss}, Total value loss: {total_value_loss}")
        

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                memory += self.selfPlay()

            print(f"AZ.Learn(): Total memory length: {len(memory)}")
            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_init_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_init_{iteration}.pt")
        torch.save(self.model.state_dict(), f"model_final.pt")
        torch.save(self.optimizer.state_dict(), f"optimizer_final.pt")

class SPG:
    def __init__(self, game):
        self.board = INIT_BOARD
        self.jumps = [1, 3]
        self.memory = []
        self.root = None
        self.node = None






In [238]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Set seeds
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

# If you're using CUDA (GPU), also set the CUDA seed
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
print(torch.cuda.memory_summary())

cuda
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   9039 MiB |  14433 MiB |    956 GiB |    948 GiB |
|       from large pool |   9028 MiB |  14416 MiB |    585 GiB |    576 GiB |
|       from small pool |     10 MiB |     82 MiB |    371 GiB |    371 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   9039 MiB |  14433 MiB |    956 GiB |    948 GiB |
|       from large pool |   9028 MiB |  14416 MiB |    585 GiB |    576 GiB |
|       from small pool |     10 MiB |     82 MiB |    371 GiB |    371 GiB |
|----------------------------------------------------------

#AlphaParallel

In [239]:
bg = Backgammon()
model = ResNet(game=bg, num_resBlocks=20, num_hidden=64, num_features=6, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, weight_decay=0.0001)

In [240]:
torch.save(model.state_dict(), f"model_init.pt")
torch.save(optimizer.state_dict(), f"optimizer_init.pt")

# model = ResNet(game=bg, num_resBlocks=20, num_hidden=64, num_features=6, device=device)
# model.load_state_dict(torch.load("model_final.pt", map_location=device))
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

In [241]:
args = {
    'C': 2,
    'num_searches': 200,
    'num_iterations': 1,
    'num_selfPlay_iterations': 64,
    'num_parallel_games':64,
    'num_epochs': 5,
    'batch_size': 1024,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.001,
}


model.eval()

alphaZero = AlphaZero(model, optimizer, bg, args)
action_probs, root = alphaZero.mcts.search(INIT_BOARD, [1,3])

print(root)
for child in sorted(root.children, key=lambda x: (child.value_sum/child.visit_count), reverse=True):
    if child.visit_count > 0:
        print(f"avg_value = {child.value_sum/child.visit_count}")
    print(child)



------------------------------------------------------------
Level: 0, N: 201, val: 5.028, prior: 0.000, uct:None, weight: 2, num_children: 19
   12  11  10  9   8   7   6   5   4   3   2   1     0 
|| 5   0   0   0  -3   0 |-5   0   0   0   0   2 ||  0 


||-5   0   0   0   3   0 | 5   0   0   0   0  -2 ||  0 
   13  14  15  16  17  18  19  20  21  22  23  24    25
jumps=[1, 3], state_value=-0.06579335778951645, action_taken=None
------------------------------------------------------------

avg_value = -0.04994314288099607

  ------------------------------------------------------------
  Level: 1, N: 6, val: -0.300, prior: 0.031, uct:0.6485708339475742, weight: 1, num_children: 21
     12  11  10  9   8   7   6   5   4   3   2   1     0 
  || 5   0   0   0  -2   0 |-4  -2   0   0   0   2 ||  0 
  
  
  ||-5   0   0   0   3   0 | 5   0   0   0   0  -2 ||  0 
     13  14  15  16  17  18  19  20  21  22  23  24    25
  jumps=[], state_value=None, action_taken=456761
  ------------------

In [None]:
# torch.save(model.state_dict(), f"/content/gdrive/My\ Drive/model_final.pt")
# torch.save(optimizer.state_dict(), f"/content/gdrive/My\ Drive/optimizer_final.pt")

In [261]:
args = {
    'C': 2,
    'num_searches': 100,
    'num_iterations': 1,
    'num_selfPlay_iterations': 128,
    'num_parallel_games':64,
    'num_epochs': 5,
    'batch_size': 1024,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.001,
}
alphaZeroParallel = AlphaZeroParallel(model, optimizer, bg, args)

alphaZeroParallel.learn()

  0%|          | 0/2 [00:00<?, ?it/s]

AZF.selfPlay(): SP iteration time: 29.1642 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 33.2692 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 32.2430 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 35.6067 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 34.2859 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 34.6754 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 34.0656 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 36.0275 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 34.5095 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 35.8991 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 34.9479 seconds for 64 games with 100 num_searches
AZF.selfPlay(): SP iteration time: 33.0146 seconds for

  0%|          | 0/5 [00:00<?, ?it/s]

AZ.Train(): Total policy loss: 0.00816875962729198, Total value loss: 0.0008284165623236877
AZ.Train(): Total policy loss: 0.006780055729786826, Total value loss: 0.0006364574919377562
AZ.Train(): Total policy loss: 0.005827723900074881, Total value loss: 0.00048435059739456804
AZ.Train(): Total policy loss: 0.005340210780022917, Total value loss: 0.0003536576123135936
AZ.Train(): Total policy loss: 0.005073262835541949, Total value loss: 0.00024844560545628013


In [262]:
args = {
    'C': 2,
    'num_searches': 1000,
    'num_iterations': 1,
    'num_selfPlay_iterations': 128,
    'num_parallel_games':64,
    'num_epochs': 3,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.01,
}


model.eval()

alphaZero = AlphaZero(model, optimizer, bg, args)
action_probs, root = alphaZero.mcts.search(INIT_BOARD, [1,3])

# for child in sorted(root.children, key=lambda x: (x.prior) , reverse=True):#/x.visit_count)):
#     print(child)



In [263]:
print(root)
for child in sorted(root.children, key=lambda x: (x.prior), reverse=True):
    if child.visit_count > 0:
        print(f"avg_value = {child.value_sum/child.visit_count}{child}")


------------------------------------------------------------
Level: 0, N: 1001, val: 214.528, prior: 0.000, uct:None, weight: 2, num_children: 19
   12  11  10  9   8   7   6   5   4   3   2   1     0 
|| 5   0   0   0  -3   0 |-5   0   0   0   0   2 ||  0 


||-5   0   0   0   3   0 | 5   0   0   0   0  -2 ||  0 
   13  14  15  16  17  18  19  20  21  22  23  24    25
jumps=[1, 3], state_value=-0.016226641833782196, action_taken=None
board = [ 0  2  0  0  0  0 -5  0 -3  0  0  0  5 -5  0  0  0  3  0  5  0  0  0  0 -2  0]
------------------------------------------------------------

avg_value = -0.11954281895199835
  ------------------------------------------------------------
  Level: 1, N: 83, val: -9.922, prior: 0.144, uct:0.6682293799192851, weight: 1, num_children: 21
     12  11  10  9   8   7   6   5   4   3   2   1     0 
  || 5   0   0   0  -2   0 |-4  -2   0   0   0   2 ||  0 
  
  
  ||-5   0   0   0   3   0 | 5   0   0   0   0  -2 ||  0 
     13  14  15  16  17  18  19  20 

In [258]:
print(root)
print(root.children[0])
for child in root.children[0].children:
    if child.visit_count > 0:
        print(f"avg_value = {child.value_sum/child.visit_count}")
    print(child)


------------------------------------------------------------
Level: 0, N: 1001, val: 3.228, prior: 0.000, uct:None, weight: 2, num_children: 19
   12  11  10  9   8   7   6   5   4   3   2   1     0 
|| 5   0   0   0  -3   0 |-5   0   0   0   0   2 ||  0 


||-5   0   0   0   3   0 | 5   0   0   0   0  -2 ||  0 
   13  14  15  16  17  18  19  20  21  22  23  24    25
jumps=[1, 3], state_value=0.07290183752775192, action_taken=None
board = [ 0  2  0  0  0  0 -5  0 -3  0  0  0  5 -5  0  0  0  3  0  5  0  0  0  0 -2  0]
------------------------------------------------------------


  ------------------------------------------------------------
  Level: 1, N: 100, val: -0.087, prior: 0.102, uct:0.5643172322610804, weight: 1, num_children: 21
     12  11  10  9   8   7   6   5   4   3   2   1     0 
  || 5   0   0   0  -2   0 |-4  -2   0   0   0   2 ||  0 
  
  
  ||-5   0   0   0   3   0 | 5   0   0   0   0  -2 ||  0 
     13  14  15  16  17  18  19  20  21  22  23  24    25
  jumps=[], s

In [252]:
board = np.array([ 0,  2,  0,  0,  0,  0, -5,  0, -3,  0,  0,  0,  5, -5,  0,  0,  0,  3,  0,  5,  0,  0,  0,  0, -2,  0])  

model.eval()
enc_board, enc_features = bg.get_encoded_state(board, [1,3])
_, value = model(torch.tensor(enc_board, device=device).unsqueeze(0), torch.tensor(enc_features, device=device).unsqueeze(0))
print(f"value = {value.item()}")


value = 0.07290183752775192


In [None]:
args = {
    'C': 2,
    'num_searches': 100,
    'num_iterations': 1,
    'num_selfPlay_iterations': 128,
    'num_parallel_games':64,
    'num_epochs': 3,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.01,
}

model.eval()

alphaZero = AlphaZero(model, optimizer, bg, args)
action_probs, root = alphaZero.mcts.search(TEST_BOARD, [1,3])

print(root)
for child in sorted(root.children, key=lambda x: (x.value_sum / x.visit_count)):#/x.visit_count)):
    print(child)


------------------------------------------------------------
Level: 0, N: 101, val: -9.363, prior: 0.000, uct:None, weight: 2, num_children: 14
   12  11  10  9   8   7   6   5   4   3   2   1     0 
|| 0   0   0   0  -1   0 |-3   0   0   0   0   2 ||  0 


|| 0   0   0   0   1   0 | 3   0   0   0   0  -2 ||  0 
   13  14  15  16  17  18  19  20  21  22  23  24    25
jumps=[1, 3], state_value=0.34306830167770386, action_taken=None
------------------------------------------------------------


  ------------------------------------------------------------
  Level: 1, N: 28, val: -2.103, prior: 0.131, uct:0.6282955757239076, weight: 1, num_children: 21
     12  11  10  9   8   7   6   5   4   3   2   1     0 
  || 0   0   0   0  -1   0 |-3   0   0   0   0   2 ||  0 
  
  
  || 0   0   0   0   1   0 | 3   0  -1   0  -1   0 ||  0 
     13  14  15  16  17  18  19  20  21  22  23  24    25
  jumps=[], state_value=None, action_taken=456327
  --------------------------------------------------