<a href="https://colab.research.google.com/github/eschwarzbeckf/taxi/blob/main/taxi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0 Setup

In [1]:
!git clone https://github.com/eschwarzbeckf/taxi.git
%cd taxi
!pip install -r requirements.txt

Cloning into 'taxi'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 15 (delta 2), reused 3 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (15/15), 4.11 KiB | 4.11 MiB/s, done.
Resolving deltas: 100% (2/2), done.
/content/taxi
Collecting pyvirtualdisplay (from -r requirements.txt (line 2))
  Downloading PyVirtualDisplay-3.0-py3-none-any.whl.metadata (943 bytes)
Downloading PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)
Installing collected packages: pyvirtualdisplay
Successfully installed pyvirtualdisplay-3.0


# 1 Imports

In [2]:
import gymnasium as gym
import torch.nn as nn
import torch
from collections import deque
import random
import math
import time
import io, glob, base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

# 2 Code

## 2.1 Env Setup

## Functions

In [3]:
def show_video(name):
  mp4list = glob.glob(f'./{name}.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else:
    print("Could not find video")

def wrap_env(env, name):
  video = VideoRecorder(env, f'./{name}.mp4')
  return env, video

In [4]:
def select_action(q_values, start, end, decay, step):
  epsilon = (
      end + (start-end) * math.exp(-step / decay)
  )
  sample = random.random()
  if sample < epsilon:
    return random.choice(range(len(q_values)))

  return torch.argmax(q_values).item()

In [5]:
class QNetwork(nn.Module):
  def __init__(self, state_size, action_size):
    super(QNetwork, self).__init__()
    self.embedding = nn.Embedding(state_size, 64)
    self.fc1 = nn.Linear(64, 64)
    self.fc2 = nn.Linear(64, action_size)

  def forward(self, state):
    x = self.embedding(state)
    x = torch.relu(self.fc1(x))
    return self.fc2(x)

In [6]:
class ReplayBuffer:
  def __init__(self, capacity):
    self.memory = deque([],maxlen=capacity)

  def push(self, state, action, reward, next_state, done):
    experience_tuple = (state, action, reward, next_state, done)
    self.memory.append(experience_tuple)

  def __len__(self):
    return len(self.memory)

  def sample(self, batch_size):
    batch = random.sample(self.memory, batch_size)
    states, actions, rewards, next_states, dones = (zip(*batch))

    states_tensor = torch.tensor(states, dtype=torch.float32)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
    return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor

In [7]:
def update_target_network(target_network, online_network, tau):
  target_net_state_dict = target_network.state_dict()
  online_net_state_dict = online_network.state_dict()
  for key in target_net_state_dict.keys():
    target_net_state_dict[key] = (
        online_net_state_dict[key] * tau + target_net_state_dict[key] * (1-tau)
    )
    target_network.load_state_dict(target_net_state_dict)

In [8]:
env = gym.make('Taxi-v3', render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, video_folder="./video_directory")
state_size = env.observation_space.n
action_size = env.action_space.n
display = Display(visible=0, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x784b9dcec690>

In [10]:
online_network = QNetwork(state_size, action_size)
target_network = QNetwork(state_size, action_size)
target_network.load_state_dict(online_network.state_dict())
replay_buffer = ReplayBuffer(capacity=10000)
optimizer = torch.optim.Adam(online_network.parameters(), lr=0.001)
batch_size = 64
gamma = 0.95
tau = 0.005
total_steps = 0

for episode in range(20000):
  state, info = env.reset()
  done = False
  step = 0
  episode_reward = 0
  while not done:
    step += 1
    total_steps += 1
    state_tensor = torch.tensor([state], dtype=torch.long)
    q_values = online_network(state_tensor)
    action = select_action(q_values, 0.9, 0.05, 1000, step)
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    replay_buffer.push(state,action,reward,next_state,done)
    if len(replay_buffer) >= batch_size:
      states, actions,rewards, next_states, dones = replay_buffer.sample(batch_size)
      states = states.long()
      actions = actions.long()
      next_states = next_states.long()

      q_values = online_network(states).gather(1, actions).squeeze(1)
      with torch.no_grad():
        next_q_values = (
          target_network(next_states).amax(1)
        )
        target_q_values = (
            rewards + gamma * next_q_values * (1 - dones)
        )
      loss  = torch.nn.functional.mse_loss(q_values, target_q_values)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      update_target_network(target_network, online_network, tau)
    state = next_state
    episode_reward += reward

env.close()