In [29]:
import gymnasium as gym
from gymnasium.spaces import Discrete, Tuple, Box
import numpy as np
import pygame
import random
import math

In [30]:
def human_render_loop(agent, env, transpose=False, print_step=False):
    obs, info = env.reset()

    clock = pygame.time.Clock()
    fps = env.metadata["render_fps"]

    pygame.display.init()
    size = env.render().shape[:2]
    if transpose:
        size = size[::-1]

    display = pygame.display.set_mode(size)

    term = trunc = quit = False
    step = 0
    total_reward = 0

    if print_step:
        print("initial state")
        print("\tobservation", obs)
        print("\tinfo", info)

    while not (term or trunc or quit):
        quit = pygame.key.get_pressed()[pygame.K_q]

        action = agent(obs)
        obs, reward, term, trunc, info = env.step(action)

        pygame.event.pump()

        step += 1
        total_reward += reward

        if print_step:
            print("step", step)
            print("\taction", action)
            print("\tobservation", obs)
            print("\treward", reward)
            print("\tterminated", term)
            print("\ttruncated", trunc)
            print("\tinfo", info)

        frame = env.render()
        if transpose:
            frame = frame.transpose((1, 0, 2))

        pygame.surfarray.blit_array(display, frame)
        pygame.display.flip()
        clock.tick(fps)

    if print_step:
        print("Total reward", total_reward)

    pygame.display.quit()

Game

In [31]:
HEIGHT = 600
WIDTH = 800

ASTEROID_NBR = 10
ASTEROID_SPEED = 5
ASTEROID_MIN_RADIUS = 50
ASTEROID_MAX_RADIUS = 150
SHIP_SPEED = 5
MISSILE_SPEED = 10

# Couleurs
WHITE = (150, 0, 0)
BLACK = (0, 0, 0)
YELLOW = (255, 255, 0)

# Chargement des images
ship_image = pygame.image.load("ship.png")
asteroid_base_image = pygame.image.load("asteroid.png")

# Rotation et redimensionnement de la fusée
ship_image = pygame.transform.rotate(ship_image, -90)  # Rotate 90 degrees to the right
ship_image = pygame.transform.scale(ship_image, (ship_image.get_width() // 2.5, ship_image.get_height() // 2.5))

# Classe pour la fusée
class Ship:
    def __init__(self):
        self.x = 50
        self.y = HEIGHT // 2
        self.speed = SHIP_SPEED
        self.vy = 0
        self.width = ship_image.get_width()
        self.height = ship_image.get_height()

    def update(self, dt):
        self.y += self.vy * dt * 100
        if self.y < 0:
            self.y = 0
        elif self.y > HEIGHT - self.height:
            self.y = HEIGHT - self.height

    def draw(self, screen):
        screen.blit(ship_image, (self.x, self.y))

    def get_collision_boxes(self):
        # Divides the ship into 3 smaller rectangles for better collision detection
        box_width = self.width // 3
        boxes = [
            pygame.Rect(self.x, self.y, box_width, self.height),  # Left box
            pygame.Rect(self.x + box_width, self.y + self.height / 2 / 2, box_width, self.height / 2),  # Middle box
            pygame.Rect(self.x + 2 * box_width, self.y + self.height / 2.65, box_width * 0.7, self.height / 4),  # Right box
            pygame.Rect(self.x + 2.7 * box_width, self.y + self.height / 2.265, box_width * 0.3, self.height / 8)  # Right box
        ]
        return boxes

class Missile:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.vx = MISSILE_SPEED

    def update(self, dt):
        self.x += self.vx * dt * 100

    def draw(self, screen):
        pygame.draw.rect(screen, YELLOW, (self.x, self.y - 2, 10, 4))

class Asteroid:
    def __init__(self):
        self.vx = ASTEROID_SPEED
        self.reset_position()
        self.spawn_delay = random.uniform(1, 3)
        self.spawn_time = pygame.time.get_ticks()  # Record the current time

    def reset_position(self):
        self.radius = random.randint(ASTEROID_MIN_RADIUS, ASTEROID_MAX_RADIUS)
        self.x = WIDTH + self.radius
        self.y = random.randint(0, HEIGHT)
        

    def update(self, dt):
        if pygame.time.get_ticks() - self.spawn_time > self.spawn_delay * 1000:
            self.x -= self.vx * dt * 100 
            if self.x < -self.radius * 2:
                self.reset_position()

    def draw(self, screen):
        image = pygame.transform.scale(asteroid_base_image, (self.radius * 2, self.radius * 2))
        screen.blit(image, (int(self.x - self.radius), int(self.y - self.radius)))

    def shrink(self):
        self.radius = int(self.radius * 0.75)
        if self.radius < 20:
            self.reset_position()

    def get_info(self):
        return self.x, self.y, self.radius

def check_ship_collision(ship, asteroids):
    collision_boxes = ship.get_collision_boxes()
    for asteroid in asteroids:
        for box in collision_boxes:
            if Collision_cicle_rect((asteroid.x, asteroid.y), asteroid.radius, (box.centerx, box.centery), box.width, box.height):
                return True
    return False

def Collision_cicle_rect(circle_center, circle_radius, rect_center, rect_width, rect_height):
    c_x, c_y = circle_center
    r_x, r_y = rect_center
    r_width, r_height = rect_width, rect_height

    # Calculate rectangle bounds
    r_min_x = r_x - r_width / 2
    r_max_x = r_x + r_width / 2
    r_min_y = r_y - r_height / 2
    r_max_y = r_y + r_height / 2

    # Find the closest point on the rectangle to the circle's center
    closest_x = max(r_min_x, min(c_x, r_max_x))
    closest_y = max(r_min_y, min(c_y, r_max_y))

    # Calculate distance from the circle's center to the closest point
    distance_squared = (c_x - closest_x) ** 2 + (c_y - closest_y) ** 2

    # Check for collision
    return distance_squared <= circle_radius ** 2

def check_missile_collision(missiles, asteroids):
    for missile in missiles[:]:
        for asteroid in asteroids:
            distance = math.hypot(missile.x - asteroid.x, missile.y - asteroid.y)
            if distance < asteroid.radius:
                asteroid.shrink()
                missiles.remove(missile)
                break  

Environnement

In [32]:


class AsteroidEnv(gym.Env):
    metadata = {"render_modes": ["rgb_array"], "render_fps": 144}

    def __init__(self, worldSize=(WIDTH, HEIGHT), asteroidNbr = ASTEROID_NBR,  render_mode=None):
        super().__init__()

        self.worldSize = worldSize
        self.width = worldSize[0]
        self.height = worldSize[1]
        self.asteroidNbr = asteroidNbr
        self.dt = 1 / self.metadata["render_fps"]

        self.t = 0
        self.x = np.zeros(1 + 3*self.asteroidNbr)  # y + (x,y,r)*asteroidNbr   (xShip vyShip et vxAsteroid = constantes)
        self.u = np.zeros(3)  # up, down, shoot

        self.ship = Ship()
        self.asteroids = [Asteroid() for _ in range(asteroidNbr)]
        self.missiles = []

        self.observation_space = Box(
            low= np.zeros(1 + 3*self.asteroidNbr),
            high=np.array([self.height - self.ship.height] + [self.width + ASTEROID_MAX_RADIUS*2, self.height, ASTEROID_MAX_RADIUS] * self.asteroidNbr),
            shape=(1 + 3*self.asteroidNbr,),
            dtype=np.float32
        )
        self.action_space = Tuple((Discrete(2), Discrete(2), Discrete(2))) # up, down, shoot

        self.render_mode = render_mode

    def _dynamic_step(self, action):
        self.t += self.dt
        self.u = action

        # Move ship
        if self.u[0]:
            self.ship.vy = -self.ship.speed
        elif self.u[1]:
            self.ship.vy = self.ship.speed
        else:
            self.ship.vy = 0

        self.ship.update(self.dt)
        self.x[0] = self.ship.y

        # Move asteroids
        for asteroid in self.asteroids:
            asteroid.update(self.dt)
        self.x[1:] = [coord for a in self.asteroids for coord in a.get_info()]

        
        # Move missles
        if (self.u[2]):
            self.missiles.append(Missile(self.ship.x + self.ship.width, self.ship.y + self.ship.height // 2))

        for missile in self.missiles:
            missile.update(self.dt)
            if missile.x > WIDTH:
                self.missiles.remove(missile)

        # Check for collisions between asteroids and missles 
        check_missile_collision(self.missiles, self.asteroids)
         

    def _g(self):
        # u = self.u
        # delta_x = next_obs = self._get_observation()

        # # à compléter - la fonction de coût
        # Q = np.array([1.0, 1.0, 0.01, 0.01])
        # R = np.array([0.01, 0.01])

        # cout = (Q * delta_x**2).sum() + (R * u**2).sum() * self.dt

        # cout += 1000 * (not self.observation_space.contains(next_obs))

        # Reward survival
        return 1

    def _get_observation(self):
        return self.x

    def _should_truncate(self):
        return False #not self.observation_space.contains(self._get_observation())

    def _should_terminate(self):
        return check_ship_collision(self.ship, self.asteroids)
    
    def _get_info(self):
        return {
            "t": self.t
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.t = 0

        observation = self.observation_space.sample()

        self.ship = Ship()
        self.asteroids = [Asteroid() for _ in range(self.asteroidNbr)]
        self.missiles = []

        self.x[0] = self.ship.y
        self.x[1:] = [coord for a in self.asteroids for coord in a.get_info()]

        observation = self._get_observation()

        self.u = np.zeros(3)

        info = self._get_info()

        return observation, info

    def step(self, action):
        # à compléter - les étape d'un step
        self._dynamic_step(action)
        observation = self._get_observation()
        reward = self._g()
        terminate = self._should_terminate()
        truncate = self._should_truncate()
        info = self._get_info()

        return observation, reward, terminate, truncate, info

    def render(self):
        if self.render_mode == "rgb_array":

            size = pygame.Vector2(self.worldSize)
            surface = pygame.Surface(size, flags=pygame.SRCALPHA)

            # Drawing the ship
            self.ship.draw(surface)

            # Drawing the asteroids
            for asteroid in self.asteroids:
                asteroid.draw(surface)

            # Drawing the missiles
            for missile in self.missiles:
                missile.draw(surface)

            return pygame.surfarray.pixels3d(surface)

In [33]:
gym.register(id="AsteroidEnv-v0", entry_point="__main__:AsteroidEnv")

env = gym.make("AsteroidEnv-v0", render_mode="rgb_array", max_episode_steps=1000)

In [34]:
def human_agent(_obs):
    keys = pygame.key.get_pressed()
    pygame.event.pump()

    if keys[pygame.K_UP] :
        up =1
    else:
        up = 0

    if keys[pygame.K_DOWN] :
        down =1
    else:
        down = 0

    if keys[pygame.K_SPACE] :
        space =1
    else:
        space = 0

    action = np.array([up,down,space])

    return action

In [35]:
human_render_loop(human_agent, env, print_step=True)

initial state
	observation [300. 864. 399.  64. 911. 171. 111. 914. 305. 114. 925. 122. 125. 924.
 408. 124. 886. 533.  86. 947. 329. 147. 863. 252.  63. 857. 145.  57.
 878. 441.  78.]
	info {'t': 0}
step 1
	action [0 0 0]
	observation [300. 864. 399.  64. 911. 171. 111. 914. 305. 114. 925. 122. 125. 924.
 408. 124. 886. 533.  86. 947. 329. 147. 863. 252.  63. 857. 145.  57.
 878. 441.  78.]
	reward 1
	terminated False
	truncated False
	info {'t': 0.006944444444444444}
step 2
	action [0 0 0]
	observation [300. 864. 399.  64. 911. 171. 111. 914. 305. 114. 925. 122. 125. 924.
 408. 124. 886. 533.  86. 947. 329. 147. 863. 252.  63. 857. 145.  57.
 878. 441.  78.]
	reward 1
	terminated False
	truncated False
	info {'t': 0.013888888888888888}
step 3
	action [0 0 0]
	observation [300. 864. 399.  64. 911. 171. 111. 914. 305. 114. 925. 122. 125. 924.
 408. 124. 886. 533.  86. 947. 329. 147. 863. 252.  63. 857. 145.  57.
 878. 441.  78.]
	reward 1
	terminated False
	truncated False
	info {'t':