In [5]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import flappy_bird_gymnasium

# Define the neural network for approximating Q-values.
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Experience Replay Memory to store transitions.
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    
    def push(self, transition):
        self.memory.append(transition)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

def train_dqn():
    # Hyperparameters
    num_episodes = 2000
    batch_size = 64
    gamma = 0.99
    learning_rate = 1e-3
    epsilon_start = 1.0
    epsilon_end = 0.01
    epsilon_decay = 0.995  # Decay factor per episode
    target_update = 10     # Update target network every 10 episodes
    memory_capacity = 10000

    # Create the environment.
    # Replace "FlappyBird-v0" with the correct id for your environment.
    env = gym.make("FlappyBird-v0", render_mode=None, use_lidar=False)
    
    # Determine the sizes from the environment.
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n
    
    # Initialize the policy network and the target network.
    policy_net = DQN(obs_size, n_actions)
    target_net = DQN(obs_size, n_actions)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()  # Set target network to evaluation mode
    
    optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()
    
    memory = ReplayMemory(memory_capacity)
    epsilon = epsilon_start
    
    # Main training loop.
    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            # Epsilon-greedy action selection.
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0)
                    q_values = policy_net(state_tensor)
                    action = q_values.argmax().item()
            
            # Take the action in the environment.
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            total_reward += reward
            
            # Save the transition in replay memory.
            memory.push((state, action, reward, next_state, done))
            state = next_state
            
            # Update the network if we have enough samples.
            if len(memory) >= batch_size:
                transitions = memory.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*transitions)
                
                states = torch.FloatTensor(states)
                actions = torch.LongTensor(actions).unsqueeze(1)
                rewards = torch.FloatTensor(rewards).unsqueeze(1)
                next_states = torch.FloatTensor(next_states)
                dones = torch.FloatTensor(dones).unsqueeze(1)
                
                # Compute current Q-values for the taken actions.
                current_q = policy_net(states).gather(1, actions)
                # Compute the maximum Q-value for the next state from the target network.
                next_q = target_net(next_states).max(1)[0].unsqueeze(1)
                # Compute target Q-values.
                target_q = rewards + gamma * next_q * (1 - dones)
                
                loss = loss_fn(current_q, target_q.detach())
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # Decay epsilon to reduce exploration over time.
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        
        # Periodically update the target network.
        if episode % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())
        
        print(f"Episode {episode} - Total Reward: {total_reward} - Epsilon: {epsilon:.3f}")
    
    # Save the trained model.
    torch.save(policy_net.state_dict(), "dqn_flappy_bird.pth")

    # Close the environment.
    env.close()

if __name__ == '__main__':
    train_dqn()

Episode 0 - Total Reward: -7.499999999999998 - Epsilon: 0.995
Episode 1 - Total Reward: -7.499999999999998 - Epsilon: 0.990
Episode 2 - Total Reward: -8.7 - Epsilon: 0.985
Episode 3 - Total Reward: -8.099999999999998 - Epsilon: 0.980
Episode 4 - Total Reward: -8.099999999999998 - Epsilon: 0.975
Episode 5 - Total Reward: -8.099999999999998 - Epsilon: 0.970
Episode 6 - Total Reward: -8.099999999999998 - Epsilon: 0.966
Episode 7 - Total Reward: -7.499999999999998 - Epsilon: 0.961
Episode 8 - Total Reward: -8.099999999999998 - Epsilon: 0.956
Episode 9 - Total Reward: -8.099999999999998 - Epsilon: 0.951
Episode 10 - Total Reward: -7.499999999999998 - Epsilon: 0.946
Episode 11 - Total Reward: -8.099999999999998 - Epsilon: 0.942
Episode 12 - Total Reward: -8.099999999999998 - Epsilon: 0.937
Episode 13 - Total Reward: -8.099999999999998 - Epsilon: 0.932
Episode 14 - Total Reward: -8.099999999999998 - Epsilon: 0.928
Episode 15 - Total Reward: -5.699999999999998 - Epsilon: 0.923
Episode 16 - Tot

In [None]:
torch.save(policy_net.state_dict(), "trained_dqn.pth")
print("Trained model saved as trained_dqn.pth")