<a href="https://colab.research.google.com/github/calvinpozderac/AZ-testing/blob/az-c4/AZ_C4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
import math
import random
from collections import deque, defaultdict
import time
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Optional
import copy
import heapq

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
class ConnectFourGame:
    """Connect Four game implementation"""

    def __init__(self):
        self.rows = 6
        self.cols = 7
        self.board_size = (self.rows, self.cols)
        self.action_size = self.cols # Actions correspond to columns

    def get_initial_state(self):
        return np.zeros(self.board_size, dtype=np.int8)

    def get_valid_moves(self, state):
        # Valid moves are columns that are not full (top row is 0)
        return state[0, :] == 0

    def get_next_state(self, state, action, player):
        next_state = state.copy()
        # Find the lowest empty row in the chosen column
        for r in range(self.rows - 1, -1, -1):
            if next_state[r, action] == 0:
                next_state[r, action] = player
                break
        return next_state, -player  # Switch player

    def get_game_ended(self, state, player):
        # Check for 4 in a row horizontally, vertically, and diagonally

        # Check horizontal
        for r in range(self.rows):
            for c in range(self.cols - 3):
                if abs(sum(state[r, c:c+4])) == 4:
                    return sum(state[r, c:c+4]) / 4

        # Check vertical
        for c in range(self.cols):
            for r in range(self.rows - 3):
                if abs(sum(state[r:r+4, c])) == 4:
                    return sum(state[r:r+4, c]) / 4

        # Check diagonals (down-right)
        for r in range(self.rows - 3):
            for c in range(self.cols - 3):
                if abs(sum(np.diagonal(state[r:r+4, c:c+4]))) == 4:
                    return sum(np.diagonal(state[r:r+4, c:c+4])) / 4

        # Check diagonals (up-right)
        for r in range(3, self.rows):
            for c in range(self.cols - 3):
                if abs(sum(np.diagonal(np.fliplr(state[r-3:r+1, c:c+4])))) == 4:
                    return sum(np.diagonal(np.fliplr(state[r-3:r+1, c:c+4]))) / 4

        # Check for draw
        if not (state[0, :] == 0).any(): # Check if the top row is full
            return 0

        return None  # Game continues

    def get_canonical_form(self, state, player):
        return state * player

    def get_symmetries(self, state, pi):
        """Get symmetries for Connect Four (just horizontal flip)"""
        # In Connect Four, only horizontal flipping is a relevant symmetry
        symmetries = []

        # Original
        symmetries.append((state, pi))

        # Horizontal flip
        state_flip = np.fliplr(state)
        pi_flip = np.fliplr(pi.reshape(1, self.cols)).flatten()
        symmetries.append((state_flip, pi_flip))

        return symmetries

    def string_representation(self, state):
        """Returns a string representation of the board"""
        symbols = {0: '.', 1: 'X', -1: 'O'}
        board_str = ""
        for r in range(self.rows):
            board_str += " ".join([symbols[state[r, c]] for c in range(self.cols)]) + "\n"
        board_str += "--------------------\n"
        board_str += "0 1 2 3 4 5 6"
        return board_str

# --- Neural Network ---
# To extend to a new game, adjust the network architecture:
# - Input shape: Match the new game's board representation.
# - Output shapes: Match the new game's action space (policy head) and value output (value head).
# - Consider different convolutional filter sizes or types if the board structure is very different.
class NeuralNetwork(nn.Module):
    """Optimized neural network for AlphaZero"""

    def __init__(self, game, num_channels=32):
        super(NeuralNetwork, self).__init__()
        self.board_size = game.board_size
        self.action_size = game.action_size
        self.num_channels = num_channels

        # Convolutional layers (adjust kernel size/padding if board size changes significantly)
        self.conv1 = nn.Conv2d(1, num_channels, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(num_channels, num_channels, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(num_channels, num_channels, 3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(num_channels, num_channels, 3, stride=1, padding=1)

        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.bn3 = nn.BatchNorm2d(num_channels)
        self.bn4 = nn.BatchNorm2d(num_channels)

        # Policy head (adjust output size to game.action_size)
        self.conv_policy = nn.Conv2d(num_channels, 32, 1)
        self.bn_policy = nn.BatchNorm2d(32)
        self.fc_policy = nn.Linear(32 * self.board_size[0] * self.board_size[1], self.action_size)

        # Value head (adjust input size if board size changes)
        self.conv_value = nn.Conv2d(num_channels, 3, 1)
        self.bn_value = nn.BatchNorm2d(3)
        self.fc_value1 = nn.Linear(3 * self.board_size[0] * self.board_size[1], 64)
        self.fc_value2 = nn.Linear(64, 1)

        self.dropout = nn.Dropout(0.3)

    def forward(self, s):
        # Common layers
        s = s.view(-1, 1, self.board_size[0], self.board_size[1])
        s = F.relu(self.bn1(self.conv1(s)))
        s = F.relu(self.bn2(self.conv2(s)))
        s = F.relu(self.bn3(self.conv3(s)))
        s = F.relu(self.bn4(self.conv4(s)))

        # Policy head
        pi = F.relu(self.bn_policy(self.conv_policy(s)))
        pi = pi.view(pi.size(0), -1)
        pi = self.dropout(pi)
        pi = self.fc_policy(pi)
        pi = F.log_softmax(pi, dim=1)

        # Value head
        v = F.relu(self.bn_value(self.conv_value(s)))
        v = v.view(v.size(0), -1)
        v = self.dropout(v)
        v = F.relu(self.fc_value1(v))
        v = self.fc_value2(v)
        v = torch.tanh(v)

        return pi, v

# --- MCTS ---
# The MCTS class is generally game-agnostic, relying on the Game and NeuralNetwork classes.
# No major changes needed here for a new game unless the MCTS algorithm itself needs modification
# (e.g., for games with very large branching factors).
class MCTS:
    """Optimized Monte Carlo Tree Search"""

    def __init__(self, game, nnet, args):
        self.game = game
        self.nnet = nnet
        self.args = args
        self.Qsa = {}  # Q values
        self.Nsa = {}  # Visit counts for state-action
        self.Ns = {}   # Visit counts for state
        self.Ps = {}   # Policy
        self.Es = {}   # Game ended
        self.Vs = {}   # Valid moves

    def get_action_prob(self, state, temp=1):
        """Get action probabilities and value estimate after MCTS simulations"""
        for _ in range(self.args.num_mcts_sims):
            self.search(state, 1)

        s = self.game.get_canonical_form(state, 1)
        # Use a more robust string representation for state keys
        s_str = np.array2string(s, separator=',', max_line_width=np.inf)
        counts = [self.Nsa.get((s_str, a), 0) for a in range(self.game.action_size)]

        if temp == 0:
            # Deterministic - choose best action
            best_actions = np.array(np.argwhere(counts == np.max(counts))).flatten()
            best_action = np.random.choice(best_actions)
            probs = np.zeros(len(counts))
            probs[best_action] = 1
        else:
            # Probabilistic based on temperature
            counts = np.array(counts, dtype=np.float64)
            counts = counts ** (1.0 / temp)
            probs = counts / counts.sum()

        # Calculate MCTS value estimate
        # Weighted average of action values based on visit counts
        total_visits = sum(counts)
        if total_visits > 0:
            mcts_value = 0
            for a in range(self.game.action_size):
                if (s_str, a) in self.Qsa and counts[a] > 0:
                    mcts_value += (counts[a] / total_visits) * self.Qsa[(s_str, a)]
        else:
            mcts_value = 0

        return probs, mcts_value

    def search(self, state, player):
        """MCTS search with neural network guidance"""
        s = self.game.get_canonical_form(state, player)
        # Use a more robust string representation for state keys
        s_str = np.array2string(s, separator=',', max_line_width=np.inf)


        if s_str not in self.Es:
            self.Es[s_str] = self.game.get_game_ended(s, 1) # Evaluate from player 1's perspective

        if self.Es[s_str] is not None:
            # Terminal node
            return self.Es[s_str] # Return the game result directly

        if s_str not in self.Ps:
            # Leaf node - expand using neural network
            with torch.no_grad():
                # Ensure input tensor has the correct shape (batch, channels, height, width)
                s_tensor = torch.FloatTensor(s).unsqueeze(0).unsqueeze(0).to(device)
                log_pi, v = self.nnet(s_tensor)
                pi = torch.exp(log_pi).cpu().numpy()[0]

            valid_moves = self.game.get_valid_moves(s)
            pi = pi * valid_moves  # Mask invalid moves
            pi_sum = np.sum(pi)
            if pi_sum > 0:
                pi = pi / pi_sum  # Renormalize
            else:
                # If all valid moves have zero probability, distribute probability evenly
                pi = valid_moves / np.sum(valid_moves)

            self.Ps[s_str] = pi
            self.Vs[s_str] = valid_moves
            self.Ns[s_str] = 0
            return v.item() # Return the value estimate

        # Internal node - select action using UCB
        valid_moves = self.Vs[s_str]
        cur_best = -float('inf')
        best_act = -1

        # UCB formula
        for a in range(self.game.action_size):
            if valid_moves[a]:
                if (s_str, a) in self.Qsa:
                    u = (self.Qsa[(s_str, a)] +
                         self.args.cpuct * self.Ps[s_str][a] *
                         math.sqrt(self.Ns[s_str]) / (1 + self.Nsa[(s_str, a)]))
                else:
                    u = (self.args.cpuct * self.Ps[s_str][a] *
                         math.sqrt(self.Ns[s_str] + 1e-8)) # Add epsilon to prevent division by zero

                if u > cur_best:
                    cur_best = u
                    best_act = a

        # Make move and recurse
        next_s, next_player = self.game.get_next_state(state, best_act, player) # Use original state here
        v = self.search(next_s, next_player)

        # Backup
        if (s_str, best_act) in self.Qsa:
            self.Qsa[(s_str, best_act)] = ((self.Nsa[(s_str, best_act)] *
                                          self.Qsa[(s_str, best_act)] + v) /
                                         (self.Nsa[(s_str, best_act)] + 1))
            self.Nsa[(s_str, best_act)] += 1
        else:
            self.Qsa[(s_str, best_act)] = v
            self.Nsa[(s_str, best_act)] = 1

        self.Ns[s_str] += 1
        return v

# --- AlphaZero Trainer ---
# The trainer orchestrates self-play, training, and evaluation.
# Mostly game-agnostic, but requires the Game, NeuralNetwork, MCTS, and Args classes.
class AlphaZeroTrainer:
    """Main AlphaZero trainer"""

    def __init__(self):
        self.game = ConnectFourGame()
        self.args = Args()
        self.nnet = NeuralNetwork(self.game).to(device)
        self.optimizer = optim.Adam(self.nnet.parameters(), lr=self.args.lr,
                                   weight_decay=self.args.weight_decay)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=30, gamma=0.1)
        self.scaler = GradScaler('cuda')

        self.train_examples_history = deque([], maxlen=self.args.max_memory_size // self.args.num_eps) # Store examples from recent episodes

        # Heap: (-elo, iteration, model_copy)
        self.model_heap = []
        heapq.heappush(self.model_heap, (-1000, 0, copy.deepcopy(self.nnet)))  # Iter 0, base net


    def execute_episode(self):
        """Execute one episode of self-play (fast version)"""
        train_examples = []
        state = self.game.get_initial_state()
        current_player = 1
        episode_step = 0

        # Create a single MCTS instance for the whole episode
        mcts = MCTS(self.game, self.nnet, self.args)

        while True:
            episode_step += 1
            canonical_state = self.game.get_canonical_form(state, current_player)
            temp = int(episode_step < self.args.temp_threshold)

            # Get action probabilities from MCTS
            pi, _ = mcts.get_action_prob(canonical_state, temp=temp)

            # Store training example (value target will be filled in later)
            sym = self.game.get_symmetries(canonical_state, pi)
            for s, p in sym:
                train_examples.append([s, current_player, p])  # no value yet

            # Sample action from probabilities
            action = np.random.choice(len(pi), p=pi)

            # Make move
            state, current_player = self.game.get_next_state(state, action, current_player)

            # Check if game ended
            result = self.game.get_game_ended(state, current_player)
            if result is not None:
                # Assign final outcome z to all training examples
                final_examples = []
                for s, player, p in train_examples:
                    z = result * player  # flip perspective depending on who played
                    final_examples.append([s, player, p, z])
                return final_examples

    def train(self, examples):
        """Train the neural network"""
        random.shuffle(examples)

        total_loss = 0
        num_batches = 0

        self.nnet.train() # Set model to training mode

        for epoch in range(self.args.epochs):
            batch_idx = 0

            while batch_idx < len(examples):
                batch_examples = examples[batch_idx:batch_idx + self.args.batch_size]
                batch_idx += self.args.batch_size

                # Ensure input tensor has the correct shape (batch, channels, height, width)
                boards = torch.FloatTensor(np.array([ex[0] for ex in batch_examples])).unsqueeze(1).to(device)
                target_pis = torch.FloatTensor(np.array([ex[2] for ex in batch_examples])).to(device)
                target_vs = torch.FloatTensor(np.array([ex[3] for ex in batch_examples])).to(device)

                self.optimizer.zero_grad()

                with autocast(device_type='cuda'):
                    out_pi, out_v = self.nnet(boards)
                    l_pi = self.loss_pi(target_pis, out_pi)
                    l_v = self.loss_v(target_vs, out_v)
                    total_loss_batch = l_pi + l_v

                self.scaler.scale(total_loss_batch).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()

                total_loss += total_loss_batch.item()
                num_batches += 1

        self.scheduler.step()
        avg_loss = total_loss / num_batches

        self.nnet.eval()
        return avg_loss

    def loss_pi(self, targets, outputs):
        return -torch.sum(targets * outputs) / targets.size()[0]

    def loss_v(self, targets, outputs):
        return torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0]

    def play_game(self, nnet1, nnet2):
        """Play a single game between two nets (fast version)"""
        state = self.game.get_initial_state()
        player = 1

        # Create MCTS instances once per player
        mcts1 = MCTS(self.game, nnet1, self.args)
        mcts2 = MCTS(self.game, nnet2, self.args)

        while True:
            nnet = nnet1 if player == 1 else nnet2
            mcts = mcts1 if player == 1 else mcts2

            pi, _ = mcts.get_action_prob(self.game.get_canonical_form(state, player), temp=0)
            action = np.argmax(pi)

            state, player = self.game.get_next_state(state, action, player)

            result = self.game.get_game_ended(state, player)
            if result is not None:
                if result == 0:
                    return 0.5
                return 1 if result == 1 else 0

    def update_elo(self, elo_a, elo_b, score_a, k=32):
        """Update Elo ratings after a game. score_a is 1 if A wins, 0.5 for draw, 0 if A loses."""
        expected_a = 1 / (1 + 10 ** ((elo_b - elo_a) / 400))
        expected_b = 1 - expected_a
        new_elo_a = elo_a + k * (score_a - expected_a)
        new_elo_b = elo_b + k * ((1 - score_a) - expected_b)
        return new_elo_a, new_elo_b

    def evaluate_models(self, new_model, iteration):
        # Pop all previous models from the heap
        previous_models = []
        while self.model_heap:
            previous_models.append(heapq.heappop(self.model_heap))

        new_elo = 1000  # starting Elo for new model

        # Play two games against each previous model: once as first, once as second
        for i, (neg_elo, iter_idx, old_model) in enumerate(previous_models):
            # Game 1: new_model goes first
            score1 = self.play_game(new_model, old_model)
            new_elo, old_elo = self.update_elo(new_elo, -neg_elo, score1)

            # Game 2: new_model goes second
            score2 = self.play_game(old_model, new_model)
            # Swap score for perspective of new_model
            new_elo, old_elo = self.update_elo(new_elo, old_elo, 1 - score2)

            # Update the heap tuple with the new opponent Elo
            previous_models[i] = (-old_elo, iter_idx, old_model)

        # Push all previous models back into the heap
        for model in previous_models:
            heapq.heappush(self.model_heap, model)

        # Push new model into heap
        heapq.heappush(self.model_heap, (-new_elo, iteration, new_model))

        # Optionally, set self.nnet to the model with highest Elo
        top_neg_elo, _, top_model = max(self.model_heap)
        self.nnet = top_model

        return new_elo

    def learn(self):
        """Main training loop"""
        print("Starting AlphaZero training...")

        for i in range(1, self.args.num_iters + 1):
            print(f"\n=== Iteration {i}/{self.args.num_iters} ===")

            # Self-play
            iteration_train_examples = deque([], maxlen=self.args.max_memory_size)

            print("Collecting self-play games...")
            for eps in range(self.args.num_eps):
                iteration_train_examples.extend(self.execute_episode()) # Use extend
                if (eps + 1) % 10 == 0: # Print progress less frequently
                    print(f"Episode {eps + 1}/{self.args.num_eps}")

            # Add to training history
            self.train_examples_history.append(iteration_train_examples)

            # Prepare training data from history
            train_examples = []
            for e in self.train_examples_history:
                train_examples.extend(e)

            # Train
            print(f"Training on {len(train_examples)} examples...")
            avg_loss = self.train(train_examples)

            # Save and evaluate
            new_model = copy.deepcopy(self.nnet)
            elo = self.evaluate_models(new_model, i)

            print("\n--- Elo Leaderboard ---")
            # Convert heap into a sorted list by Elo (highest first)
            sorted_models = sorted(self.model_heap, key=lambda x: x[0])  # x[0] is -elo
            for neg_elo, iteration, _ in sorted_models:
                print(f"Iteration {iteration}: {-neg_elo:.2f}")
            print("-----------------------\n")

            # Pick the best model (highest Elo = smallest neg_elo in heap)
            neg_elo, best_iter, best_model = self.model_heap[0]

            # Update current nnet to the best model
            #self.nnet = best_model

            # Reinitialize optimizer and scaler for the new model
            self.optimizer = torch.optim.Adam(self.nnet.parameters(), lr=self.args.lr)
            self.scaler = torch.amp.GradScaler('cuda')



        print("Training completed!")
        return self.model_heap

def demo_game():
    """Demo game between trained AlphaZero and Optimal AI and evaluate specific state"""
    print("\n=== Demo Game and Evaluation ===")
    trainer = AlphaZeroTrainer()

    # Quick training for demo (reduce iterations for faster demo)
    trainer.args.num_iters = 100 # Reduced iterations
    trainer.args.num_eps = 50 # Reduced episodes
    trainer.args.num_mcts_sims = 100 # Reduced simulations per move

    print("Training AlphaZero (reduced for demo)...")
    model_heap = trainer.learn()

    return model_heap

In [None]:
self = AlphaZeroTrainer()

iteration_train_examples = deque([], maxlen=self.args.max_memory_size)


self.args.num_eps = 50 # Reduced episodes
self.args.num_mcts_sims = 100 # Reduced simulations per move

print("Collecting self-play games...")
for eps in range(self.args.num_eps):
    iteration_train_examples.extend(self.execute_episode()) # Use extend
    if (eps + 1) % 10 == 0: # Print progress less frequently
        print(f"Episode {eps + 1}/{self.args.num_eps}")

Collecting self-play games...
Episode 10/50
Episode 20/50
Episode 30/50
Episode 40/50
Episode 50/50


In [None]:
print("Collecting self-play games...")
for eps in range(self.args.num_eps):
    iteration_train_examples.extend(self.execute_episode()) # Use extend
    if (eps + 1) % 10 == 0: # Print progress less frequently
        print(f"Episode {eps + 1}/{self.args.num_eps}")

Collecting self-play games...
Episode 10/50
Episode 20/50
Episode 30/50
Episode 40/50
Episode 50/50


In [None]:
train_examples = []
for e in iteration_train_examples:
    train_examples.extend(e)

In [None]:
iteration_train_examples[0]

[array([[ 0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0],
        [ 0,  0, -1,  0,  0,  0,  0],
        [ 1,  1, -1,  1, -1, -1,  0]], dtype=int8),
 -1,
 array([0.17322835, 0.11811024, 0.14173228, 0.08661417, 0.24409449,
        0.11023622, 0.12598425]),
 0]

In [None]:
all_nns = []
names = []
for num in [10/150,25/150,50/150,75/150,100/150,125/150,150/150]:
    for num_steps in [25,50,75,100,200]:
        for lr in [0.001, 0.005, 0.01]:
            nnet = NeuralNetwork(self.game).to(device)
            optimizer = optim.Adam(nnet.parameters(), lr=lr,
                                    weight_decay=self.args.weight_decay)
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=1)
            scaler = GradScaler('cuda')

            total_loss = 0
            num_batches = 0

            nnet.train() # Set model to training mode
            random.shuffle(iteration_train_examples)

            for epoch in range(num_steps):
                batch_idx = 0

                while batch_idx < len(iteration_train_examples)*num:
                    batch_examples = list(iteration_train_examples)[batch_idx:batch_idx + self.args.batch_size]
                    batch_idx += self.args.batch_size

                    # Ensure input tensor has the correct shape (batch, channels, height, width)
                    boards = torch.FloatTensor(np.array([ex[0] for ex in batch_examples])).unsqueeze(1).to(device)
                    target_pis = torch.FloatTensor(np.array([ex[2] for ex in batch_examples])).to(device)
                    target_vs = torch.FloatTensor(np.array([ex[3] for ex in batch_examples])).to(device)

                    optimizer.zero_grad()

                    with autocast(device_type='cuda'):
                        out_pi, out_v = nnet(boards)
                        l_pi = self.loss_pi(target_pis, out_pi)
                        l_v = self.loss_v(target_vs, out_v)
                        total_loss_batch = l_pi + l_v

                    scaler.scale(total_loss_batch).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    total_loss += total_loss_batch.item()
                    num_batches += 1
            names.append(f"num={int(num*150)}, num_steps={num_steps}, lr={lr}")
            print(names[-1])
            all_nns.append(copy.deepcopy(nnet))


num=10, num_steps=25, lr=0.001
num=10, num_steps=25, lr=0.005
num=10, num_steps=25, lr=0.01
num=10, num_steps=50, lr=0.001
num=10, num_steps=50, lr=0.005
num=10, num_steps=50, lr=0.01
num=10, num_steps=75, lr=0.001
num=10, num_steps=75, lr=0.005
num=10, num_steps=75, lr=0.01
num=10, num_steps=100, lr=0.001
num=10, num_steps=100, lr=0.005
num=10, num_steps=100, lr=0.01
num=10, num_steps=200, lr=0.001
num=10, num_steps=200, lr=0.005
num=10, num_steps=200, lr=0.01
num=25, num_steps=25, lr=0.001
num=25, num_steps=25, lr=0.005
num=25, num_steps=25, lr=0.01
num=25, num_steps=50, lr=0.001
num=25, num_steps=50, lr=0.005
num=25, num_steps=50, lr=0.01
num=25, num_steps=75, lr=0.001
num=25, num_steps=75, lr=0.005
num=25, num_steps=75, lr=0.01
num=25, num_steps=100, lr=0.001
num=25, num_steps=100, lr=0.005
num=25, num_steps=100, lr=0.01
num=25, num_steps=200, lr=0.001
num=25, num_steps=200, lr=0.005
num=25, num_steps=200, lr=0.01
num=50, num_steps=25, lr=0.001
num=50, num_steps=25, lr=0.005
num=50

In [None]:
elos = [1000 for _ in all_nns]
ijs = [(i,j) for i in range(len(all_nns)) for j in range(i+1, len(all_nns))]
for _ in range(10):
    random.shuffle(ijs)
    for i,j in ijs:
        nnet1 = all_nns[i]
        nnet2 = all_nns[j]
        elo1 = elos[i]
        elo2 = elos[j]

        nnet1.eval()
        nnet2.eval()
        score1 = self.play_game(nnet1,nnet2)
        elo1, elo2 = self.update_elo(elo1, elo2, score1)

        # Game 2: new_model goes second
        score2 = self.play_game(nnet2, nnet1)
        # Swap score for perspective of new_model
        elo1, elo2 = self.update_elo(elo1, elo2, 1 - score2)

        elos[i] = elo1
        elos[j] = elo2
    for k in np.argsort(-np.array(elos))[:10]:
        print(names[k], elos[k])

In [None]:
# --- Training Arguments ---
# Adjust these parameters based on the complexity of the new game.
# More complex games may require more simulations, episodes, iterations, and a larger batch size.
class Args:
    """Training arguments"""
    def __init__(self):
        # Self-play
        self.num_iters = 100
        self.num_eps = 50  # Episodes per iteration
        self.temp_threshold = 15 # Number of moves to use temperature sampling in self-play

        # MCTS
        self.num_mcts_sims = 5 # Number of MCTS simulations per move
        self.cpuct = 1.0 # Exploration constant

        # Training
        self.epochs = 50
        self.batch_size = 256
        self.lr = 0.01
        self.dropout = 0.3
        self.weight_decay = 1e-4

        # Memory
        self.max_memory_size = 100000 # Maximum number of training examples

        # Evaluation
        self.arena_compare = 40 # Number of games to compare new vs old network
        self.update_threshold = 0.6 # Minimum win rate against old network to update

In [None]:
if __name__ == "__main__":
    # Run the demo
    demo_game()