In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import cv2
import random
import os
from collections import deque
import matplotlib.pyplot as plt
import gym_chrome_dino
from PIL.ImagePath import Path
from gym_chrome_dino.utils.wrappers import make_dino

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


# Preprocessing

In [2]:
def preprocess_observation(obs):
    """
    Convert RGB (150, 600, 3) â†’ grayscale (84, 84) and normalize to [0,1]
    """
    # Convert to grayscale
    gray = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
    # Resize to 84x84 (standard for Atari/Dino)
    resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
    # Normalize to [0,1]
    normalized = resized.astype(np.float32) / 255.0
    return normalized

# Stack 4 frames
class FrameStack:
    def __init__(self, k=4):
        self.k = k
        self.frames = deque([], maxlen=k)

    def reset(self, obs):
        processed = preprocess_observation(obs)
        self.frames = deque([processed] * self.k, maxlen=self.k)
        return np.stack(self.frames, axis=0)

    def step(self, obs):
        processed = preprocess_observation(obs)
        self.frames.append(processed)
        return np.stack(self.frames, axis=0)

# Custom CNN DQN Model (PyTorch)

In [3]:
class DQN(nn.Module):
    def __init__(self, action_dim):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, action_dim)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

## DQN Agent

In [4]:
class DQNAgent:
    def __init__(self, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0, eps_min=0.01, eps_decay=0.995):
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.eps_min = eps_min
        self.eps_decay = eps_decay
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_net = DQN(action_dim).to(self.device)
        self.target_net = DQN(action_dim).to(self.device)
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.update_target()

    def update_target(self):
        self.target_net.load_state_dict(self.q_net.state_dict())

    def act(self, state):
        if np.random.random() <= self.epsilon:
            return random.randrange(self.action_dim)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        q_vals = self.q_net(state)
        return q_vals.argmax().item()

    def train_step(self, batch, double_dqn=True):
        states, actions, rewards, next_states, dones = batch
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.BoolTensor(dones).to(self.device)

        curr_q = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        if double_dqn:
            next_actions = self.q_net(next_states).argmax(1)
            next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        else:
            next_q = self.target_net(next_states).max(1)[0]

        target_q = rewards + (self.gamma * next_q * ~dones)
        loss = nn.MSELoss()(curr_q, target_q.detach())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.eps_min:
            self.epsilon *= self.eps_decay

        return loss.item()

# Training loop

In [5]:
def train_dino_dqn(
    env_name="ChromeDino-v0",
    total_steps=50_000,
    batch_size=32,
    replay_size=10_000,
    update_freq=1000,
    save_path="dino_dqn.pth",
    render=False
):
    env = gym.make(env_name, render_mode="human" if render else None)
    action_dim = env.action_space.n

    agent = DQNAgent(action_dim)
    replay_buffer = deque(maxlen=replay_size)
    frame_stack = FrameStack()

    scores = []
    losses = []
    score = 0

    obs, _ = env.reset()
    state = frame_stack.reset(obs)

    for step in range(total_steps):
        action = agent.act(state)
        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        next_state = frame_stack.step(next_obs)
        replay_buffer.append((state, action, reward, next_state, done))
        state = next_state
        score += reward

        if done:
            scores.append(score)
            score = 0
            obs, _ = env.reset()
            state = frame_stack.reset(obs)

        # Train if enough samples
        if len(replay_buffer) >= batch_size:
            batch = random.sample(replay_buffer, batch_size)
            states, actions, rewards, next_states, dones = zip(*batch)
            loss = agent.train_step((states, actions, rewards, next_states, dones))
            losses.append(loss)

        # Update target network
        if step % update_freq == 0:
            agent.update_target()

        if step % 1000 == 0:
            avg_score = np.mean(scores[-10:]) if scores else 0
            avg_loss = np.mean(losses[-100:]) if losses else 0
            print(f"Step {step}/{total_steps} | Avg Score (last 10): {avg_score:.2f} | Îµ: {agent.epsilon:.3f} | Loss: {avg_loss:.4f}")

    # Save model
    torch.save(agent.q_net.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    # Plot metrics
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(scores)
    plt.title("Episode Scores")
    plt.xlabel("Episode")
    plt.ylabel("Score")

    plt.subplot(1, 2, 2)
    plt.plot(losses)
    plt.title("Training Loss")
    plt.xlabel("Training Step")
    plt.ylabel("Loss")
    plt.tight_layout()
    plt.savefig("dino_dqn_training.png")
    plt.show()

    env.close()
    return agent

In [6]:
# train_dino_dqn(total_steps=10_000, render=False)

In [7]:
train_dino_dqn(
    env_name="ChromeDinoNoBrowser-v0",
        total_steps=1_000,   # Just 5k steps (~1â€“5 mins)
        render=True,         # ðŸ‘ˆ THIS enables live view
        batch_size=16,       # Smaller batch for faster updates
        update_freq=500      # Update target net less frequently
    )

NameNotFound: Environment `ChromeDinoNoBrowser` doesn't exist.

# Setup the env from scratch

In [None]:
import base64
import io
import numpy as np
from PIL import Image
import gymnasium as gym
from gymnasium import spaces
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
import cv2
import time
from pathlib import Path

def rgba2rgb(im):
    bg = Image.new("RGB", im.size, (255, 255, 255))  # fill background as white color
    bg.paste(im, mask=im.split()[3])  # 3 is the alpha channel
    return bg

class Timer():
    def __init__(self):
        self.t0 = time.time()
    def tick(self):
        t1 = time.time()
        dt = t1 - self.t0
        self.t0 = t1
        return dt

class DinoGame():
    def __init__(self, render=False, accelerate=False, autoscale=False):
        options = Options()
        options.add_argument('--disable-infobars')
        options.add_argument('--mute-audio')
        options.add_argument('--no-sandbox')
        options.add_argument('--window-size=800,600')
        if not render:
            options.add_argument('--headless=new')

        self.driver = webdriver.Chrome(options=options)

        self.driver.get('https://elvisyjlin.github.io/t-rex-runner/')
        # self.driver.get('chrome://dino')
        # self.driver.get(Path("./files/t-rex-runner-gh-pages/index.html").resolve().as_uri())
        self.defaults = self.get_parameters()  # default parameters
        if not accelerate:
            self.set_parameter('config.ACCELERATION', 0)
        if not autoscale:
            self.driver.execute_script('Runner.instance_.setArcadeModeContainerScale = function(){};')
        self.press_space()

    def get_parameters(self):
        params = {}
        params['config.ACCELERATION'] = self.driver.execute_script('return Runner.config.ACCELERATION;')
        return params

    def is_crashed(self):
        return self.driver.execute_script('return Runner.instance_.crashed;')

    def is_inverted(self):
        return self.driver.execute_script('return Runner.instance_.inverted;')

    def is_paused(self):
        return self.driver.execute_script('return Runner.instance_.paused;')

    def is_playing(self):
        return self.driver.execute_script('return Runner.instance_.playing;')

    def press_space(self):
        return self.driver.find_element('tag name', 'body').send_keys(Keys.SPACE)

    def press_up(self):
        return self.driver.find_element('tag name', 'body').send_keys(Keys.UP)

    def press_down(self):
        return self.driver.find_element('tag name', 'body').send_keys(Keys.DOWN)

    def pause(self):
        return self.driver.execute_script('Runner.instance_.stop();')

    def resume(self):
        return self.driver.execute_script('Runner.instance_.play();')

    def restart(self):
        return self.driver.execute_script('Runner.instance_.restart();')

    def close(self):
        self.driver.quit()

    def get_score(self):
        digits = self.driver.execute_script('return Runner.instance_.distanceMeter.digits;');
        return int(''.join(digits))

    def get_canvas(self):
        return self.driver.execute_script('return document.getElementsByClassName("runner-canvas")[0].toDataURL().substring(22);')

    def set_parameter(self, key, value):
        self.driver.execute_script('Runner.{} = {};'.format(key, value))

    def restore_parameter(self, key):
        self.set_parameter(key, self.defaults[key])

class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env, width, height):
        gym.ObservationWrapper.__init__(self, env)
        self.width = width
        self.height = height
        self.observation_space = spaces.Box(low=0, high=255,
            shape=(self.height, self.width, 1), dtype=np.uint8)

    def observation(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return frame[:, :, None]

class TimerEnv(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.timer = Timer()

    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self.timer.tick()
        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        info['timedelta'] = self.timer.tick()
        return obs, reward, terminated, truncated, info

class ChromeDinoEnv(gym.Env):
    metadata = {'render_modes': ['rgb_array'], 'render_fps': 10}

    def __init__(self, render=False, accelerate=False, autoscale=False):
        self.game = DinoGame(render, accelerate, autoscale)
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(150, 600, 3), dtype=np.uint8
        )
        self.action_space = spaces.Discrete(4)  # NOOP, UP, DOWN, SPACE
        self.gametime_reward = 0.1
        self.gameover_penalty = -1
        self.current_frame = np.zeros((150, 600, 3), dtype=np.uint8)
        self._action_set = [0, 1, 2, 3]

    def _observe(self):
        s = self.game.get_canvas()
        b = io.BytesIO(base64.b64decode(s))
        i = Image.open(b)
        i = rgba2rgb(i)
        a = np.array(i)
        self.current_frame = a
        return self.current_frame

    def step(self, action):
        if action == 1:  # UP
            self.game.press_up()
        elif action == 2:  # DOWN
            self.game.press_down()
        elif action == 3:  # SPACE
            self.game.press_space()
        # action == 0 is NOOP

        observation = self._observe()
        reward = self.gametime_reward
        terminated = False
        truncated = False
        info = {}

        if self.game.is_crashed():
            reward = self.gameover_penalty
            terminated = True

        return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.game.restart()
        return self._observe(), {}

    def render(self):
        return self.current_frame

    def close(self):
        self.game.close()

    def get_score(self):
        return self.game.get_score()

    def set_acceleration(self, enable):
        if enable:
            self.game.restore_parameter('config.ACCELERATION')
        else:
            self.game.set_parameter('config.ACCELERATION', 0)

    def get_action_meanings(self):
        return [ACTION_MEANING[i] for i in self._action_set]

ACTION_MEANING = {
    0 : "NOOP",
    1 : "UP",
    2 : "DOWN",
    3 : "SPACE",
}

In [19]:
env = ChromeDinoEnv(render=True, accelerate=False, autoscale=False)
obs, info = env.reset()
for _ in range(1000):
    action = env.action_space.sample()  # Random action
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        env.reset()
env.close()

KeyboardInterrupt: 