In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
from pathlib import Path
from world.delivery_environment import Environment

def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

device = get_device()

In [None]:
env = Environment(Path("grid_configs/small_grid_2.npy"), no_gui=True)
env.reset()
n_rows, n_cols = env.grid.shape
max_deliveries = env.initial_target_count
state_dim = 3

def encode_state_norm(raw: tuple[int,int,int]) -> torch.Tensor:
    i, j, rem = raw
    return torch.tensor([
        i / (n_rows - 1),
        j / (n_cols - 1),
        rem / max_deliveries
    ], device=device, dtype=torch.float32)

In [None]:
class DQN(nn.Module):
    def __init__(self, input_dim: int, n_actions: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

In [None]:
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        # store raw state tuples for later encoding
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        # encode and stack
        states_enc = torch.stack([encode_state_norm(s) for s in states])
        next_states_enc = torch.stack([encode_state_norm(s) for s in next_states])
        return (
            states_enc,
            torch.tensor(actions, dtype=torch.int64, device=device),
            torch.tensor(rewards, dtype=torch.float32, device=device),
            next_states_enc,
            torch.tensor(dones, dtype=torch.float32, device=device)
        )
    
    def __len__(self):
        return len(self.buffer)

In [None]:
n_actions = 4
buffer_capacity = 5000
batch_size = 64
gamma = 0.99
lr = 1e-3
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.995
target_update_freq = 100
num_episodes = 500
max_steps_per_episode = 500

policy_net = DQN(state_dim, n_actions).to(device)
target_net = DQN(state_dim, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=lr)
replay_buffer = ReplayBuffer(buffer_capacity)

In [None]:
epsilon = epsilon_start
for episode in range(1, num_episodes + 1):
    raw_state = env.reset()                  # (row, col, remaining)
    state = encode_state_norm(raw_state)
    total_reward = 0
    
    for _ in range(max_steps_per_episode):
        # ε-greedy action selection
        if random.random() < epsilon:
            action = random.randrange(n_actions)
        else:
            with torch.no_grad():
                q_vals = policy_net(state.unsqueeze(0))
                action = q_vals.argmax(dim=1).item()
        
        raw_next, reward, done, _ = env.step(action)
        next_state = encode_state_norm(raw_next)
        
        # store transition
        replay_buffer.push(raw_state, action, reward, raw_next, done)
        total_reward += reward
        raw_state = raw_next
        state = next_state
        
        # learn from batch
        if len(replay_buffer) >= batch_size:
            s_b, a_b, r_b, ns_b, d_b = replay_buffer.sample(batch_size)
            # Q(s,a)
            q_values = policy_net(s_b).gather(1, a_b.unsqueeze(1)).squeeze(1)
            # target: r + γ max_a' Q_target(s',a')
            with torch.no_grad():
                next_q = target_net(ns_b).max(dim=1)[0]
                target_q = r_b + gamma * next_q * (1 - d_b)
            loss = nn.MSELoss()(q_values, target_q)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if done:
            break
    
    # decay ε
    epsilon = max(epsilon_end, epsilon * epsilon_decay)
    # update target network
    if episode % target_update_freq == 0:
        target_net.load_state_dict(policy_net.state_dict())
    # log
    if episode % 10 == 0:
        print(f"Episode {episode}, Total Reward: {total_reward:.2f}, Epsilon: {epsilon:.2f}")

In [None]:
torch.save(policy_net.state_dict(), "models/small_grid_2_policy.pt")

In [None]:
eval_net = DQN(state_dim, n_actions).to(device)

eval_net.load_state_dict(torch.load("models/small_grid_2_policy.pt", map_location=device))
eval_net.eval()

raw_state = env.reset()
state = encode_state(raw_state)
done = False
trajectory = [env.agent_pos]

while not done:
    with torch.no_grad():
        action = eval_net(state.unsqueeze(0)).argmax(dim=1).item()
    raw_state, _, done, _ = env.step(action)
    trajectory.append(env.agent_pos)
    state = encode_state(raw_state)

print("Evaluation trajectory:", trajectory)