# DSA - Deep Learning [5] - Reinforcement learning

In [13]:
# Install necessary libraries
!pip install flappy-bird-gymnasium pygame
!apt-get install -y xvfb python3-opengl ffmpeg
!pip install pyvirtualdisplay
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
python3-opengl is already the newest version (3.1.5+dfsg-1).
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
xvfb is already the newest version (2:21.1.4-2ubuntu1.7~22.04.12).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Looking in indexes: https://download.pytorch.org/whl/cpu


In [14]:
# Import necessary libraries
import os
import torch
import random
import numpy as np
import pygame
import imageio
from IPython.display import display, Image
from PIL import Image as PILImage  # Importing PIL for image manipulation
from flappy_bird_gymnasium.envs.flappy_bird_env import FlappyBirdEnv

# Set environment variables for rendering and audio in Colab
os.environ["SDL_VIDEODRIVER"] = "dummy"
os.environ["SDL_AUDIODRIVER"] = "dummy"


In [15]:
class CustomFlappyBirdEnv(FlappyBirdEnv):
    def __init__(self):
        super().__init__()

        # Initialize pygame and enforce dummy display
        pygame.init()
        if not pygame.display.get_init():
            pygame.display.init()
        pygame.display.set_mode((1, 1))  # Enforce dummy video mode

        # Initialize pygame mixer for audio
        if not pygame.mixer.get_init():
            pygame.mixer.init()

        # Initialize game surface
        self._surface = pygame.Surface((288, 512))  # Game surface dimensions

        # Initialize display surface (required for FlappyBirdEnv rendering)
        self._display = pygame.display.set_mode((288, 512))  # Create display window of appropriate size

        # Initialize the FPS clock for controlling the frame rate
        self._fps_clock = pygame.time.Clock()  # Initialize the FPS clock

        # Initialize image assets
        self._images = {}

        # Load images required for the game
        self._images["background"] = self._load_image("background-day.png")
        self._images["pipe"] = [
            self._load_image("pipe-green.png"),  # Top pipe
            pygame.transform.flip(self._load_image("pipe-green.png"), False, True)  # Bottom pipe (flipped)
        ]
        self._images["base"] = self._load_image("base.png")
        self._images["player"] = [
            self._load_image("yellowbird-upflap.png"),
            self._load_image("yellowbird-midflap.png"),
            self._load_image("yellowbird-downflap.png"),
        ]
        self._images["numbers"] = {
            i: self._load_image(f"{i}.png") for i in range(10)  # Load images for digits 0-9
        }

        # Load audio assets if needed
        self._audio = {
            "wing": self._load_audio("wing.wav"),
            "point": self._load_audio("point.wav"),
            "hit": self._load_audio("hit.wav"),
            "die": self._load_audio("die.wav"),
        }

        # Additional attributes required by the parent class
        self._score = 0
        self._player_index = 0
        self._base_shift = self._images["base"].get_width() - self._surface.get_width()
        self._pipes = []
        self._player_y = 256
        self._player_velocity_y = 0
        self._gravity = 1
        self._pipe_gap = 100

    def _load_image(self, filename):
        """
        Load an image from the assets directory.
        Args:
            filename: Name of the image file.
        Returns:
            Loaded pygame image.
        """
        assets_path = "/usr/local/lib/python3.10/dist-packages/flappy_bird_gymnasium/assets/sprites"
        filepath = os.path.join(assets_path, filename)
        return pygame.image.load(filepath).convert_alpha()

    def _load_audio(self, filename):
        """
        Load an audio file from the assets directory.
        Args:
            filename: Name of the audio file.
        Returns:
            Loaded pygame audio sound.
        """
        assets_path = "/usr/local/lib/python3.10/dist-packages/flappy_bird_gymnasium/assets/audio"
        filepath = os.path.join(assets_path, filename)
        return pygame.mixer.Sound(filepath)

    def render(self):
        """
        Render the game screen to the display and capture the frame for Colab visualization.
        """
        super().render()  # Call the parent class's render method

        # Capture the screen as an array
        frame = pygame.surfarray.array3d(pygame.display.get_surface())
        self.frames.append(frame)  # Save the frame for GIF creation

        # Control frame rate
        self._fps_clock.tick(self.metadata["render_fps"])

    def create_gif(self, gif_name="flappy_bird_game.gif"):
        """
        Create and display a GIF from the captured frames.
        """
        flipped_frames = []
        for frame in self.frames:
            pil_frame = PILImage.fromarray(frame)
            flipped_frame = pil_frame.rotate(270, expand=True)  # Rotate 270 degrees
            flipped_frames.append(flipped_frame)

        # Save and display the GIF
        flipped_gif_name = gif_name.replace(".gif", "_flipped.gif")
        imageio.mimsave(flipped_gif_name, flipped_frames, duration=1 / self.metadata["render_fps"])
        display(Image(flipped_gif_name))

    def reset(self):
        """
        Reset the environment and clear the stored frames.
        """
        self.frames = []  # Clear captured frames
        return super().reset()


In [16]:
# define the initializing function --> refer to DQN Learning slide
class DQN(torch.nn.Module):
  def __init__(self, state_size, action_size):
      super(DQN, self).__init__()
      self.fc1 = torch.nn.Linear(state_size, 128)
      self.fc2 = torch.nn.Linear(128, 128)
      self.fc3 = torch.nn.Linear(128, action_size)

  def forward(self, x):  # forward propogation function passes data into neural network layers to produce an input
      x = torch.relu(self.fc1(x))
      x = torch.relu(self.fc2(x))
      return self.fc3(x)

In [17]:
def preprocess_state(state): # collect every single frame of the game
  if isinstance(state, tuple):
    observation = state[0] # state is in a cuboid format (type of data format) and convert into an array
  else:
    observation = state

  observation = np.array(observation, dtype=np.float32) # where the position of the pipes, players are etc.

  if observation.max() > 1.0:
    observation = observation / 255.0  # 255.0 is the normalising algorithm (for data points to have the same scale) for efficiency purposes

  return observation.flatten() # becomes a 1D array so that it's easier to be used for other data further on. less complex for calculations also.

In [18]:
# making a training policy for the agent, so that don't have to train the agent every time you open the game
def train_dqn(env, num_episodes, model_save_path="flappy_bird_dqn_final.pth", checkpoint_interval=100):
    """
    Train the DQN model on the environment.
    Args:
        env: The Flappy Bird environment.
        num_episodes: Number of training episodes.
        model_save_path: File path to save the final trained model.
        checkpoint_interval: Number of episodes between saving model checkpoints.
    Returns:
        Trained model.
    """
    # Determine state and action sizes dynamically
    state = preprocess_state(env.reset())
    state_size = state.shape[0]
    action_size = env.action_space.n

    # Initialize the DQN model
    model = DQN(state_size, action_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()

    # Replay buffer for storing experiences
    replay_buffer = []
    max_buffer_size = 1000

    # Hyperparameters
    gamma = 0.99  # Discount factor
    epsilon = 1.0  # Initial exploration rate
    epsilon_min = 0.01  # Minimum exploration rate
    epsilon_decay = 0.995  # Decay factor for exploration rate

    # Metrics
    rewards_per_episode = []  # Track rewards per episode
    steps_per_episode = []  # Track steps per episode

    for episode in range(num_episodes):
        state = preprocess_state(env.reset())  # Preprocess the initial state
        done = False
        total_reward = 0
        steps = 0

        while not done:
            # Convert state to tensor
            state_tensor = torch.FloatTensor(state).unsqueeze(0)

            # Select an action (epsilon-greedy policy)
            if random.random() < epsilon:  # Explore
                action = random.choice(range(action_size))
            else:  # Exploit the learned policy
                with torch.no_grad():
                    action = torch.argmax(model(state_tensor)).item()

            # Perform the action in the environment
            next_state, reward, done, _, _ = env.step(action)
            next_state = preprocess_state(next_state)  # Preprocess the next state
            total_reward += reward
            steps += 1

            # Store the experience in the replay buffer
            replay_buffer.append((state, action, reward, next_state, done))
            if len(replay_buffer) > max_buffer_size:
                replay_buffer.pop(0)  # Remove the oldest experience

            # Sample a random batch from the replay buffer for training
            if len(replay_buffer) >= 32:
                batch = random.sample(replay_buffer, 32)
                states, actions, rewards, next_states, dones = zip(*batch)

                # Convert batch to tensors
                states_tensor = torch.FloatTensor(states)
                actions_tensor = torch.LongTensor(actions).unsqueeze(1)
                rewards_tensor = torch.FloatTensor(rewards).unsqueeze(1)
                next_states_tensor = torch.FloatTensor(next_states)
                dones_tensor = torch.FloatTensor(dones).unsqueeze(1)

                # Compute Q values for the current states
                q_values = model(states_tensor).gather(1, actions_tensor)

                # Compute target Q values
                with torch.no_grad():
                    next_q_values = model(next_states_tensor).max(1, keepdim=True)[0]
                    targets = rewards_tensor + gamma * next_q_values * (1 - dones_tensor)

                # Compute loss
                loss = loss_fn(q_values, targets)

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Move to the next state
            state = next_state

        # Decay exploration rate
        if epsilon > epsilon_min:
            epsilon *= epsilon_decay

        # Track metrics
        rewards_per_episode.append(total_reward)
        steps_per_episode.append(steps)

In [19]:
def load_model(env, model_path="flappy_bird_dqn_final.pth"):
    """
    Load the trained RL model.
    Args:
        env: The Flappy Bird environment.
        model_path: Path to the saved model file.
    Returns:
        Loaded model.
    """
    # Preprocess the initial state to determine its size
    state = preprocess_state(env.reset())
    state_size = state.shape[0]
    action_size = env.action_space.n

    # Initialize the DQN model with the correct dimensions
    model = DQN(state_size, action_size)
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model to evaluation mode
    return model

In [20]:
# each step is a frame in the game
def test_rl_agent_playing(env, model, step_limit = 500):
  state = preprocess_state(env.reset())
  for step in range(step_limit):
    state_tensor = torch.FloatTensor(state).unsqueeze(0)

    with torch.no_grad():
      action = torch.argmax(model(state_tensor)).item()

      next_state, reward, done, _. _ = env.step(action)
      next_state = preprocess_state(next_state)

      env.render()

      if done:
        break # ends the game

      state = next_state # if game hasn't ended, enters the next stafe

    env.create_gif()

In [21]:
env = CustomFlappyBirdEnv()

In [28]:
model = train_dqn(env, num_episodes = 500, model_save_path = "flappy_bird_dqn_final.pth", checkpoint_interval = 1000)

In [29]:
model = load_model(env, "flappy_bird_dqn_final.pth")

test_rl_agent_playing(env, model)

  model.load_state_dict(torch.load(model_path))


FileNotFoundError: [Errno 2] No such file or directory: 'flappy_bird_dqn_final.pth'