In [86]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

from MazeGenerationAlgorithms.RandomizedKruskal import RandomizedKruskalMaze

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import scipy.special as sp
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import copy
from IPython.display import clear_output
from collections import namedtuple, deque
import random

In [None]:
folder_path = "files"

for filename in os.listdir(folder_path):
    file_path = os.path.join(folder_path, filename)
    os.remove(file_path)  

print(f'All contents of "{folder_path}" have been deleted.')

In [88]:
Transition = namedtuple('Memory',
                        field_names=['state', 'action', 'next_state', 'reward', 'is_game_on'])

In [89]:
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)

    def __len__(self):
        return len(self.memory)

    def append(self, transition):
        self.memory.append(transition)

    def sample(self, batch_size, device='mps'):
        transitions = random.sample(self.memory, batch_size)
        states, actions, next_states, rewards, isgameon = zip(*transitions)
        return (torch.tensor(states, dtype=torch.float, device=device),
                torch.tensor(actions, dtype=torch.long, device=device),
                torch.tensor(next_states, dtype=torch.float, device=device),
                torch.tensor(rewards, dtype=torch.float, device=device),
                torch.tensor(isgameon, dtype=torch.bool, device=device))

In [90]:
class DeepQNetwork(nn.Module):
    def __init__(self, Ni, Nh1, Nh2, No=4):  
        super().__init__()
        self.fc1 = nn.Linear(Ni, Nh1)
        self.fc2 = nn.Linear(Nh1, Nh2)
        self.fc3 = nn.Linear(Nh2, No)
        self.act = nn.SiLU()
        self.weights = 0

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x

In [91]:
def Qloss(batch, net, gamma):
    states, actions, next_states, rewards, _ = batch
    lbatch = len(states)
    state_action_values = net(states.view(lbatch, -1))
    state_action_values = state_action_values.gather(1, actions.unsqueeze(-1)).squeeze(-1)
    next_state_values = net(next_states.view(lbatch, -1)).max(1)[0].detach()
    target = rewards + gamma * next_state_values
    return F.smooth_l1_loss(state_action_values, target)

In [92]:
class MazeEnvironment:    
    def __init__(self, maze):
        self.graph_maze = maze.to_graph()
        self.maze = maze.to_grid()
        self.maze_size = self.maze.size
        x = len(self.maze)
        y = len(self.maze)
        self.boundary = np.asarray([x, y])
        self.init_position = np.array([1, 1])
        self.current_position = self.init_position.copy()
        self.goal = (x-2, y-2)
        self.visited = {tuple(self.current_position)}
        self.allowed_states = np.argwhere(self.maze == 0).tolist()
        distances = np.linalg.norm(np.array(self.allowed_states) - self.goal, axis=1)
        goal_idx = np.where(distances == 0)[0][0]
        self.allowed_states.pop(goal_idx)
        self.distances = np.delete(distances, goal_idx)
        self.action_map = {0: [0, 1], 1: [0, -1], 2: [1, 0], 3: [-1, 0]}
        self.directions = {0: '→', 1: '←', 2: '↓', 3: '↑'}
        self.dead_end_rewards = self.calculate_dead_end_rewards()

    def reset_policy(self, eps, reg=7):
        scaling_factor = reg * (1 - eps ** (2 / reg)) ** (reg / 2)
        return sp.softmax(-self.distances / scaling_factor).squeeze()
    
    def reset(self, epsilon, prand=0):
        if np.random.rand() < prand:
            idx = np.random.choice(len(self.allowed_states))
        else:
            idx = np.random.choice(len(self.allowed_states), p=self.reset_policy(epsilon))
        self.current_position = np.array(self.allowed_states[idx])
        self.visited = {tuple(self.current_position)}
        return self.state()
    
    def state_update(self, action):
        isgameon = True
        reward = -0.05
        next_position = self.current_position + self.action_map[action]
        if np.array_equal(self.current_position, self.goal):
            return self.state(), -sum(self.dead_end_rewards.values()), False 
        if tuple(self.current_position) in self.visited:
            reward = -0.2
        if tuple(self.current_position) in self.dead_end_rewards:
            reward += self.dead_end_rewards[tuple(self.current_position)]
        if self.is_valid_state(next_position):
            self.current_position = next_position
        else:
            reward = -1
        self.visited.add(tuple(self.current_position))
        return self.state(), reward, isgameon

    def state(self):
        state = copy.deepcopy(self.maze)
        state[tuple(self.current_position)] = 2
        return state
        
    def is_valid_state(self, position):
        return not (self.is_out_of_bounds(position) or self.is_wall(position))

    def is_out_of_bounds(self, position):
        return np.any(position < 0) or np.any(position >= self.boundary)

    def is_wall(self, position):
        return self.maze[tuple(position)] == 1
        
    def find_dead_ends(self):
        start_in_graph = tuple([(self.init_position[0] - 1) / 2, (self.init_position[1] - 1) / 2])
        goal_in_graph = tuple([(self.goal[0] - 1) / 2, (self.goal[1] - 1) / 2])
        dead_ends = {node for node, degree in self.graph_maze.degree() if degree == 1 and node not in {goal_in_graph, start_in_graph}}
        dead_end_positions = {(2 * y + 1, 2 * x + 1) for x, y in dead_ends}
        return dead_end_positions
    
    def dead_end_L1_distance(self):
        dead_ends = self.find_dead_ends()
        distances = {}
        for dead_end in dead_ends:            
            distance = abs(dead_end[0] - self.goal[0]) + abs(dead_end[1] - self.goal[1])
            distances[dead_end] = distance  
        return distances
    
    def calculate_dead_end_rewards(self):
        dead_end_distances = self.dead_end_L1_distance()
        max_dist = max(dead_end_distances.values())
        min_dist = min(dead_end_distances.values())
        if max_dist == min_dist:
            return {dead_end: -1 for dead_end in dead_end_distances}
        dead_end_rewards = {}
        for dead_end, dist in dead_end_distances.items():
            normalized_reward = -0.4 - (0.6 * (dist - min_dist) / (max_dist - min_dist))
            dead_end_rewards[dead_end] = normalized_reward
        return dead_end_rewards  

    def draw_rewards(self):
        rewards = np.full(self.maze.shape, np.nan)  
        for x, y in self.allowed_states:
            rewards[x, y] = 0
        rewards[self.goal] = -sum(self.dead_end_rewards.values()) 
        for dead_end, reward in self.dead_end_rewards.items():
            x, y = dead_end
            rewards[x, y] = reward  
        fig, ax = plt.subplots(figsize=(11, 11))
        fig.patch.set_facecolor("#AACFC3")
        ax.set_facecolor("#AACFC3")
        cmap = plt.cm.RdYlBu_r
        im = ax.imshow(rewards, cmap=cmap, interpolation="nearest")
        plt.colorbar(im, label="Reward")
        for x in range(self.maze.shape[0]):
            for y in range(self.maze.shape[1]):
                if not np.isnan(rewards[x, y]):
                    cell_rect = patches.Rectangle(
                        (y - 0.5, x - 0.5),
                        1, 1,
                        linewidth=1,
                        edgecolor='black',
                        facecolor='none'
                    )
                    ax.add_patch(cell_rect)
                    ax.text(
                        y, x,
                        f"{rewards[x, y]:.1f}",
                        ha="center", va="center",
                        fontsize=8,
                        color="black"
                    )
        ax.set_xticks([])
        ax.set_yticks([])
        plt.title("Reward Map")
        plt.show()


In [93]:
class Agent:
    def __init__(self, env, maze, memory_buffer, alpha=1e-6, use_softmax=True):
        self.env = env
        self.maze = maze
        self.buffer = memory_buffer
        self.alpha = alpha
        self.num_act = 4
        self.use_softmax = use_softmax
        self.min_reward = -self.env.maze.size // 3
        self.total_reward = 0
        self.isgameon = True
        self.rwd = 0

    def make_a_move(self, net, epsilon, device='mps'):
        current_state = self.env.state()
        action = self.select_action(net, epsilon, device)
        next_state, reward, self.isgameon = self.env.state_update(action)
        self.total_reward += reward
        self.rwd = reward
        if self.total_reward < self.min_reward:
            self.isgameon = False
        if not self.isgameon:
            self.total_reward = 0
        self.buffer.append(Transition(current_state, action, next_state, reward, self.isgameon))

    def select_action(self, net, epsilon, device='mps'):
        state_tensor = torch.tensor(self.env.state(), dtype=torch.float32, device=device).view(1, -1)
        qvalues = net(state_tensor).cpu().detach().numpy().squeeze()
        if self.use_softmax:
            p = sp.softmax(qvalues / max(epsilon, 1e-8))
            action = np.random.choice(self.num_act, p=p)
        else:
            if random.random() < epsilon:
                action = random.randint(0, self.num_act - 1)
            else:
                action = np.argmax(qvalues)
        probs = np.full(self.num_act, self.alpha)
        probs[action] = 1 - self.alpha * (self.num_act - 1)
        return np.random.choice(self.num_act, p=probs)

    def plot_policy_map(self, net):
        net.eval()
        with torch.no_grad():
            fig, ax = plt.subplots(figsize=(11, 11))
            fig.patch.set_facecolor("#F6EDDD")
            grid_size = len(self.maze.to_grid())
            q_values_grid = np.full((grid_size, grid_size), np.nan)
            q_values_list = []
            for free_cell in self.env.allowed_states:
                self.env.current_position = np.asarray(free_cell)
                state_tensor = torch.Tensor(self.env.state()).to('mps').view(1, -1)
                q_values = net(state_tensor).detach().squeeze()
                max_q = torch.abs(torch.max(q_values)).item()
                q_values_grid[free_cell[0], free_cell[1]] = max_q
                q_values_list.append(max_q)
            epsilon = 1e-8
            q_min, q_max = min(q_values_list), max(q_values_list)
            normalized_q_grid = (q_values_grid - q_min) / (q_max - q_min + epsilon)
            for free_cell in self.env.allowed_states:
                self.env.current_position = np.asarray(free_cell)
                state_tensor = torch.Tensor(self.env.state()).to('mps').view(1, -1)
                q_values = net(state_tensor).detach().squeeze()
                action = torch.argmax(q_values, dim=0).item()
                policy_direction = self.env.directions[action]
                cell_rect = patches.Rectangle((free_cell[1] - 0.5, free_cell[0] - 0.5), 1, 1,
                                              linewidth=1, edgecolor='black', facecolor='none')
                ax.add_patch(cell_rect)
                ax.text(free_cell[1], free_cell[0], policy_direction, ha='center', va='center',
                        fontsize=9, color='black', fontweight='bold')
            goal_cell = self.env.goal
            goal_rect = patches.Rectangle((goal_cell[1] - 0.5, goal_cell[0] - 0.5), 1, 1,
                                          linewidth=1, edgecolor='black', facecolor='green')
            ax.add_patch(goal_rect)
            ax.text(goal_cell[1], goal_cell[0], 'Goal', ha='center', va='center',
                    fontsize=7, color='white', fontweight='bold')
            cax = ax.imshow(normalized_q_grid, cmap='RdYlBu_r', interpolation='nearest')
            cbar = fig.colorbar(cax, ax=ax, shrink=0.75)
            cbar.set_label('Normalized Max Q-Value', fontsize=10)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title('Policy Map', fontsize=14, color='black')
            plt.show()


In [94]:
def train(agent, net, target, optimizer, num_epochs, cutoff, batch_size, 
          buffer_start_size, gamma, exploration_rate=-1, device="mps"):
    epsilon = np.exp(-np.arange(num_epochs) / cutoff)
    cutoff_idx = 100 * int(num_epochs / cutoff)
    threshold = epsilon[cutoff_idx]
    epsilon = np.minimum(epsilon, threshold)

    r_list = []
    loss_log = []
    best_loss = float("inf")
    running_loss = 0
    last_counters = deque(maxlen=20)
    last_losses = deque(maxlen=20)
    last_results = deque(maxlen=20)

    for epoch in range(num_epochs):
        r_sum, loss, counter = 0, 0, 0
        eps = epsilon[epoch] if exploration_rate == -1 else exploration_rate

        agent.isgameon = True
        _ = agent.env.reset(eps, prand=0.1)

        while agent.isgameon:
            agent.make_a_move(net, eps)
            counter += 1
            r_sum += agent.rwd

            if len(agent.buffer) < buffer_start_size:
                continue

            optimizer.zero_grad()
            batch = agent.buffer.sample(batch_size, device=device)

            loss_t = Qloss(batch, net, gamma=gamma)
            loss_t.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
            optimizer.step()

            loss += loss_t.item()

        game_result = int((agent.env.current_position == agent.env.goal).all())
        loss_log.append(loss)

        if epoch % 5 == 0:
            target.load_state_dict(net.state_dict())

        if epoch > num_epochs // 20:
            running_loss = np.mean(loss_log[-50:])
            if running_loss < best_loss:
                best_loss = running_loss
                torch.save(net.state_dict(), "files/best.torch")
                estop = epoch

        r_list.append(r_sum)

        if epoch % 10 == 0:
            torch.save(r_list, "files/r_list.torch")
            torch.save(loss_log, "files/loss_log.torch")

        last_counters.append(counter)
        last_losses.append(loss)
        last_results.append(game_result)

        avg_last_counters = np.mean(last_counters) if last_counters else 0
        avg_last_losses = np.mean(last_losses) if last_losses else 0
        win_last_results = np.sum(last_results)

        if epoch % 20 == 0 or epoch == num_epochs - 1:
            clear_output(wait=True)

            if epoch > num_epochs // 20:
                print(f'Best loss so far: {best_loss:.5f} at epoch {estop}')
                print('\n')

            if epoch >= 20:
                print(f'Epoch {epoch - 20} to {epoch}')
                progress_bar = '#' * int(100 * (epoch / num_epochs)) + ' ' * int(100 * (1 - epoch / num_epochs))
                print(f'[{progress_bar}]')
                print(f'Games won in last 20: {win_last_results}')
                print(f'Avg moves in last 20: {avg_last_counters:.2f}')
                print(f'Avg loss in last 20: {avg_last_losses:.5f}')
                agent.plot_policy_map(net)
            else:
                print('Starting...')

    return net, r_list, loss_log

In [None]:
width, height = 5, 5
maze = RandomizedKruskalMaze(width, height)  

env = MazeEnvironment(maze)

env.draw_rewards()

In [None]:
DEVICE = "mps"
input_size = env.maze_size
net = DeepQNetwork(input_size, input_size, input_size, 4).to(DEVICE)
target = DeepQNetwork(input_size, input_size, input_size, 4).to(DEVICE)
target.load_state_dict(net.state_dict())

NUM_EPOCHS = 10000
CUTOFF = NUM_EPOCHS // 3
BUFFER_CAPACITY = NUM_EPOCHS
BUFFER_START_SIZE = NUM_EPOCHS // 15
BATCH_SIZE = 24
LEARNING_RATE = 1e-4
GAMMA = 0.9

memory = ReplayMemory(BUFFER_CAPACITY)
agent = Agent(env, maze, memory)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

trained_net, rewards, losses = train(
    agent, net, target, optimizer,
    num_epochs=NUM_EPOCHS,
    cutoff=CUTOFF,
    batch_size=BATCH_SIZE,
    buffer_start_size=BUFFER_START_SIZE,
    gamma=GAMMA,
    device=DEVICE
)

In [None]:
r_list = torch.load("files/r_list.torch")

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(range(len(r_list)), r_list, linewidth=2, color='#4D7C7D', label='Accumulated Reward')
ax.set_title('Accumulated Reward Over Epochs', fontsize=18, fontweight='bold', pad=15)
ax.set_xlabel('Epoch', fontsize=14, fontweight='medium', labelpad=10)
ax.set_ylabel('Accumulated Reward', fontsize=14, fontweight='medium', labelpad=10)
ax.grid(True, linestyle='--', alpha=0.6)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.legend(fontsize=12, loc='lower right')
plt.show()

In [None]:
torch.save(net.state_dict(), "files/net.torch")
net.load_state_dict(torch.load("files/net.torch"))

agent.plot_policy_map(net)

In [None]:
best_net = copy.deepcopy(net)
best_net.load_state_dict(torch.load("files/best.torch"))

agent.plot_policy_map(best_net)