### Authors
- Gabriel Souza Lima
- Guilherme Mertens
- Kiyoshi Araki
- Lucas Tramonte

### Libraries

In [None]:
import gymnasium as gym
import highway_env
import pickle

# Load custom configuration
with open("config.pkl", "rb") as f:
    config = pickle.load(f)

# Create and configure the environment
env = gym.make('highway-fast-v0', render_mode='rgb_array')
env.unwrapped.configure(config)
print(env.reset())

### Neural Network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.out = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

In [None]:
import random
from collections import deque, namedtuple

Transition = namedtuple("Transition", ("state", "action", "reward", "next_state", "done"))

class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

### Training function

In [None]:
import numpy as np

def train(env, episodes=300):
    obs_shape = env.observation_space.shape
    n_actions = env.action_space.n
    input_dim = np.prod(obs_shape)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    policy_net = DQN(input_dim, n_actions).to(device)
    target_net = DQN(input_dim, n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())

    optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-3)
    memory = ReplayBuffer()

    gamma = 0.99
    epsilon = 1.0
    epsilon_min = 0.05
    epsilon_decay = 0.995
    batch_size = 64
    target_update_freq = 10

    all_rewards = []

    for ep in range(episodes):
        obs = env.reset()[0]
        state = torch.tensor(obs, dtype=torch.float32).flatten().to(device)
        done = False
        total_reward = 0

        while not done:
            # Choose action with epsilon-greedy policy
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q_vals = policy_net(state)
                    action = q_vals.argmax().item()

            next_obs, reward, done, truncated, _ = env.step(action)
            next_state = torch.tensor(next_obs, dtype=torch.float32).flatten().to(device)
            memory.push(state, action, reward, next_state, done)

            state = next_state
            total_reward += reward

            # Model update
            if len(memory) >= batch_size:
                transitions = memory.sample(batch_size)
                batch = Transition(*zip(*transitions))

                state_batch = torch.stack(batch.state)
                action_batch = torch.tensor(batch.action, dtype=torch.long).unsqueeze(1).to(device)
                reward_batch = torch.tensor(batch.reward, dtype=torch.float32).to(device)
                next_state_batch = torch.stack(batch.next_state)
                done_batch = torch.tensor(batch.done, dtype=torch.float32).to(device)

                q_values = policy_net(state_batch).gather(1, action_batch).squeeze()
                with torch.no_grad():
                    next_q = target_net(next_state_batch).max(1)[0]
                expected_q = reward_batch + (1 - done_batch) * gamma * next_q

                loss = F.mse_loss(q_values, expected_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        all_rewards.append(total_reward)

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        # Update target network
        if ep % target_update_freq == 0:
            target_net.load_state_dict(policy_net.state_dict())

        print(f"Ep {ep} | Reward: {total_reward:.2f} | Epsilon: {epsilon:.3f}")

    return all_rewards

### Plotting

In [None]:
import matplotlib.pyplot as plt

rewards = train(env)

plt.plot(rewards)
plt.title("DQN - Reward per Episode")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.grid()
plt.show()

### Visualization

Still to implement, a video showing the simulation