In [60]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

class GridWorld:
    def __init__(self, size=5):
        self.size = size
        self.reset()
    
    def reset(self):
        self.agent_position = 0
        return self.agent_position
    
    def step(self, action):
        if action == 0 and self.agent_position % self.size > 0:
            self.agent_position -= 1
        elif action == 1 and self.agent_position % self.size < self.size - 1:
            self.agent_position += 1
        elif action == 2 and self.agent_position >= self.size:
            self.agent_position -= self.size
        elif action == 3 and self.agent_position < self.size * (self.size - 1):
            self.agent_position += self.size
        
        done = self.agent_position == self.size * self.size - 1
        reward = 10 if done else -1
        return self.agent_position, reward, done, {}

class QLearningAgent:
    def __init__(self, state_size, action_size, learning_rate=0.001, discount_factor=0.99, exploration_rate=1.0, exploration_decay=0.995, exploration_min=0.01):
        self.state_size = state_size
        self.action_size = action_size
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.exploration_rate = exploration_rate
        self.exploration_min = exploration_min
        self.exploration_decay = exploration_decay
        
        # Q-Network
        self.q_network = nn.Sequential(
            nn.Linear(state_size, 24),
            nn.ReLU(),
            nn.Linear(24, 24),
            nn.ReLU(),
            nn.Linear(24, action_size)
        )
        
        # Target Network
        self.target_network = nn.Sequential(
            nn.Linear(state_size, 24),
            nn.ReLU(),
            nn.Linear(24, 24),
            nn.ReLU(),
            nn.Linear(24, action_size)
        )
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.loss_fn = nn.MSELoss()
        self.target_update_frequency = 10  # Update target every 10 episodes
        self.steps = 0
    
    def get_action(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_size)
        with torch.no_grad():
            state_tensor = torch.FloatTensor(self.one_hot_encode(state))
            q_values = self.q_network(state_tensor)
            return torch.argmax(q_values).item()
    
    def one_hot_encode(self, state):
        encoding = [0] * self.state_size
        encoding[state] = 1
        return encoding
    
    def train(self, state, action, reward, next_state, done):
        state_tensor = torch.FloatTensor(self.one_hot_encode(state))
        next_state_tensor = torch.FloatTensor(self.one_hot_encode(next_state))
        
        current_q_values = self.q_network(state_tensor)
        
        with torch.no_grad():
            next_q_values = self.target_network(next_state_tensor)
            max_next_q_value = torch.max(next_q_values).item()
            target_q_value = reward + (self.discount_factor * max_next_q_value * (not done))
        
        target = current_q_values.clone()
        target[action] = target_q_value
        loss = self.loss_fn(current_q_values, target)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.steps += 1
    
    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

def train_agent(episodes=500):
    env = GridWorld()
    agent = QLearningAgent(
        state_size=env.size**2,
        action_size=4,
        learning_rate=0.001,
        exploration_decay=0.995,
        exploration_min=0.01
    )
    
    for episode in range(episodes):
        state = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.train(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward
        
        # Decay exploration rate after each episode
        agent.exploration_rate = max(
            agent.exploration_min,
            agent.exploration_rate * agent.exploration_decay
        )
        
        # Update target network periodically
        if episode % agent.target_update_frequency == 0:
            agent.update_target_network()
        
        if episode % 100 == 0:
            print(f"Episode {episode}: Total Reward = {total_reward}, Exploration Rate = {agent.exploration_rate:.3f}")
    
    return agent

# Run the training
trained_agent = train_agent()

Episode 0: Total Reward = -13, Exploration Rate = 0.995
Episode 100: Total Reward = -8, Exploration Rate = 0.603
Episode 200: Total Reward = -4, Exploration Rate = 0.365
Episode 300: Total Reward = 3, Exploration Rate = 0.221
Episode 400: Total Reward = 1, Exploration Rate = 0.134


In [61]:
def test_agent(agent, env):
    state = env.reset()
    done = False
    trajectory = []
    
    # Disable exploration for testing
    original_exploration = agent.exploration_rate
    agent.exploration_rate = 0
    
    while not done:
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        trajectory.append((state, action))
        state = next_state
    
    # Restore exploration rate (if you plan to keep training)
    agent.exploration_rate = original_exploration
    
    return trajectory

# Test the agent
env = GridWorld()
path = test_agent(trained_agent, env)
print("Path taken:", path)

Path taken: [(0, 1), (1, 3), (6, 1), (7, 3), (12, 3), (17, 3), (22, 1), (23, 1)]


In [62]:
def print_grid_path(size, path):
    grid = [["·" for _ in range(size)] for _ in range(size)]
    action_symbols = {0: "←", 1: "→", 2: "↑", 3: "↓"}
    
    for (state, action) in path:
        row = state // size
        col = state % size
        grid[row][col] = action_symbols[action]
    
    # Mark goal (bottom-right corner)
    grid[size-1][size-1] = "G"
    
    for row in grid:
        print(" ".join(row))

# Visualize the path
print_grid_path(env.size, path)

→ ↓ · · ·
· → ↓ · ·
· · ↓ · ·
· · ↓ · ·
· · → → G
