In [1]:
import chess
import chess.engine
import os
import copy
from pathlib import Path
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn as nn
from torch.optim import AdamW
from tqdm import tqdm
import numpy as np
import random
import pandas as pd

In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

chess_dict = {
    'P_l': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'P_r': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'p':   [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'N_l': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'N_r': [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'n':   [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'B_l': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    'B_r': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    'b':   [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    'R_l': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    'R_r': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    'r':   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    'q':   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    'Q':   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    'k':   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    'K':   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    '.':   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
}

num2move = {}
move2num = {}

counter = 0
for from_sq in range(64):
    for to_sq in range(64):
        num2move[counter] = chess.Move(from_sq, to_sq)
        move2num[chess.Move(from_sq, to_sq)] = counter
        counter += 1
        
        
def translate_board(board):
    pgn = board.epd()
    foo = []
    pieces = pgn.split(" ", 1)[0]
    rows = pieces.split("/")
    for row in rows:
        foo2 = []
        for index, thing in enumerate(row):
            if thing.isdigit():
                for i in range(0, int(thing)):
                    foo2.append(chess_dict['.'])
            else:
                if thing not in ["P", "N", "R", "B"]:
                    foo2.append(chess_dict[thing])
                else:
                    if(index < 4):
                        foo2.append(chess_dict[thing+"_l"])
                    else:
                        foo2.append(chess_dict[thing+"_r"])
        foo.append(foo2)
    return np.array(foo)
        
def find_piece_key(piece_representation):
    index = np.argmax(piece_representation)
    piece_key = ""
    for key in chess_dict:
        if index == np.argmax(chess_dict[key]):
            piece_key = key
            break
    return piece_key


def can_move(move, agent_num, translated_board):
    from_square = move.from_square
    piece = translated_board[7-from_square//8][from_square % 8]
    piece_key = find_piece_key(piece)
    if agent_num == 0 and "_l" in piece_key:
        return True
    if agent_num == -1 and "_r" in piece_key:
        return True
    if "Q" in piece_key or "K" in piece_key:
        return True
    return False


def filter_legal_moves(board, logits, agent_num, translated_board):
    filter_mask = np.zeros(logits.shape)
    legal_moves = board.legal_moves
    num_legal_mobes = list(board.legal_moves)
    for legal_move in legal_moves:
        if agent_num is not None:
            if(can_move(legal_move, agent_num, translated_board)):
                from_square = legal_move.from_square
                to_square = legal_move.to_square
                idx = move2num[chess.Move(from_square, to_square)]
                filter_mask[idx] = 1
        else:
            from_square = legal_move.from_square
            to_square = legal_move.to_square
            idx = move2num[chess.Move(from_square, to_square)]
            filter_mask[idx] = 1
    new_logits = logits*filter_mask
    return new_logits


def check_legal_move(board, move):
    legal_moves = board.legal_moves
    legal = False
    for legal_move in legal_moves:
        from_square = legal_move.from_square
        to_square = legal_move.to_square
        if from_square == move.from_square and to_square == move.to_square:
            legal = True
            break

    return legal


In [3]:
class agent_dqn(nn.Module):
  def __init__(self, obs_shape, n_actions):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(obs_shape[0], 64, kernel_size=2,stride=2),
        nn.Conv2d(64, 128, kernel_size=2,stride=2),
        nn.Conv2d(128, 256, kernel_size=2,stride=2),
    )
    conv_out_size = self._get_conv_out(obs_shape)
 
    self.fc = nn.Linear(conv_out_size, n_actions)

  def _get_conv_out(self, shape):
    conv_out = self.conv(torch.zeros(1, *shape))
    return int(np.prod(conv_out.size()))

  def forward(self, x):
    x = self.conv(x.float()).view(x.size()[0], -1)
    x = self.fc(x)
    return torch.softmax(x, dim=1)

class agent_double_dqn(nn.Module):
  def __init__(self, obs_shape, n_actions):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(obs_shape[0], 64, kernel_size=2,stride=2),
        nn.Conv2d(64, 128, kernel_size=2,stride=2),
        nn.Conv2d(128, 256, kernel_size=2,stride=2),
    )
    conv_out_size = self._get_conv_out(obs_shape)
 
    self.fc = nn.Linear(conv_out_size, n_actions)

  def _get_conv_out(self, shape):
    conv_out = self.conv(torch.zeros(1, *shape))
    return int(np.prod(conv_out.size()))

  def forward(self, x):
    x = self.conv(x.float()).view(x.size()[0], -1)
    x = self.fc(x)
    return torch.softmax(x, dim=1)

class agent_duelling_dqn(nn.Module):
  def __init__(self, obs_shape, n_actions):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(obs_shape[0], 64, kernel_size=2,stride=2),
        nn.Conv2d(64, 128, kernel_size=2,stride=2),
        nn.Conv2d(128, 256, kernel_size=2,stride=2),
    )
    conv_out_size = self._get_conv_out(obs_shape)
    
    self.fc = nn.Linear(conv_out_size, 512)
    self.fc_adv = nn.Linear(512, n_actions) 
    self.fc_value = nn.Linear(512, 1)

  def _get_conv_out(self, shape):
    conv_out = self.conv(torch.zeros(1, *shape))
    return int(np.prod(conv_out.size()))

  def forward(self, x):
    x = self.conv(x.float()).view(x.size()[0], -1)
    x = self.fc(x)
    adv = self.fc_adv(x)
    value = self.fc_value(x)
    x = value + adv - torch.mean(adv, dim=1, keepdim=True)
    return torch.softmax(x, dim=1)

In [4]:
def policy(state, env, q_network, agent_num = None, epsilon=0.):
    if torch.rand(1) < epsilon:
        action_probs = 1 * 2 + np.random.uniform(0, 1, 4096)
        action_space = filter_legal_moves(env.board, action_probs, agent_num, env.translated_board)
        action_space = torch.from_numpy(action_space)
        action = torch.argmax(action_space, dim=-1, keepdim=True)
        move = num2move[action.item()]
        return action, move
    else:
        action_probs = q_network(state).detach()
        action_space = filter_legal_moves(env.board, action_probs[0], agent_num, env.translated_board)
        action =  torch.argmax(action_space, dim=-1, keepdim=True)

        move = num2move[action.item()]
        return action, move

class ChessEnv():
    def __init__(self):
        pass

    def reset(self):
        self.board = chess.Board()
        self.translated_board = translate_board(self.board)
        self.next_agent = 0
        return self.translated_board
    
    def get_state(self):
        return self.board

    def get_next_agent(self):
        return self.next_agent


    def step(self, action):
        self.board.push(action)
        self.update_translated_borad(action)
        self.next_agent = ~self.next_agent
        state_next = self.translated_board
        self.done = self.board.is_checkmate() 
        if chess.Status.PAWNS_ON_BACKRANK == self.board.status():
            self.done = True 
        is_game_over = self.board.is_insufficient_material()
        return state_next, None, self.done, None, is_game_over
    
    
    def preprocess(self, board):
        pgn = board.epd()
        processed_board = []
        pieces = pgn.split(" ", 1)[0]
        rows = pieces.split("/")
        
        for row in rows:
            processed_row = []
            for index, element in enumerate(row):
                if element.isdigit():
                    for i in range(0, int(element)):
                        processed_row.append(chess_dict['.'])
                else:
                    if element not in ["P", "N", "R", "B"]:
                        processed_row.append(chess_dict[element])
                    else:
                        if(index < 4):
                            processed_row.append(chess_dict[element+"_l"])
                        else:
                            processed_row.append(chess_dict[element+"_r"])
            processed_board.append(processed_row)
        return np.array(processed_board)
    
    def update_translated_borad(self, action):
        from_square = action.from_square
        to_square = action.to_square
        tmp = self.translated_board[7-from_square//8][from_square % 8].copy()
        self.translated_board[7-from_square//8][from_square %
                                                8] = self.translated_board[7-to_square//8][to_square % 8]
        self.translated_board[7-to_square//8][to_square % 8] = tmp



In [5]:
env = ChessEnv()
state = env.reset()
obs_size = state.shape
num_actions = 4096

def load(path, n_actions, state_size, model_name):
    if model_name == "DQN":
            model = agent_dqn(n_actions,state_size).to(device)
    if model_name == "DOUBLE_DQN":
            model = agent_double_dqn(n_actions,state_size).to(device)
    if model_name == "DUELLING_DQN":
            model = agent_duelling_dqn(n_actions,state_size).to(device)
    model.load_state_dict(torch.load(path))
    return model 

In [6]:
def run(episodes, team1_path, team1_name,  team2_path, team2_name):
    agent1 =  load(team1_path, obs_size,num_actions, team1_name)
    agent2 =  load(team1_path, obs_size,num_actions, team1_name)
    agent3 =  load(team2_path,  obs_size,num_actions, team2_name)
    results = {"team_1_wins":0, "team_2_wins":0, "draw":0}
    for i in tqdm(range(episodes)):
        state = env.reset()
        state = torch.from_numpy(state).unsqueeze(dim=0)
        next_agent = env.get_next_agent()
        done = False
        game_over = False
        agent_num = 0
        step = 0
        
        while not done and not game_over:
            step += 1
            state = state.float().to(device)
            if next_agent == 0:
                action, move = policy(state,env, agent1, None, 0.05)
                agent_num = ~agent_num
                if action == 0:
                        action, move = policy(state,env, agent2, ~agent_num, 0.05)
            else:
                action, move = policy(state,env, agent3, None, 0.05)
        
            next_state, _, done, _, game_over = env.step(move)
            if game_over:
                results["draw"] += 1
            if  done:
                if agent_num == 0:
                    results["team_1_wins"] += 1
                else:
                    results["team_2_wins"] += 1
            next_state = torch.from_numpy(next_state).unsqueeze(dim=0).float()
            state = next_state
            next_agent = env.get_next_agent()
    return results

In [7]:
results = run(200, "./models/multi-agent-deep-q-learning","DQN", "./models/one-agent-deep-q-learning", "DQN")
print("multi-agent-deep-q-learning","one-agent-deep-q-learning",results)
results = run(200, "./models/multi-agent-deep-q-learning","DQN", "./models/one-agent-double-q-learning", "DOUBLE_DQN")
print("multi-agent-deep-q-learning","one-agent-double-q-learning",results)
results = run(200, "./models/multi-agent-deep-q-learning","DQN", "./models/one-agent-duelling-q-learning", "DUELLING_DQN")
print("multi-agent-deep-q-learning","one-agent-duelling-q-learning",results)

results = run(200, "./models/multi-agent-double-q-learning","DOUBLE_DQN", "./models/one-agent-deep-q-learning", "DQN")
print("multi-agent-double-q-learning","one-agent-deep-q-learning",results)
results = run(200, "./models/multi-agent-double-q-learning","DOUBLE_DQN", "./models/one-agent-double-q-learning", "DOUBLE_DQN")
print("multi-agent-double-q-learning","one-agent-double-q-learning",results)
results = run(200, "./models/multi-agent-double-q-learning","DOUBLE_DQN", "./models/one-agent-duelling-q-learning", "DUELLING_DQN")
print("multi-agent-double-q-learning","one-agent-duelling-q-learning",results)

results = run(200, "./models/multi-agent-duelling-q-learning","DUELLING_DQN", "./models/one-agent-deep-q-learning", "DQN")
print("multi-agent-duelling-q-learning","one-agent-deep-q-learning",results)
results = run(200, "./models/multi-agent-duelling-q-learning","DUELLING_DQN", "./models/one-agent-double-q-learning", "DOUBLE_DQN")
print("multi-agent-duelling-q-learning","one-agent-double-q-learning",results)
results = run(200, "./models/multi-agent-duelling-q-learning","DUELLING_DQN", "./models/one-agent-duelling-q-learning", "DUELLING_DQN")
print("multi-agent-duelling-q-learning","one-agent-duelling-q-learning",results)



100%|██████████| 200/200 [10:30<00:00,  3.15s/it]


multi-agent-deep-q-learning one-agent-deep-q-learning {'team_1_wins': 95, 'team_2_wins': 94, 'draw': 11}


100%|██████████| 200/200 [12:07<00:00,  3.64s/it]


multi-agent-deep-q-learning one-agent-double-q-learning {'team_1_wins': 96, 'team_2_wins': 97, 'draw': 7}


100%|██████████| 200/200 [17:08<00:00,  5.14s/it]


multi-agent-deep-q-learning one-agent-duelling-q-learning {'team_1_wins': 88, 'team_2_wins': 101, 'draw': 11}


100%|██████████| 200/200 [19:08<00:00,  5.74s/it]


multi-agent-double-q-learning one-agent-deep-q-learning {'team_1_wins': 90, 'team_2_wins': 95, 'draw': 15}


100%|██████████| 200/200 [15:18<00:00,  4.59s/it]


multi-agent-double-q-learning one-agent-double-q-learning {'team_1_wins': 92, 'team_2_wins': 93, 'draw': 15}


100%|██████████| 200/200 [17:04<00:00,  5.12s/it]


multi-agent-double-q-learning one-agent-duelling-q-learning {'team_1_wins': 88, 'team_2_wins': 89, 'draw': 23}


100%|██████████| 200/200 [12:25<00:00,  3.73s/it]


multi-agent-duelling-q-learning one-agent-deep-q-learning {'team_1_wins': 91, 'team_2_wins': 105, 'draw': 4}


100%|██████████| 200/200 [11:32<00:00,  3.46s/it]


multi-agent-duelling-q-learning one-agent-double-q-learning {'team_1_wins': 92, 'team_2_wins': 94, 'draw': 14}


100%|██████████| 200/200 [12:32<00:00,  3.76s/it]

multi-agent-duelling-q-learning one-agent-duelling-q-learning {'team_1_wins': 96, 'team_2_wins': 91, 'draw': 13}



