<a href="https://colab.research.google.com/github/ezzeddinegasmi/DRL_comparative_study/blob/main/sac_breakout_colab_dynamic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:

# --- Installation et GPU ---
!pip install stable-baselines3[extra] pygame -q

import torch
print("GPU disponible ?", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Nom du GPU :", torch.cuda.get_device_name(0))

# --- Environnement ---
import gym
from gym import spaces
import numpy as np
import random

class BreakoutContinuousEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self):
        super(BreakoutContinuousEnv, self).__init__()
        self.screen_width = 400
        self.screen_height = 300
        self.paddle_width = 60
        self.paddle_height = 10
        self.ball_size = 8

        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=np.array([0, 0, -5, -5, 0]),
            high=np.array([self.screen_width, self.screen_height, 5, 5, self.screen_width]),
            dtype=np.float32
        )
        self.reset()

    def reset(self):
        self.paddle_x = self.screen_width / 2
        self.ball_x = self.screen_width / 2
        self.ball_y = self.screen_height / 2
        self.ball_vx = random.choice([-3, 3])
        self.ball_vy = -3
        self.score = 0
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        return np.array([self.ball_x, self.ball_y, self.ball_vx, self.ball_vy, self.paddle_x], dtype=np.float32)

    def step(self, action):
        move = float(action[0]) * 10
        self.paddle_x = np.clip(self.paddle_x + move, 0, self.screen_width - self.paddle_width)

        self.ball_x += self.ball_vx
        self.ball_y += self.ball_vy

        if self.ball_x <= 0 or self.ball_x >= self.screen_width:
            self.ball_vx *= -1
        if self.ball_y <= 0:
            self.ball_vy *= -1

        if (self.ball_y + self.ball_size >= self.screen_height - self.paddle_height) and            (self.paddle_x <= self.ball_x <= self.paddle_x + self.paddle_width):
            self.ball_vy *= -1
            self.score += 1
            reward = 1.0
        elif self.ball_y > self.screen_height:
            self.done = True
            reward = -10.0
        else:
            reward = -0.01

        return self._get_obs(), reward, self.done, {}

    def render(self, mode='human'):
        import pygame
        if not hasattr(self, 'screen'):
            pygame.init()
            self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
            pygame.display.set_caption("Breakout Continuous")
            self.clock = pygame.time.Clock()

        self.screen.fill((0, 0, 0))
        pygame.draw.circle(self.screen, (255, 255, 255), (int(self.ball_x), int(self.ball_y)), self.ball_size)
        pygame.draw.rect(
            self.screen,
            (0, 255, 0),
            pygame.Rect(int(self.paddle_x), self.screen_height - self.paddle_height, self.paddle_width, self.paddle_height)
        )
        pygame.display.flip()
        self.clock.tick(60)

    def close(self):
        import pygame
        if hasattr(self, 'screen'):
            pygame.quit()

# --- Entraînement ---
from stable_baselines3.common.monitor import Monitor
from stable_baselines3 import SAC
import os

log_dir = "./logs/"
os.makedirs(log_dir, exist_ok=True)

env = BreakoutContinuousEnv()
env = Monitor(env, log_dir)

model = SAC("MlpPolicy", env, verbose=1, device="cuda")
model.learn(total_timesteps=50000)

# --- Affichage en temps réel des récompenses ---
import matplotlib.pyplot as plt
from stable_baselines3.common.results_plotter import load_results, ts2xy
from IPython.display import clear_output
import time

results = load_results(log_dir)

plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
for _ in range(10):
    results = load_results(log_dir)
    x, y = ts2xy(results, 'timesteps')
    ax.clear()
    ax.plot(x, y, label="Reward")
    ax.set_xlabel("Timesteps")
    ax.set_ylabel("Reward")
    ax.set_title("SAC Training Rewards")
    ax.legend()
    ax.grid(True)
    clear_output(wait=True)
    display(fig)
    time.sleep(1)

plt.ioff()
plt.show()

# --- Sauvegarde ---
model.save("sac_breakout_continuous")

# --- Test avec Pygame ---
env = BreakoutContinuousEnv()
model = SAC.load("sac_breakout_continuous", env=env, device="cuda")

obs = env.reset()
total_reward = 0
for step in range(300):
    env.render()
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, _ = env.step(action)
    total_reward += reward
    if done:
        break
env.close()
print(f"Total reward: {total_reward}")


GPU disponible ? False


AssertionError: Expected env to be a `gymnasium.Env` but got <class '__main__.BreakoutContinuousEnv'>