In [2]:
!pip install flappy-bird-gymnasium


Collecting flappy-bird-gymnasium
  Downloading flappy_bird_gymnasium-0.4.0-py3-none-any.whl (37.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.3/37.3 MB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gymnasium (from flappy-bird-gymnasium)
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium->flappy-bird-gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, flappy-bird-gymnasium
Successfully installed farama-notifications-0.0.4 flappy-bird-gymnasium-0.4.0 gymnasium-0.29.1


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import random
import flappy_bird_gymnasium


In [4]:
env_id = 'FlappyBird-v0'
env = gym.make(env_id, render_mode="human", use_lidar=True)
eval_env = gym.make(env_id, render_mode="human", use_lidar=True)

s_size = env.observation_space.shape[0]
a_size = env.action_space.n

hidden_dim = 256
batch_size = 64
gamma = 0.99
epsilon = 0.1
n_episodes = 10000
max_t = 1000
print_every = 100

*Model*


In [5]:
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer = deque(maxlen=buffer_size)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((np.array(state), action, reward, np.array(next_state), done))

    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)
        states = torch.tensor(np.array(states), dtype=torch.float32)
        actions = torch.tensor(actions)
        rewards = torch.tensor(rewards)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32)
        dones = torch.tensor(dones)
        return states, actions, rewards, next_states, dones

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
policy = DQN(s_size, a_size, hidden_dim).to(device)
target_policy = DQN(s_size, a_size, hidden_dim).to(device)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
replay_buffer = ReplayBuffer(50000)

scores = []
for i_episode in range(1, n_episodes+1):
    state = env.reset()[0]
    score = 0
    for t in range(max_t):
        state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
        action = torch.argmax(policy(state_tensor))
        next_state, reward, done, info, _ = env.step(action.item())
        score += reward
        replay_buffer.add(state, action.item(), reward, next_state, done)
        state = next_state
        if done:
            break
    scores.append(score)

    if i_episode % 10 == 0:
        target_policy.load_state_dict(policy.state_dict())

    if len(replay_buffer.buffer) > batch_size:
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        states, actions, rewards, next_states, dones = states.to(device), actions.to(device), rewards.to(device), next_states.to(device), dones.to(device)
        q_values = policy(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_values = target_policy(next_states)
        next_q_values, _ = next_q_values.max(1)
        dones = dones.float()
        targets = rewards + gamma * next_q_values * (1 - dones)
        loss = (q_values - targets).pow(2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if i_episode % print_every == 0:
        print('Episode'+ str(i_episode) + '\tAverage Score:'+ str(round(np.mean(scores[-print_every:]), 2)))

  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")


Episode100	Average Score:-1.76
Episode200	Average Score:-1.49
Episode300	Average Score:-0.91
Episode400	Average Score:-1.21
Episode500	Average Score:-2.07
Episode600	Average Score:-1.91
Episode700	Average Score:-1.4
Episode800	Average Score:-0.94
Episode900	Average Score:-1.45
Episode1000	Average Score:-1.79
Episode1100	Average Score:1.67
Episode1200	Average Score:1.63
Episode1300	Average Score:0.26
Episode1400	Average Score:-0.26
Episode1500	Average Score:-0.41
Episode1600	Average Score:0.73
Episode1700	Average Score:-1.34
Episode1800	Average Score:-0.05
Episode1900	Average Score:0.09
Episode2000	Average Score:-0.1
Episode2100	Average Score:-1.39
Episode2200	Average Score:-0.95
Episode2300	Average Score:-0.32
Episode2400	Average Score:2.09
Episode2500	Average Score:-0.02
Episode2600	Average Score:1.65
Episode2700	Average Score:0.33
Episode2800	Average Score:0.85
Episode2900	Average Score:0.96
Episode3000	Average Score:-1.06
Episode3100	Average Score:-1.65
Episode3200	Average Score:0.1

In [None]:
torch.save(policy.state_dict(), 'dqn_policy.pth')