In [191]:
import pygame as pg
import numpy as np
import random
import matplotlib.pyplot as plt
import os
import gym
from gym import Env
from gym.spaces import Box, Discrete, Dict
from gym.wrappers import FlattenObservation
from gym.envs.registration import register
from setuptools import setup
from stable_baselines3.common import env_checker

Python 3.7.13


In [5]:
class Sprite():
    # superclass for all objects in the game
    def __init__(self, img, pos, game):
        # each sprite gets a position, an image (or list of images) and a reference to the game itself
        self.position = np.array(pos, dtype=np.float32)
        self.image = img
        self.game = game
        if type(self.image) == list:
            self.height = self.image[0].get_rect()[3]
            self.width = self.image[0].get_rect()[2]
        else:
            self.height = self.image.get_rect()[3]
            self.width = self.image.get_rect()[2]
            
        # add the instance to the sprite list of the game
        self.game.all_sprites.append(self)
        
    def update(self):
        # every iteration, update the position of the object
        self.position[0] -= self.game.game_speed
        # if the object has moved off screen, reset
        if self.position[0] + self.width < 20:
            self.reset()
        self.height = self.image.get_rect()[3]
        self.width = self.image.get_rect()[2]
        
    def get_obs(self):
        # return the rect of the object (x, y, w, h)
        return np.append([self.position], [self.width, self.height])
        
    def draw(self, screen):
        # function to draw the object to the pg.display
        screen.blit(self.image, self.position)
        #pg.draw.rect(screen, (255,0,0), (self.position, (self.width, self.height)), 2)

    def reset(self):
        # resets the object to a location to the right of the display
        # we need to find a valid location for the reset position in order to not make the game impossible when obstacles are too close together
        min_dist = int(self.game.game_speed * 35)
        max_dist = int(self.game.game_speed * 55)
        max_x_pos = 0
        for obstacle in self.game.obstacles:
            if obstacle.position[0] > max_x_pos:
                max_x_pos = obstacle.position[0]
        # set the new position
        self.position[0] = max(max_x_pos + random.randint(min_dist, max_dist), 1100)

In [6]:
class Background(Sprite):
    # subclass of Sprite for the Background
    def __init__(self, img, pos, game, secondary=False):
        super(Background, self).__init__(img, pos, game)
        self.secondary = secondary
        # the actual game background consists of two intances of the Background class right behind each other
        if not self.secondary:
            new_bg = Background(self.image, (self.width, self.game.ground_height-15), self.game, secondary=True)

    def reset(self):
        # the reset during an episode until the environment is done
        self.position[0] = self.width
        if self.game.full_reset:
            # the reset if the environment is done
            if not self.secondary:
                self.position[0] = 0
            else:
                self.position[0] = self.width

In [7]:
class Cloud(Sprite):
    # subclass of Sprite for the Clouds
    def __init__(self, img, pos, game):
        super(Cloud, self).__init__(img, pos, game)
        
    def reset(self):
        # reset for the cloud
        self.position[0] = 1100 + random.randint(2500,3000)
        self.position[1] = random.randint(50,100)

In [8]:
class Cactus(Sprite):
    # subclass of Sprite for the Cactus
    def __init__(self, img, pos, game):
        super(Cactus, self).__init__(img, pos, game)
        # set the y position to the ground, append to game obstacles, get an image
        self.position[1] = self.game.ground_height - self.height
        self.game.obstacles.append(self)
        self.imgs = self.image
        self.image = random.choice(self.imgs)
        
    def reset(self):
        # reset the cactus to a new position, get a new random image
        super(Cactus, self).reset()
        self.image = random.choice(self.imgs)
        self.height = self.image.get_rect()[3]
        self.position[1] = self.game.ground_height - self.height
        self.game.point_counter = True
    

In [9]:
class Bird(Sprite):
    # subclass of Sprite for the Bird
    def __init__(self, img, pos, game):
        super(Bird, self).__init__(img, pos, game)
        self.game.obstacles.append(self)
        self.imgs = self.image
        self.image = self.imgs[0]
        self.step_index = 0
        self.elevation = random.randint(100,250)
        
    def update(self):
        # update, get flying animation
        super(Bird, self).update()
        self.position[1] = self.game.ground_height - self.elevation
        if self.step_index >= 9:
            self.step_index = 0
        self.image = self.imgs[self.step_index//5]
        self.step_index += 1
        
    def reset(self):
        # reset, get a new y position for the bird
        super(Bird, self).reset()
        self.elevation = random.randint(80,200)
        self.game.point_counter = True

In [10]:
class Dino(Sprite):
    # subclass of Sprite for the Dino
    def __init__(self, img, pos, game):
        super(Dino, self).__init__(img, pos, game)
        self.imgs = img
        self.duck_offset = np.array([0,33], dtype=np.float32)
        self.jump_vel = 8.5
        
        # state = [running, jumping, ducking]
        self.state = np.array([True, False, False], dtype=bool)
        self.new_state = self.state
        self.step_index = 0
        
        self.position[1] = self.game.ground_height - self.height
        self.image = self.imgs[0]
        
    def update(self):
        # we don't want to do the usual sprite update for the dino, it stays in place
        pass
        
    def reset(self):
        # set back on the ground if jumping, set starting image
        self.position = np.array([10,self.game.ground_height-self.height], dtype=np.float32)
        self.image = self.imgs[0]
            
    def take_action(self, choice):
        # perform a chosen action
        if self.step_index >= 10:
            self.step_index = 0
            
        actions = [self.run, self.jump, self.duck]
        self.new_state = np.zeros(3, dtype=bool)
        self.new_state[choice] = True
            
        # unless the player is still in the air (jumping), apply new action
        if not self.state[1]:
            actions[choice]()
            self.state = self.new_state
        else:
            actions[1]()
                
    def run(self):
        # if not currently jumping or about to duck, run (do nothing)
        if not self.state[1] or self.new_state[2] or self.state[2]:
            # iterates through the images for animation
            self.image = self.imgs[:2][self.step_index // 5]
            self.step_index += 1
            self.position[1] = self.game.ground_height - self.height
    
    def jump(self):
        # jump
        self.image = self.imgs[2]
        if self.state[1]:
            self.position[1] -= self.jump_vel * 4
            self.jump_vel -= 0.8
        if self.position[1] + self.height >= self.game.ground_height:
            self.state[1] = False
            self.jump_vel = 10
            self.position[1] = self.game.ground_height - self.height
    
    def duck(self):
        if not self.state[1]:
            # iterates through the images for animation
            self.image = self.imgs[3:5][self.step_index // 5]
            self.step_index += 1
        # unless the player was already ducking, we apply an offset to the position
        if not self.state[2]:
            self.position[1] = self.game.ground_height - self.height
            self.position += self.duck_offset


In [204]:
class ChromeDinoEnv(Env):
    # gym Env subclass, follows gym documentation
    metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps':30}
    
    def __init__(self, render_mode=None):
        super().__init__()
        self.window_size = (1100, 600)

        self.points = 0
        self.ground_height = 380
        self.game_speed = 15
        self.full_reset = True
        self.all_sprites = []
        self.obstacles = []
        self.point_counter = False
        
        imgs = []
        # get all images for the game animation
        assets = ['Assets/'+i for i in ['Cactus', 'Bird', 'Dino', 'Other']]
        all_assets = [['SmallCactus1.png', 'SmallCactus2.png', 'SmallCactus3.png','LargeCactus1.png', 'LargeCactus2.png', 'LargeCactus3.png'], ['Bird1.png', 'Bird2.png'], ['DinoRun1.png', 'DinoRun2.png', 'DinoJump.png', 'DinoDuck1.png', 'DinoDuck2.png'], ['Cloud.png', 'GameOver.png', 'Reset.png', 'Track.png']]
        for idx, asset in enumerate(assets):
            imgs.append([pg.image.load(os.path.join(asset, i)) for i in all_assets[idx]])
            
        cactus_imgs, bird_imgs, dino_imgs, other_imgs = imgs
        
        # initialize all Objects in the game
        self.player = Dino(dino_imgs, (10,333), self)
        self.cactus= Cactus(cactus_imgs, (0, 300), self)
        self.bird = Bird(bird_imgs, (0, 250), self)
        self.background = Background(other_imgs[3], (0,self.ground_height-15), self)
        self.cloud = Cloud(other_imgs[0], (1100+random.randint(800,1000),(random.randint(50,100))), self)

        # (x,y,w,h) for dino, bird and cactus + bool is_jumping and game speed
        self.observation_space = Dict(
            {
            "player": Box(low=np.array([0,0,10,10], dtype=np.float32), high=np.array([1000,350,300,300], dtype=np.float32), dtype=np.float32),
            "bird": Box(low=np.array([-200,0,10,10], dtype=np.float32), high=np.array([5000,350,300,300], dtype=np.float32), dtype=np.float32),
            "cactus": Box(low=np.array([-200,0,10,10], dtype=np.float32), high=np.array([5000,350,300,300], dtype=np.float32), dtype=np.float32),
            "is_jumping": Box(0, 1, shape=(1,), dtype=bool),
            "speed": Box(10, 100, shape=(1,), dtype=np.float32)
            }
        )
        
        
        # we have 3 discrete actions: run(do nothing), jump, duck
        self.action_space = Discrete(3)
        
        self.game_over = False
        
        assert render_mode is None or render_mode in self.metadata['render_modes']
        self.render_mode = render_mode
        
        self.window = None
        self.clock = None
        
    def check_for_collision(self, obstacle):
        # checks if two rectangles collide
        o_x, o_y = obstacle.position
        p_x, p_y = self.player.position
        o_w, o_h = obstacle.width, obstacle.height
        p_w, p_h = self.player.width, self.player.height
        return o_x + o_w >= p_x and o_x <= p_x + p_w and o_y + o_h >= p_y and o_y <= p_y + p_h

    def _get_obs(self):
        return {"player": self.player.get_obs(),"bird": self.bird.get_obs(), "cactus": self.cactus.get_obs(), "is_jumping": self.player.get_obs()[1] < 285, "speed": self.game_speed}
    
    def step(self, action):
        # perform a player action
        self.player.take_action(action)
        
        # update all game objects
        for sprite in self.all_sprites:
            sprite.update()
        
        # get relevant information from the environment
        observation = self._get_obs()
        done = self.get_done()
        self.full_reset = done
        
        reward = 1
        """
        # gives a reward of 10 when passing an obstacle, -0.1 when jumping, 0 when on the ground, and -10 when done
        if self.point_counter:
            reward = 10
            self.point_counter = False
            # increase game speed every time an obstacle is passed
            if self.game_speed < 100:
                self.game_speed += 0.2
        elif observation['is_jumping']:
            reward = -0.1
        else:
            reward = 0
        if done:
            reward = -10
        """
        
        info = self._get_info()
        
        # renders the Environment if render_mode is set to human
        if self.render_mode == 'human':
            self._render_frame()
            
        self.points += reward
        return observation, reward, done, info
    
    def render(self, mode='rgb_array'):
        self.render_mode=mode
        if self.render_mode == 'rgb_array':
            return self._render_frame()
        
    def _render_frame(self):
        # render the environment using pygame
        if self.window is None and self.render_mode == 'human':
            pg.init()
            pg.font.init()
            pg.display.init()
            self.window = pg.display.set_mode((1100, 600))
        if self.clock is None and self.render_mode == 'human':
            self.clock = pg.time.Clock()
        
        # init the screen
        self.screen = pg.Surface((1100, 600))
        self.screen.fill((255,255,255))
            
        # draw all game objects on the screen
        for sprite in self.all_sprites:
            sprite.draw(self.screen)
            
        if self.render_mode == 'human':
            font = pg.font.Font('freesansbold.ttf', 20)    
            text = font.render('Score: ' + str(self.points), True, (0,0,0))
            text_rect = text.get_rect()
            text_rect.center = (1000,40)
            self.window.blit(self.screen, self.screen.get_rect())
            self.window.blit(text, (1000,40))
            pg.event.pump()
            pg.display.update()
            self.clock.tick(self.metadata['render_fps'])
        else:
            return np.transpose(
            np.array(pg.surfarray.pixels3d(self.screen)), axes=(1,0,2))
        
    
    def reset(self):
        # reset all sprites to a semi-random (specific to their class) location 
        for sprite in self.all_sprites:
            sprite.reset()
        # reset game stats
        self.point_counter = False
        self.points = 0
        self.game_speed = 14
        observation =  self._get_obs()
        info = self._get_info()
        return observation
    
    def close(self):
        # close the pg.display
        if self.window is not None:
            pg.display.quit()
            pg.quit()
        self.window = None
        self.clock = None
 
    def get_done(self):
        # checks the terminal state for the env (player collides with obstacle)
        for obstacle in self.obstacles:
            if self.check_for_collision(obstacle):
                return True
        return False
    
    def _get_info(self):
        return{'score': self.points}
    
    def play(self):
        # lets a human play the game
        self.render_mode = 'human'
        # take a step to initialize the pygame display in env._render_frame()
        if self.window is None:
            self.reset()
            self.step(0)
            done = False
        while not done:
            # get the input
            user_input = pg.key.get_pressed()
            # jump
            if user_input[pg.K_UP]:
                action =  1
            # duck 
            elif user_input[pg.K_DOWN]:            
                action =  2
            # run
            else:
                action = 0
            _, _, done, _ = self.step(action)
        self.close()

In [205]:
env = ChromeDinoEnv()
env = FlattenObservation(env)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


In [206]:
env_checker.check_env(env)