# Breaking Isolation: A Localized Monte Carlo Tree Search Approach
Minjae Kim, Shu Yang Wei, Stephen Yang

- Env specification is detailed in `env.yml`
- Bots are detailed in `bots.py`
- Isolation game environment is detailed in `isolation_env.py`
- A human-playable Isolation game environment is detailed in `isolation_human.py`
- A DQN implementation using the Tianshou RL library is detailed in `tianshou_train.py`

In [1]:
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from itertools import product
import seaborn as sns
import datetime
import numpy as np
import copy
import ast

import isolation_env
from bots import RandomBot, HeuristicBot, MCTSBot, MCTSRicherBot, MCTSBiggerBot, DQNBot, QStarBot




pygame 2.5.2 (SDL 2.30.10, Python 3.8.20)
Hello from the pygame community. https://www.pygame.org/contribute.html


## Training the MCTS bot

In [2]:
BOARD_SIZE = (6, 8)

def run_one_game(bots, env, test=False):

    for agent in env.agent_iter():
        id = ['player_0', 'player_1'].index(agent)
        observation, reward, termination, truncation, info = env.last()
        
        if termination or truncation:
            if not test:
                bots[id].learn(observation, reward)
            action = None
            if reward == 1:
                winner = agent

        else:
            action = bots[id].take_step(observation)

        env.step(action)
    
    return winner

def run_games(bots, num_games, test=False, shaping=False):
    env = isolation_env.env(board_size=BOARD_SIZE, shaping=False, render_mode=None)
    env.reset()

    batch_win_rates = []
    batch_wins = 0
    batch_size = 100

    for i in range(num_games):
        winner = run_one_game(bots, env, test=test)
        env.reset()
        
        if winner == "player_0":
            batch_wins += 1

        if i % batch_size == batch_size - 1:
            batch = i // batch_size
            batch_win_rate = batch_wins / batch_size

            timestamp = datetime.datetime.now().strftime("%H:%M:%S")
            print("{} Batch {} Win Percentage: {:.0%}".format(timestamp, batch, batch_win_rate))
            
            batch_win_rates.append(batch_win_rate)
            batch_wins = 0

    env.close()
    return batch_win_rates

In [None]:
mcts_bots = []
mcts_train_data = []
hb = HeuristicBot(board_size=BOARD_SIZE)

for _ in range(25):
    b = MCTSBot(board_size=BOARD_SIZE)
    mcts_bots.append(b)
    mcts_train_data.append(run_games([b, hb], 30000))

In [None]:
# Plot train performance of all MCTS bots
sns.set(style="whitegrid")

data_array = np.array(mcts_train_data)
mean = np.mean(data_array, axis=0)

for l in data_array:
    plt.plot(l, color='lightblue', alpha=0.2)

plt.plot(mean, color='blue', linewidth=2, label='Mean')

plt.xlabel('Batch')
plt.ylabel('Win % against HeuristicBot')
plt.title('Train Performance of MCTSBot, All Trains')
plt.legend()
plt.show()

In [None]:
# Plot train performance of select MCTS bots
cutoff = 0.4

data_array = np.array([l for l in mcts_train_data if l[-1] > cutoff])
mean = np.mean(data_array, axis=0)

for l in data_array:
    plt.plot(l, color='lightblue', alpha=0.4)

plt.plot(mean, color='blue', linewidth=2, label='Mean')

plt.xlabel('Batch')
plt.ylabel('Win % against HeuristicBot')
plt.title('Train Performance of MCTSBot, Selected')
plt.legend()
plt.show()

In [None]:
rb = RandomBot()
hb = HeuristicBot(board_size=BOARD_SIZE)

wins_random = []
wins_heuristic = []

for b in mcts_bots:
    b.exploration_weight = 0

    wins_random.append(run_games([b, rb], 500, test=True))
    wins_heuristic.append(run_games([b, hb], 500, test=True))

In [None]:
# Plot test performance (no exploration)
c = list(zip(wins_random, wins_heuristic))
s = sorted(c, key=lambda x: x[1], reverse=True)
sorted_random, sorted_heuristic = zip(*s)

xs = np.arange(len(sorted_random))
plt.scatter(xs, sorted_random, label='Wins against RandomBot')
plt.scatter(xs, sorted_heuristic, label='Wins against HeuristicBot')

plt.xlabel('Bot')
plt.ylabel('% Wins in 500 Games')
plt.legend()
plt.grid(True)
plt.title("Test Performance of 25 MCTSBots against HeuristicBot")
plt.show()

In [None]:
# Plot state visits
bins = [1, 2, 10, 100, 1000, 1000000, 10000000]
bin_labels = ['0', '1', '2', '3-10', '11-100', '101-1000', '1000+']

counts, _ = np.histogram(list(mcts_bots[0].STATS_MOVE_VISITED.values()), bins=bins)
counts = [1044 - len(mcts_bots[0].STATS_MOVE_VISITED)] + list(counts)

plt.bar(bin_labels, counts, alpha=0.7)

plt.xlabel('# of Visits to State')
plt.ylabel('# of States (Log Scale)')
plt.title('MCTSBot, # of Move States by Visits')

plt.xticks(rotation=45)
def format_ticks(x, pos):
    return f'{int(x):,}'
formatter = FuncFormatter(format_ticks)
plt.gca().yaxis.set_major_formatter(formatter)
plt.tight_layout()
plt.show()

## Training MCTSRicherBot

In [None]:
# Count number of star states
# And number of canonical rotations
ROTATIONS = np.array([
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # 0° (no rotation)
    [7, 3, 6, 10, 0, 2, 9, 11, 1, 5, 8, 4],  # 90° CW
    [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],  # 180° CW
    [4, 8, 5, 1, 11, 9, 2, 0, 10, 6, 3, 7]   # 270° CW
])

def _canonical_rotation( array):
    """Find the lexicographically smallest rotation of a given array"""
    rotated = [tuple(array[rotation]) for rotation in ROTATIONS]
    return np.array(min(rotated))

binary_combinations = list(product([0, 1], repeat=12))
print("Total permutations for MCTSBot (0/1): ", len(binary_combinations))

unique_patterns = set()
for comb in binary_combinations:
    array = np.array(comb)
    canonical = _canonical_rotation(array)
    unique_patterns.add(tuple(canonical))


print("Size of state space for MCTSBot (0/1): ", len(unique_patterns))

unique_patterns = set()

permutations = list(product([-2, 0, 1], repeat=12))

permutations_with_one_3 = []
for i in range(12):
    for perm in product([-2, 0, 1], repeat=11):
        # Insert '3' at position i
        new_perm = list(perm)
        new_perm.insert(i, -1)
        permutations_with_one_3.append(tuple(new_perm))

# Combine both lists
binary_combinations = permutations + permutations_with_one_3
print("Total permutations for MCTSRicherBot (0/1): ", len(binary_combinations))

unique_patterns = set()
for comb in binary_combinations:
    array = np.array(comb)
    canonical = _canonical_rotation(array)
    unique_patterns.add(tuple(canonical))

print("Size of state space for MCTSRicherBot (-2/-1/0/1): ", len(unique_patterns))

In [None]:
richer_bot = MCTSRicherBot(board_size=BOARD_SIZE)
data = []

In [None]:
for _ in range(30):
    data.append(run_games([richer_bot, hb], 10000))

In [None]:
# Plot train performance
data_flat = [i for l in data for i in l]
data_array = np.array(data_flat)
game_batch = 5
batched = data_array[:(len(data_array) // game_batch)*game_batch].reshape(-1, game_batch)

m = batched.mean(axis=1)
lb = np.quantile(batched, .10, axis=1)
ub = np.quantile(batched, .90, axis=1)

plt.fill_between(range(len(lb)), lb, ub, color='lightgray', alpha=0.5, label='10th to 90th Percentile')
plt.plot(m, label='mean')
plt.ylabel('Win % against HeuristicBot')
plt.xlabel("Batch, each batch contains {} games".format(game_batch * 100))
plt.title('Train Performance of MCTSRicherBot, One Train')

In [None]:
# Plot state visits
bins = [1, 2, 10, 100, 1000, 1000000, 10000000]
bin_labels = ['0', '1', '2', '3-10', '11-100', '101-1000', '1000+']

# Count the number of keys in each bin
counts, _ = np.histogram(list(richer_bot.STATS_MOVE_VISITED.values()), bins=bins)
counts = [664497 - len(richer_bot.STATS_MOVE_VISITED)] + list(counts)

plt.bar(bin_labels, counts, alpha=0.7)

plt.xlabel('# of Visits to State')
plt.ylabel('# of States (Log Scale)')
plt.title('MCTSRicherBot, # of Move States by Visits')

plt.yscale('log')
plt.xticks(rotation=45)
def format_ticks(x, pos):
    return f'{int(x):,}'
formatter = FuncFormatter(format_ticks)
plt.gca().yaxis.set_major_formatter(formatter)
plt.tight_layout()
plt.show()

## Training MCTSBiggerBot

In [None]:
# Count number of star states
# And number of canonical rotations
ROTATIONS = np.array([
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],  # 0° (no rotation)
            [14, 8, 13, 19, 3, 7, 12, 18, 22, 0, 2, 6, 17, 21, 23, 1, 5, 11, 16, 20, 4, 10, 15, 9],  # 90° CW
            [23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],  # 180° CW
            [9, 15, 10, 4, 20, 16, 11, 5, 1, 23, 21, 17, 6, 2, 0, 22, 18, 12, 7, 3, 19, 13, 8, 14]   # 270° CW
        ])


unique_patterns = set()

# Generate binary combinations on the fly - iterate over all possible bin numbers
for i in range(2**24): 
    # Convert the number to binary representation with leading zeros
    binary = [int(x) for x in f"{i:0{24}b}"]
    array = np.array(binary)
    
    # Get the canonical rotation
    canonical = _canonical_rotation(array)
    
    # Add the canonical pattern to the set
    unique_patterns.add(tuple(canonical))

print(len(unique_patterns))

In [None]:
bigger_bot = MCTSBiggerBot(board_size=BOARD_SIZE)
data = []

In [None]:
for _ in range(30):
    data.append(run_games([bigger_bot, hb], 10000))

In [None]:
# Plot train performance
data_flat = [i for l in data for i in l]
data_array = np.array(data_flat)
game_batch = 5
batched = data_array[:(len(data_array) // game_batch)*game_batch].reshape(-1, game_batch)

m = batched.mean(axis=1)
lb = np.quantile(batched, .10, axis=1)
ub = np.quantile(batched, .90, axis=1)

plt.fill_between(range(len(lb)), lb, ub, color='lightgray', alpha=0.5, label='10th to 90th Percentile')
plt.plot(m, label='mean')
plt.ylabel('Win % against HeuristicBot')
plt.xlabel("Batch, each batch contains {} games".format(game_batch * 100))
plt.title('Train Performance of MCTSBiggerBot, One Train')

In [None]:
# Plot state visits
bins = [1, 2, 10, 100, 1000, 1000000, 10000000]
bin_labels = ['0', '1', '2', '3-10', '11-100', '101-1000', '1000+']

# Count the number of keys in each bin
counts, _ = np.histogram(list(bigger_bot.STATS_MOVE_VISITED.values()), bins=bins)
counts = [4195360 - len(bigger_bot.STATS_MOVE_VISITED)] + list(counts)

plt.bar(bin_labels, counts, alpha=0.7)

plt.xlabel('# of Visits to State')
plt.ylabel('# of States (Log Scale)')
plt.title('MCTSBiggerBot, # of Move States by Visits')

plt.yscale('log')
plt.xticks(rotation=45)
def format_ticks(x, pos):
    return f'{int(x):,}'
formatter = FuncFormatter(format_ticks)
plt.gca().yaxis.set_major_formatter(formatter)
plt.tight_layout()
plt.show()