In [1]:
# Thanks to @MaxRohowsky
# copied->adapted from https://github.com/MaxRohowsky/chrome-dinosaur
import gymnasium as gym
import pygame
import os
import random
import numpy as np
from gymnasium import spaces

# Initialize pygame
pygame.init()
pygame.font.init()

# === Copy-paste asset loading and constants from main.py ===
SCREEN_HEIGHT = 600
SCREEN_WIDTH = 1100


# Preload images (relative to this file's location)
def load_image(*path):
    return pygame.image.load(os.path.join("chrome-dinosaur/Assets", *path))


RUNNING = [load_image("Dino", "DinoRun1.png"), load_image("Dino", "DinoRun2.png")]
JUMPING = load_image("Dino", "DinoJump.png")
DUCKING = [load_image("Dino", "DinoDuck1.png"), load_image("Dino", "DinoDuck2.png")]

SMALL_CACTUS = [load_image("Cactus", "SmallCactus1.png"),
                load_image("Cactus", "SmallCactus2.png"),
                load_image("Cactus", "SmallCactus3.png")]
LARGE_CACTUS = [load_image("Cactus", "LargeCactus1.png"),
                load_image("Cactus", "LargeCactus2.png"),
                load_image("Cactus", "LargeCactus3.png")]
BIRD = [load_image("Bird", "Bird1.png"), load_image("Bird", "Bird2.png")]

CLOUD = load_image("Other", "Cloud.png")
BG = load_image("Other", "Track.png")


# === Re-define classes exactly as in main.py, but without global state ===

class Dinosaur:
    X_POS = 80
    Y_POS = 310
    Y_POS_DUCK = 340
    JUMP_VEL = 8.5

    def __init__(self):
        self.duck_img = DUCKING
        self.run_img = RUNNING
        self.jump_img = JUMPING

        self.dino_duck = False
        self.dino_run = True
        self.dino_jump = False

        self.step_index = 0
        self.jump_vel = self.JUMP_VEL
        self.image = self.run_img[0]
        self.dino_rect = self.image.get_rect()
        self.dino_rect.x = self.X_POS
        self.dino_rect.y = self.Y_POS

    def update(self, action):
        # Action: 0 = nothing, 1 = jump, 2 = duck
        if action == 1 and not self.dino_jump:
            self.dino_duck = False
            self.dino_run = False
            self.dino_jump = True
        elif action == 2 and not self.dino_jump:
            self.dino_duck = True
            self.dino_run = False
            self.dino_jump = False
        elif not (self.dino_jump or action == 2):
            self.dino_duck = False
            self.dino_run = True
            self.dino_jump = False

        if self.dino_duck:
            self.duck()
        elif self.dino_run:
            self.run()
        elif self.dino_jump:
            self.jump()

        if self.step_index >= 10:
            self.step_index = 0

    def duck(self):
        self.image = self.duck_img[self.step_index // 5]
        self.dino_rect = self.image.get_rect()
        self.dino_rect.x = self.X_POS
        self.dino_rect.y = self.Y_POS_DUCK
        self.step_index += 1

    def run(self):
        self.image = self.run_img[self.step_index // 5]
        self.dino_rect = self.image.get_rect()
        self.dino_rect.x = self.X_POS
        self.dino_rect.y = self.Y_POS
        self.step_index += 1

    def jump(self):
        self.image = self.jump_img
        if self.dino_jump:
            self.dino_rect.y -= self.jump_vel * 4
            self.jump_vel -= 0.8
        if self.jump_vel < -self.JUMP_VEL:
            self.dino_jump = False
            self.jump_vel = self.JUMP_VEL

    def draw(self, screen):
        screen.blit(self.image, (self.dino_rect.x, self.dino_rect.y))


class Cloud:
    def __init__(self):
        self.x = SCREEN_WIDTH + random.randint(800, 1000)
        self.y = random.randint(50, 100)
        self.image = CLOUD
        self.width = self.image.get_width()

    def update(self, game_speed):
        self.x -= game_speed
        if self.x < -self.width:
            self.x = SCREEN_WIDTH + random.randint(2500, 3000)
            self.y = random.randint(50, 100)

    def draw(self, screen):
        screen.blit(self.image, (self.x, self.y))


class Obstacle:
    def __init__(self, image, type, rect_y):
        self.image = image
        self.type = type
        self.rect = self.image[self.type].get_rect()
        self.rect.x = SCREEN_WIDTH
        self.rect.y = rect_y

    def update(self, game_speed):
        self.rect.x -= game_speed
        if self.rect.x < -self.rect.width:
            return True  # Mark for removal
        return False

    def draw(self, screen):
        screen.blit(self.image[self.type], self.rect)


class SmallCactus(Obstacle):
    def __init__(self):
        super().__init__(SMALL_CACTUS, random.randint(0, 2), 325)


class LargeCactus(Obstacle):
    def __init__(self):
        super().__init__(LARGE_CACTUS, random.randint(0, 2), 300)


class Bird(Obstacle):
    def __init__(self):
        super().__init__(BIRD, 0, 250)
        self.index = 0

    def draw(self, screen):
        if self.index >= 9:
            self.index = 0
        screen.blit(self.image[self.index // 5], self.rect)
        self.index += 1


# === Gymnasium Environment ===

class DinoEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}

    def __init__(self, render_mode=None):
        self.render_mode = render_mode

        # Action: 0 = NOOP, 1 = JUMP, 2 = DUCK
        self.action_space = spaces.Discrete(3)

        # Observation: [dino_y, dino_vel_y, nearest_obstacle_dist, obstacle_type, game_speed, is_ducking]
        # obstacle_type: 0=none, 1=small cactus, 2=large cactus, 3=bird
        low = np.array([0, -10, 0, 0, 5, 0], dtype=np.float32)
        high = np.array([SCREEN_HEIGHT, 10, SCREEN_WIDTH, 3, 50, 1], dtype=np.float32)
        self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)

        # Internal state
        self.dino = None
        self.obstacles = []
        self.cloud = None
        self.game_speed = 0
        self.points = 0
        self.x_pos_bg = 0
        self.y_pos_bg = 380

        # Rendering
        self.screen = None
        self.clock = pygame.time.Clock()

    def _get_obs(self):
        # Find nearest obstacle
        nearest_dist = SCREEN_WIDTH
        nearest_type = 0  # 0 = none
        for obs in self.obstacles:
            dist = obs.rect.x - self.dino.dino_rect.x
            if 0 <= dist < nearest_dist:
                nearest_dist = dist
                if isinstance(obs, SmallCactus):
                    nearest_type = 1
                elif isinstance(obs, LargeCactus):
                    nearest_type = 2
                elif isinstance(obs, Bird):
                    nearest_type = 3

        return np.array([
            self.dino.dino_rect.y,
            self.dino.jump_vel if self.dino.dino_jump else 0.0,
            nearest_dist,
            nearest_type,
            self.game_speed,
            float(self.dino.dino_duck)
        ], dtype=np.float32)

    def _get_info(self):
        return {"score": self.points}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.dino = Dinosaur()
        self.obstacles = []
        self.cloud = Cloud()
        self.game_speed = 20
        self.points = 0
        self.x_pos_bg = 0

        if self.render_mode == "human" and self.screen is None:
            pygame.display.init()
            self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
            pygame.display.set_caption("Dino RL")

        return self._get_obs(), self._get_info()

    def step(self, action):
        # Update dino
        self.dino.update(action)

        # Spawn obstacles
        if len(self.obstacles) == 0:
            r = random.randint(0, 2)
            if r == 0:
                self.obstacles.append(SmallCactus())
            elif r == 1:
                self.obstacles.append(LargeCactus())
            else:
                self.obstacles.append(Bird())

        # Update obstacles
        to_remove = []
        collision = False
        for obs in self.obstacles:
            if obs.update(self.game_speed):
                to_remove.append(obs)
            elif self.dino.dino_rect.colliderect(obs.rect):
                collision = True

        for obs in to_remove:
            self.obstacles.remove(obs)

        # Update cloud & background (not needed for obs, but for render)
        self.cloud.update(self.game_speed)
        self.x_pos_bg = (self.x_pos_bg - self.game_speed) % -BG.get_width()
        self.points += 1

        # Increase difficulty
        if self.points % 100 == 0:
            self.game_speed += 1

        # Reward & done
        reward = 1.0
        done = collision
        if done:
            reward = -100

        # Render
        if self.render_mode == "human":
            self._render_frame()

        return self._get_obs(), reward, done, False, self._get_info()

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):
        if self.screen is None and self.render_mode == "human":
            return

        canvas = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
        canvas.fill((255, 255, 255))

        # Draw background
        canvas.blit(BG, (self.x_pos_bg, self.y_pos_bg))
        canvas.blit(BG, (self.x_pos_bg + BG.get_width(), self.y_pos_bg))

        # Draw game objects
        self.dino.draw(canvas)
        for obs in self.obstacles:
            obs.draw(canvas)
        self.cloud.draw(canvas)

        # Draw score
        font = pygame.font.Font(pygame.font.get_default_font(), 20)
        text = font.render(f"Score: {self.points}", True, (0, 0, 0))
        canvas.blit(text, (1000, 40))

        if self.render_mode == "human":
            self.screen.blit(canvas, (0, 0))
            pygame.display.flip()
            self.clock.tick(self.metadata["render_fps"])
        elif self.render_mode == "rgb_array":
            return np.transpose(pygame.surfarray.array3d(canvas), axes=(1, 0, 2))

    def close(self):
        if self.screen:
            pygame.display.quit()
            self.screen = None

  from pkg_resources import resource_stream, resource_exists


## Test play

In [2]:
env = DinoEnv(render_mode="human")
obs, info = env.reset()

done = False
while not done:
    action = env.action_space.sample()
    obs, reward, done, _, info = env.step(action)
    if done:
        print(f"Game Over! Final Score: {info['score']}")
        break

env.close()

Game Over! Final Score: 47


## Test play by hand

In [3]:
env = DinoEnv(render_mode="human")
obs, info = env.reset()

done = False
clock = pygame.time.Clock()

while not done:
    action = 0  # default: no action

    # Check for key presses
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            done = True

    keys = pygame.key.get_pressed()
    if keys[pygame.K_UP] or keys[pygame.K_SPACE]:
        action = 1  # jump
    elif keys[pygame.K_DOWN]:
        action = 2  # duck
    # else: action = 0 (do nothing)

    # Step the environment
    obs, reward, done, _, info = env.step(action)

    # Optional: cap FPS for consistent speed (already done in env if using clock)
    clock.tick(30)

    if done:
        print(f"Game Over! Final Score: {info['score']}")

env.close()

Game Over! Final Score: 1713
