<a href="https://colab.research.google.com/github/nhngmnh/RLAgentResearch/blob/master/GraphDQNAgent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque

# ---------------------------
# Graph Q-Network (CPU-only)
# ---------------------------
class GraphQNetwork(nn.Module):
    def __init__(self, feature_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(feature_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, node_features, adjacency_matrix, current_index):
        # node_features: [N, feature_dim], adjacency_matrix: [N,N]
        h = F.relu(self.fc1(node_features))
        h = F.relu(self.fc2(h))

        degree = adjacency_matrix.sum(1, keepdim=True)
        h_new = adjacency_matrix @ h / degree

        h_cur = h_new[current_index]  # embedding current node
        q_values = h_new @ h_cur      # Q-values for all nodes
        return q_values

# ---------------------------
# Graph DQN Agent
# ---------------------------
class GraphDQNAgent:
    def __init__(self, distance_matrix, feature_dim=3, hidden_dim=128,
                 gamma=1.0, epsilon=1.0, decay=0.999, min_epsilon=0.01,
                 episodes=500, batch_size=64, buffer_size=5000, lr=1e-3):
        self.numCities = len(distance_matrix)
        self.distanceMatrix = torch.tensor(distance_matrix, dtype=torch.float)
        self.gamma = gamma
        self.epsilon = epsilon
        self.decay = decay
        self.min_epsilon = min_epsilon
        self.episodes = episodes
        self.batch_size = batch_size

        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim

        # Networks
        self.network = GraphQNetwork(feature_dim, hidden_dim)
        self.target_network = GraphQNetwork(feature_dim, hidden_dim)
        self.target_network.load_state_dict(self.network.state_dict())
        self.target_network.eval()
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)

        # Replay buffer
        self.replay_buffer = deque(maxlen=buffer_size)

        # Best solutions
        self.best_solution = None
        self.best_distance = float('inf')
        self.initial_solution = None
        self.initial_distance = float('inf')

    # -----------------------
    # Node features and adjacency
    # -----------------------
    def get_node_features(self, visited_mask):
        features = []
        for i in range(self.numCities):
            # [x, y, visited] (x=y=i chỉ là placeholder; nếu có coords thật, thay bằng coords)
            features.append([i, i, 1 if visited_mask[i] else 0])
        return torch.tensor(features, dtype=torch.float)

    def get_adjacency(self):
        return torch.ones((self.numCities, self.numCities))  # fully connected

    # -----------------------
    # Select action
    # -----------------------
    def select_action(self, node_features, adjacency_matrix, current_index, unvisited):
        if random.random() < self.epsilon:
            return random.choice(list(unvisited))
        else:
            with torch.no_grad():
                q_values = self.network(node_features, adjacency_matrix, current_index)
            q_np = q_values.numpy()
            mask = np.array([0 if i in unvisited else 1 for i in range(self.numCities)])
            q_np[mask == 1] = -np.inf
            return int(np.argmax(q_np))

    # -----------------------
    # Store & train
    # -----------------------
    def store_transition(self, transition):
        self.replay_buffer.append(transition)

    def train_step(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        batch = random.sample(self.replay_buffer, self.batch_size)
        loss = 0.0
        for state_feat, adj, cur_idx, action, reward, next_feat, next_adj, next_idx, done, next_unvisited in batch:
            q_pred = self.network(state_feat, adj, cur_idx)[action]
            with torch.no_grad():
                q_next = self.target_network(next_feat, next_adj, next_idx)
                mask = np.array([0 if i in next_unvisited else 1 for i in range(self.numCities)])
                q_next_np = q_next.numpy()
                q_next_np[mask == 1] = -np.inf
                max_q_next = 0 if done else np.max(q_next_np)
                q_target = reward + self.gamma * max_q_next
            loss += F.mse_loss(q_pred, torch.tensor(q_target))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def decay_epsilon_func(self):
        self.epsilon = max(self.min_epsilon, self.epsilon * self.decay)

    # -----------------------
    # Solve function
    # -----------------------
    def solve(self):
        start_time = time.time()

        for episode in range(self.episodes):
            start = 0
            state = start
            unvisited = list(range(1, self.numCities))
            visited_mask = [0]*self.numCities
            visited_mask[start] = 1
            path = [state]
            total_reward = 0.0

            while unvisited:
                node_features = self.get_node_features(visited_mask)
                adjacency_matrix = self.get_adjacency()

                action = self.select_action(node_features, adjacency_matrix, state, unvisited)
                reward = -self.distanceMatrix[state][action].item()
                next_state = action
                visited_mask[next_state] = 1
                next_unvisited = [i for i in unvisited if i != action]

                # store
                self.store_transition((node_features, adjacency_matrix, state, action, reward,
                                       self.get_node_features(visited_mask), adjacency_matrix, next_state, False, next_unvisited))
                self.train_step()

                state = next_state
                path.append(state)
                unvisited.remove(action)
                total_reward += reward

            # cuối tour: về start
            reward = -self.distanceMatrix[state][start].item()
            total_reward += reward
            self.store_transition((self.get_node_features(visited_mask), adjacency_matrix, state, start, reward,
                                   self.get_node_features(visited_mask), adjacency_matrix, start, True, []))
            self.train_step()

            # update best
            tour_distance = -total_reward
            if episode == 0:
                self.initial_solution = [city for city in path] + [start]
                self.initial_distance = tour_distance
            if tour_distance < self.best_distance:
                self.best_solution = [city for city in path] + [start]
                self.best_distance = tour_distance

            # decay epsilon
            self.decay_epsilon_func()

            # update target network
            if episode % 10 == 0:
                self.target_network.load_state_dict(self.network.state_dict())

        end_time = time.time()
        execution_time = end_time - start_time

        return {
            'initial_solution': self.initial_solution,
            'final_solution': self.best_solution,
            'initial_distance': self.initial_distance,
            'final_distance': self.best_distance,
            'execution_time': execution_time
        }

# -----------------------
# Example usage
# -----------------------
if __name__ == "__main__":
    # distance matrix giả lập 10 thành phố
    N = 10
    distance_matrix = np.random.randint(1, 100, size=(N, N))
    np.fill_diagonal(distance_matrix, 0)

    agent = GraphDQNAgent(distance_matrix, episodes=50)
    result = agent.solve()
    print(result)


{'initial_solution': [0, 5, 7, 6, 8, 2, 9, 4, 3, 1, 0], 'final_solution': [0, 4, 1, 3, 9, 7, 6, 2, 5, 8, 0], 'initial_distance': 516.0, 'final_distance': 367.0, 'execution_time': 20.139946937561035}
