### Imports

In [2]:
import numpy as np
import torch
import pickle
import gymnasium as gym
import highway_env

### Setting up environement

In [19]:
# Making highway environment
env = gym.make("highway-fast-v0", render_mode="human")

# Importing config
with open('1-highway-discrete-config.pkl', 'rb') as pf:
    config_dict = pickle.load(pf)
env.unwrapped.configure(config_dict)

### Creating DQN Network

In [31]:
import torch.nn as nn

class DQNNet(nn.Sequential):
    def __init__(self, state_dim, d_hid, n_actions):
        super(DQNNet,self).__init__()

        # takes as input the dimension of our state space
        self.fc1 = nn.Linear(state_dim,d_hid)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(d_hid, d_hid)
        self.relu2 = nn.ReLU()
        # Output the expected cumulative return for each action
        self.fc3 = nn.Linear(d_hid, n_actions)

### Creating Replay Buffer

In [48]:
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, maxsize):
        self.buffer = deque([], maxlen=maxsize)
        self.max_size = maxsize

    def push(self, s, a, r, sn, done):
        self.buffer.append((s,a,r,sn, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, min(len(self.buffer), batch_size))

### Creating DQN algorithm

In [53]:
from tqdm import tqdm
import random
class DQN:
    def __init__(self, env, buffer_size):

        self.env = env
        obs, info = self.env.reset()
        self.curr_obs = torch.tensor(obs.flatten())
        self.action_space = self.env.action_space
        self.q_net = DQNNet(len(self.curr_obs.flatten()), 128, self.action_space.n)
        self.target_net = DQNNet(len(self.curr_obs.flatten()), 128, self.action_space.n)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.replay_buffer = ReplayBuffer(maxsize = 1000)
        self.gamma = 0.95
        self.batchsize = 32
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.q_net.parameters())
        self.epsilon = 0.5

    def sample_action(self, epsilon):

        if random.random() < epsilon:
            return self.action_space.sample()
        else:
            return torch.argmax(self.q_net(self.curr_obs))
    
    def update_buffer(self,s,a,r,sn, done):
        s_torch = torch.tensor(s.flatten())
        a_torch = torch.tensor(a)
        r_torch = torch.tensor(r)
        sn_torch = torch.tensor(sn.flatten())
        done_torch = torch.tensor(done,dtype=torch.float32)
        self.replay_buffer.push(s_torch,a_torch,r_torch,sn_torch,done_torch)
        
    
    def compute_target(self, done, reward, obs):
        
        return reward + (1-done) * self.gamma * torch.max(self.target_net(obs))

    def sample_minibatch(self, batchsize):      
        batch = self.replay_buffer.sample(batchsize)
        batch_curr_obs, batch_a, batch_r, batch_obs, batch_done = map(list, zip(*batch))
        
        return torch.tensor(batch_curr_obs).reshape(batchsize,-1), torch.tensor(batch_a).reshape(batchsize,-1),torch.tensor(batch_r).reshape(batchsize,-1),torch.tensor(batch_obs).reshape(batchsize,-1),torch.tensor(batch_done).reshape(batchsize,-1)
    
    def step(self):
        
        # Sample an action
        a = self.sample_action(self.epsilon)
        
        # Perform one transition
        obs, reward, done, truncated, _ = self.env.step(a)
        
        # Update the replay buffer
        self.update_buffer(self.curr_obs, a, reward, obs, done)
        
        # Sample a minibatch of transitions from the replay buffer
        self.curr_obs = torch.tensor(obs.flatten())
        batch_curr_obs, batch_a, batch_r, batch_obs, batch_done = self.sample_minibatch(self.batchsize)
        
        # Compute the target
        with torch.no_grad():
            ys = self.compute_target(batch_done, batch_r, batch_obs)
        
        # Compute the predictions on the return
        preds = self.q_net(batch_curr_obs).gather(1, batch_a)

        loss = self.criterion(preds, ys)

        self.optimizer.zero_grad()

        loss.backward()

        self.optimizer.step()

        return done

    def train(self):

        done = 0
        for episode in tqdm(range(1, 500)):
            
            self.curr_obs, _ = self.env.reset()
            self.curr_obs = torch.tensor(self.curr_obs.flatten())
            while not done:
                done = self.step()

            if episode % 10 == 0:
                self.curr_obs = env.reset()

                done = 0
                cum_ret = 0
                it = 0
                with torch.no_grad():
                    self.q_net.eval()
                    
                    while not done:
                
                        a = self.sample_action(epsilon=0)
                        self.curr_obs, reward, done, truncated, _ = self.env.step(a)
                        cum_ret += self.gamma**it * reward
                        it += 1

                    self.q_net.train()
                print(f'Reward on a test environnement at episode {episode}: {cum_ret}')

                self.target_net.load_state_dict(self.q_net.state_dict())
   


In [54]:
dqn = DQN(env=env, buffer_size=500)

In [55]:
dqn.train()

  s_torch = torch.tensor(s.flatten())
  0%|          | 0/499 [00:00<?, ?it/s]


ValueError: only one element tensors can be converted to Python scalars