In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
from pathlib import Path
from world.delivery_environment import Environment

def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

device = get_device()

pygame 2.6.1 (SDL 2.28.4, Python 3.11.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
env = Environment(Path("grid_configs/custom_medium_grid_3.npy"), no_gui=True, sigma=0.0)
env.reset()
n_rows, n_cols = env.grid.shape
max_deliveries = env.initial_target_count
state_dim = 3

def encode_state_norm(raw: tuple[int,int,int]) -> torch.Tensor:
    i, j, rem = raw
    return torch.tensor([
        i / (n_rows - 1),
        j / (n_cols - 1),
        rem
    ], device=device, dtype=torch.float32)

In [4]:
class DQN(nn.Module):
    def __init__(self, input_dim: int, n_actions: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

In [5]:
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        # store raw state tuples for later encoding
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        # encode and stack
        states_enc = torch.stack([encode_state_norm(s) for s in states])
        next_states_enc = torch.stack([encode_state_norm(s) for s in next_states])
        return (
            states_enc,
            torch.tensor(actions, dtype=torch.int64, device=device),
            torch.tensor(rewards, dtype=torch.float32, device=device),
            next_states_enc,
            torch.tensor(dones, dtype=torch.float32, device=device)
        )
    
    def __len__(self):
        return len(self.buffer)

In [6]:
n_actions = 4
buffer_capacity = 10000
batch_size = 128
gamma = 0.99
lr = 1e-3
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.995
target_update_freq = 50
num_episodes = 2000
max_steps_per_episode = 600

In [10]:
policy_net = DQN(state_dim, n_actions).to(device)
target_net = DQN(state_dim, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=lr)
replay_buffer = ReplayBuffer(buffer_capacity)

In [11]:
epsilon = epsilon_start
for episode in range(1, num_episodes + 1):
    raw_state = env.reset()
    state = encode_state_norm(raw_state)
    total_reward = 0
    
    for step in range(max_steps_per_episode):
        if random.random() < epsilon:
            action = random.randrange(n_actions)
        else:
            with torch.no_grad():
                q_vals = policy_net(state.unsqueeze(0))
                action = q_vals.argmax(dim=1).item()
        
        raw_next, reward, done, _ = env.step(action)
        next_state = encode_state_norm(raw_next)
        
        replay_buffer.push(raw_state, action, reward, raw_next, done)
        total_reward += reward
        raw_state = raw_next
        state = next_state
        
        if len(replay_buffer) >= batch_size:
            s_b, a_b, r_b, ns_b, d_b = replay_buffer.sample(batch_size)
            q_values = policy_net(s_b).gather(1, a_b.unsqueeze(1)).squeeze(1)
            with torch.no_grad():
                next_q = target_net(ns_b).max(dim=1)[0]
                target_q = r_b + gamma * next_q * (1 - d_b)
            loss = nn.MSELoss()(q_values, target_q)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if done:
            break
    
    # decay ε
    epsilon = max(epsilon_end, epsilon * epsilon_decay)
    # update target network
    if episode % target_update_freq == 0:
        target_net.load_state_dict(policy_net.state_dict())
    # log
    if episode % 10 == 0:
        print(f"Episode {episode}, Total Reward: {total_reward:.2f}, Epsilon: {epsilon:.2f}, Steps: {step + 1}, Rem Targets: {state[-1]}")

Episode 10, Total Reward: -788.00, Epsilon: 0.95, Steps: 453, Rem Targets: 0.0
Episode 20, Total Reward: -1247.00, Epsilon: 0.90, Steps: 600, Rem Targets: 0.0
Episode 30, Total Reward: -749.00, Epsilon: 0.86, Steps: 490, Rem Targets: 0.0
Episode 40, Total Reward: -1077.00, Epsilon: 0.82, Steps: 600, Rem Targets: 2.0
Episode 50, Total Reward: -1117.00, Epsilon: 0.78, Steps: 600, Rem Targets: 2.0
Episode 60, Total Reward: -966.00, Epsilon: 0.74, Steps: 600, Rem Targets: 1.0
Episode 70, Total Reward: -561.00, Epsilon: 0.70, Steps: 478, Rem Targets: 0.0
Episode 80, Total Reward: -990.00, Epsilon: 0.67, Steps: 600, Rem Targets: 1.0
Episode 90, Total Reward: -1013.00, Epsilon: 0.64, Steps: 600, Rem Targets: 2.0
Episode 100, Total Reward: -938.00, Epsilon: 0.61, Steps: 600, Rem Targets: 1.0
Episode 110, Total Reward: -1102.00, Epsilon: 0.58, Steps: 600, Rem Targets: 1.0
Episode 120, Total Reward: -906.00, Epsilon: 0.55, Steps: 600, Rem Targets: 1.0
Episode 130, Total Reward: -882.00, Epsilon:

KeyboardInterrupt: 

In [8]:
torch.save(policy_net.state_dict(), "models/A1_grid_policy.pt")

In [10]:
env.grid.shape

(15, 15)