In [1]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import List
import numpy as np
import itertools
import random
import math

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cpu')

In [3]:
N = 80
k = 4

Inspiration: https://andrew-gordienko.medium.com/reinforcement-learning-dqn-w-pytorch-7c6faad3d1e

In [4]:
# Hyperparameters
EPISODES = 2_000
LEARNING_RATE = 0.001
MEM_SIZE = 1000
BATCH_SIZE = 128
GAMMA = 0.9999
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.0025
EXPLORATION_DECAY = 0.9999 # math.exp(math.log(EXPLORATION_MIN)/EPISODES) 
EXPLORATION_DECAY

0.9999

# Environment

In [5]:
# Simply a collection of methods to work with the state / observed environment
class NQEnv:
    def __init__(self, n, k):
        self.n = n
        self.k = k
        self._reset()
        
    def _reset(self):
        self.rows, self.cols, self.diags, self.anti_diags = set(), set(), set(), set()
        self.board = np.zeros(self.n * self.n)
        
    def reset(self) -> np.array:
        """
            Resets board with new partial N-Queens configuration
        """        
        has_set = False
        while not has_set:
            self._reset() # resets board and line sets
            num_set = 0
            while num_set < self.k:
                if not self.get_avail_squares(): break
                action = self.sample_action()
                self._perform_action(action)
                
                num_set += 1
                
            has_set = num_set == self.k
    
        return self.board # 1D array -> partial config
    
    def _perform_action(self, action:int):
        """
            Performs action i.e. places queen on square (action//n, action (mod n))
        """
        row,col = action//self.n, action%self.n
        self.rows.add(row)
        self.cols.add(col)
        self.diags.add(col-row)
        self.anti_diags.add(col+row)
        
        self.board[action] = 1
        
    def get_avail_squares(self) -> List[int]:
        """
            Returns a list of the squares available in col-major format
        """
        
        avail = []
        for row, col in itertools.product(range(self.n), range(self.n)):
            if row in self.rows or col in self.cols or col-row in self.diags or col+row in self.anti_diags: continue
            avail.append(self.n*row + col)
        return avail
    
    def sample_action(self):
        """
            Returns a valid action, assumes such an action exists
        """
        avail = self.get_avail_squares()
        return random.sample(avail,1)[0]
    
    def step(self, action):
        """
            Performs action on the board, assumes that the action is valid
        """
        self._perform_action(action)
        avail = self.get_avail_squares()
        terminate = len(self.rows) == self.n or len(avail) == 0
        reward = (len(self.rows)/self.n)**4 / (1 if len(self.rows) == self.n else 5)
        return (self.board.copy(), reward, terminate) # new_state, reward, teminated
    
    def get_mask(self):
        avail = self.get_avail_squares()
        mask = [0] * self.n * self.n
        for idx in avail: mask[idx] = 1
        return torch.as_tensor(mask)
    
    def render(self, fontsize=20):
        """
            Render the board nicely
        """
        matrix = np.zeros((self.n, self.n), dtype=int)
        for i in range(self.n**2):
            matrix[i//self.n][i%self.n] = int(self.board[i])
        
        n = matrix.shape[0]
        # Create a chess board (n x n) pattern
        board = np.zeros_like(matrix)
        board[1::2, ::2] = 1
        board[::2, 1::2] = 1

        cmap = ListedColormap(['#769656', '#eeeed2'])
        fig, ax = plt.subplots()
        ax.imshow(board, cmap=cmap, interpolation='nearest')

        # Place queens based on matrix
        for i in range(n):
            for j in range(n):
                if matrix[i, j] == 1:
                    ax.text(j, i, '♛', fontsize=fontsize, ha='center', va='center', color='black' if board[i, j] else 'white')

        # Hide the axes
        ax.set_xticks([])
        ax.set_yticks([])

In [6]:
env = NQEnv(N,k)
env

<__main__.NQEnv at 0x1dd89323eb0>

# Model

In [7]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        
        _observation_size = N*N
        _action_space_size = N*N
        self.fc1 = nn.Linear(_observation_size, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 512)
        self.fc5 = nn.Linear(512, 256)
        self.fc6 = nn.Linear(256, 128)
        self.out = nn.Linear(128, _action_space_size)
        
        self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
        self.loss = nn.MSELoss()
        self.to(device)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        x = F.sigmoid(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        
        return self.out(x)

In [8]:
class ReplayBuffer:
    def __init__(self):
        self.mem_count = 0
        
        _observationspace_shape = N*N
        self.states = np.zeros((MEM_SIZE, _observationspace_shape),dtype=np.float32)
        self.actions = np.zeros(MEM_SIZE, dtype=np.int64)
        self.rewards = np.zeros(MEM_SIZE, dtype=np.float32)
        self.states_ = np.zeros((MEM_SIZE, _observationspace_shape),dtype=np.float32)
        self.dones = np.zeros(MEM_SIZE, dtype=np.bool)
        
    def add(self, state, action, reward, state_, done):
        mem_index = self.mem_count % MEM_SIZE

        self.states[mem_index]  = state
        self.actions[mem_index] = action
        self.rewards[mem_index] = reward
        self.states_[mem_index] = state_
        self.dones[mem_index] =  1 - done

        self.mem_count += 1
        
    def sample(self):
        MEM_MAX = min(self.mem_count, MEM_SIZE)
        batch_indices = np.random.choice(MEM_MAX, BATCH_SIZE, replace=True)
        
        states  = self.states[batch_indices]
        actions = self.actions[batch_indices]
        rewards = self.rewards[batch_indices]
        states_ = self.states_[batch_indices]
        dones   = self.dones[batch_indices]

        return states, actions, rewards, states_, dones

# Training

### Training framework

In [9]:
class DQN_Solver:
    def __init__(self):
        self.memory = ReplayBuffer()
        self.exploration_rate = EXPLORATION_MAX
        self.network = DQN()
        
    def choose_action(self, observation):
        if random.random() < self.exploration_rate:
            return env.sample_action()
        
        state = torch.tensor(observation).float().detach()
        state = state.to(device)
        state = state.unsqueeze(0)
        q_values = self.network(state)
        
        mask = env.get_mask()
        q_values *= mask
        
        return torch.argmax(q_values).item()
    
    def learn(self):
        if self.memory.mem_count < BATCH_SIZE:
            return
        
        states, actions, rewards, states_, dones = self.memory.sample()
        states = torch.tensor(states , dtype=torch.float32).to(device)
        actions = torch.tensor(actions, dtype=torch.long).to(device)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        states_ = torch.tensor(states_, dtype=torch.float32).to(device)
        dones = torch.tensor(dones, dtype=torch.bool).to(device)
        batch_indices = np.arange(BATCH_SIZE, dtype=np.int64)

        q_values = self.network(states)
        next_q_values = self.network(states_)
        
        predicted_value_of_now = q_values[batch_indices, actions]
        predicted_value_of_future = torch.max(next_q_values, dim=1)[0]
        
        q_target = rewards + GAMMA * predicted_value_of_future * dones

        loss = self.network.loss(q_target, predicted_value_of_now)
        self.network.optimizer.zero_grad()
        loss.backward()
        self.network.optimizer.step()

        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

    def returning_epsilon(self):
        return self.exploration_rate

### Training

In [10]:
observation_space = N*N

best_reward = 0
average_reward = 0
episode_number = []
average_reward_number = []

In [11]:
agent = DQN_Solver()

# Create a tqdm progress bar for episodes
pbar = tqdm(range(1, EPISODES), desc='Initializing')

for i in pbar:
    state = env.reset()
    state = np.reshape(state, [1, observation_space])
    score = 0

    while True:
        action = agent.choose_action(state)
        state_, reward, done = env.step(action)
        state_ = np.reshape(state_, [1, observation_space])
        agent.memory.add(state, action, reward, state_, done)
        agent.learn()
        state = state_
        score += reward

        if done:
            if score > best_reward:
                best_reward = score
            average_reward += score 
            # Update progress bar description with the latest episode info instead of using print
            pbar.set_description("Episode: {} Avg Reward: {:.2f} Best Reward: {} Last Reward: {} Epsilon: {:.2f}".format(i, average_reward/i, best_reward, score, agent.returning_epsilon()))
            pbar.refresh() # to show immediately the update
            break
            
    episode_number.append(i)
    average_reward_number.append(average_reward/i)

plt.xlabel("Epoch")
plt.ylabel("Avg Reward")
plt.plot(episode_number, average_reward_number)
plt.show()


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  self.dones = np.zeros(MEM_SIZE, dtype=np.bool)


Initializing:   0%|          | 0/1999 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Inference

In [None]:
def plot_state_history(state_history, n):
    # Adjusting figsize and subplot layout for vertical display
    fig, axs = plt.subplots(len(state_history), 1, figsize=(4, len(state_history)*4))
    
    # Ensuring axs is iterable when there's only one plot
    if len(state_history) == 1:
        axs = [axs]
    
    for ax, state in zip(axs, state_history):
        matrix = state.reshape((n, n))  # Assuming state is a flat array
        board = np.zeros_like(matrix)
        board[1::2, ::2] = 1
        board[::2, 1::2] = 1
        cmap = ListedColormap(['#769656', '#eeeed2'])
        ax.imshow(board, cmap=cmap, interpolation='nearest')
        
        # Adding the queens to the board
        for i in range(n):
            for j in range(n):
                if matrix[i, j] == 1:
                    ax.text(j, i, '♛', fontsize=10, ha='center', va='center', color='black' if board[i, j] else 'white')
        
        # Removing the ticks from the axes
        ax.set_xticks([])
        ax.set_yticks([])
        
    plt.tight_layout()
    plt.show()

In [None]:
state_history = []
action_history = []
avail_actions_history = []

# Initial random partial config
state = env.reset()
state_history.append(state.copy())

# Place queens
while len(env.get_avail_squares()):
    avail_actions_history.append(env.get_avail_squares())
    
    action = agent.choose_action(state)
    env._perform_action(action)
    state = env.board
    state_history.append(state.copy())
    action_history.append(action)
    
print(f"Number of queens placed = {k + len(state_history) - 1}")

In [None]:
env.render()
plt.show()

In [None]:
action_history, avail_actions_history

In [None]:
plot_state_history(state_history, N)