In [1]:
# imports
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import namedtuple, deque
import random
import time

In [2]:
# Define the QNetwork class
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
# Replay buffer for experience replay
class ReplayBuffer:
    def __init__(self, buffer_size, batch_size):
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
    
    def add(self, experience):
        self.buffer.append(experience)
    
    def sample(self):
        return random.sample(self.buffer, self.batch_size)
    
    def __len__(self):
        return len(self.buffer)

In [11]:
# Hyperparameters
buffer_size = 10000
batch_size = 128
gamma = 0.9
lr = 0.0025
target_update = 10
episodes = 500
epsilon_start = 1.0
epsilon_decay = 0.99
epsilon_min = 0.2
tau = 1e-3

# Create the environment
env = gym.make('CartPole-v1', render_mode='rgb_array')

# Define state and action sizes
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# Initialize the Q-network and target network
q_network = QNetwork(state_size, action_size)
target_network = QNetwork(state_size, action_size)
target_network.load_state_dict(q_network.state_dict())
# # target_network.eval()

# Define optimizer and loss function
optimizer = optim.Adam(q_network.parameters(), lr=lr)
loss_fn = nn.MSELoss()
replay_buffer = ReplayBuffer(buffer_size, batch_size)

  deprecation(
  deprecation(


In [5]:
def train_dqn():

    all_rewards = []

    for episode in range(episodes):
        state = env.reset()
        state = torch.tensor(state, dtype=torch.float32)
        episode_reward = 0
        # Decay epsilon
        epsilon = max(epsilon_min, epsilon_start * (epsilon_decay ** episode))

        terminated = False
        while not terminated:
            # Choose action using epsilon-greedy policy
            if np.random.random() > epsilon:
                with torch.no_grad():
                    q_values = q_network(state)
                    action = torch.argmax(q_values).item()
            else:
                action = np.random.randint(action_size)

            # Take action and observe next state and reward
            next_state, reward, terminated, _ = env.step(action)
            next_state = torch.tensor(next_state, dtype=torch.float32)
            replay_buffer.add((state, action, reward, next_state, terminated))
            state = next_state
            episode_reward += reward

            # Train Q-network if replay buffer has enough samples
            if len(replay_buffer) > batch_size:
                experiences = replay_buffer.sample()
                states, actions, rewards, next_states, dones = zip(*experiences)

                states = torch.stack(states)
                actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(-1)
                rewards_tensor = torch.tensor(rewards, dtype=torch.float32).unsqueeze(-1)
                next_states = torch.stack(next_states)
                dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(-1)

                # Compute target Q-values
                # with torch.no_grad():
                max_next_q_values = target_network(next_states).max(dim=1)[0].unsqueeze(1)
                targets = rewards_tensor + (gamma * max_next_q_values * (1 - dones))

                # Compute current Q-values
                q_values = q_network(states).gather(1, actions).squeeze()

                # Compute loss and update Q-network
                loss = nn.functional.mse_loss(q_values, targets)
                # print(f"Q-values: {q_values}, targets: {targets}, Action: {actions}")
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # Update target network
        # if episode % target_update == 0:
        #     target_network.load_state_dict(q_network.state_dict())

        for target_param, local_param in zip(target_network.parameters(), q_network.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

        all_rewards.append(episode_reward)

        print(f'Episode {episode}, Reward: {episode_reward}, Epsilon: {epsilon}')


    env.close()
    torch.save(q_network.state_dict(), 'dqn_cartpole.pth')
    return all_rewards

In [6]:
rewards = train_dqn()

  if not isinstance(terminated, (bool, np.bool8)):
  loss = nn.functional.mse_loss(q_values, targets)


Episode 0, Reward: 11.0, Epsilon: 1.0
Episode 1, Reward: 11.0, Epsilon: 0.99
Episode 2, Reward: 20.0, Epsilon: 0.9801
Episode 3, Reward: 22.0, Epsilon: 0.970299
Episode 4, Reward: 34.0, Epsilon: 0.96059601
Episode 5, Reward: 12.0, Epsilon: 0.9509900498999999
Episode 6, Reward: 11.0, Epsilon: 0.941480149401
Episode 7, Reward: 22.0, Epsilon: 0.9320653479069899
Episode 8, Reward: 15.0, Epsilon: 0.9227446944279201
Episode 9, Reward: 20.0, Epsilon: 0.9135172474836408
Episode 10, Reward: 42.0, Epsilon: 0.9043820750088044
Episode 11, Reward: 12.0, Epsilon: 0.8953382542587164
Episode 12, Reward: 13.0, Epsilon: 0.8863848717161292
Episode 13, Reward: 28.0, Epsilon: 0.8775210229989678
Episode 14, Reward: 38.0, Epsilon: 0.8687458127689782
Episode 15, Reward: 15.0, Epsilon: 0.8600583546412884
Episode 16, Reward: 15.0, Epsilon: 0.8514577710948755
Episode 17, Reward: 49.0, Epsilon: 0.8429431933839268
Episode 18, Reward: 28.0, Epsilon: 0.8345137614500875
Episode 19, Reward: 22.0, Epsilon: 0.8261686238

In [16]:
# Load trained model
q_network = QNetwork(state_size=state_size, action_size=action_size)
q_network.load_state_dict(torch.load('dqn_cartpole_best.pth'))
q_network.eval()

QNetwork(
  (fc1): Linear(in_features=4, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=2, bias=True)
)

In [17]:
def render_cartpole(render_speed=0.05):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32)
    terminated = False
    total_reward = 0
    while not terminated:
        env.render()
        with torch.no_grad():
            q_values = q_network(state)
            action = torch.argmax(q_values).item()
            print(f"State: {state}, Q-values: {q_values}, Action: {action}")
        next_state, reward, terminated, _ = env.step(action)
        state = torch.tensor(next_state, dtype=torch.float32)
        total_reward += reward

        time.sleep(render_speed)

    print('Total reward:', total_reward)
    env.close()

In [18]:
render_cartpole()

State: tensor([0.0089, 0.0447, 0.0237, 0.0168]), Q-values: tensor([1.2048, 1.2041]), Action: 0
State: tensor([ 0.0098, -0.1507,  0.0240,  0.3169]), Q-values: tensor([1.2029, 1.2044]), Action: 1
State: tensor([0.0068, 0.0441, 0.0303, 0.0319]), Q-values: tensor([1.2048, 1.2041]), Action: 0
State: tensor([ 0.0077, -0.1515,  0.0310,  0.3340]), Q-values: tensor([1.2029, 1.2044]), Action: 1
State: tensor([0.0046, 0.0432, 0.0376, 0.0512]), Q-values: tensor([1.2047, 1.2042]), Action: 0
State: tensor([ 0.0055, -0.1525,  0.0387,  0.3555]), Q-values: tensor([1.2029, 1.2045]), Action: 1
State: tensor([0.0024, 0.0421, 0.0458, 0.0753]), Q-values: tensor([1.2046, 1.2042]), Action: 0
State: tensor([ 0.0033, -0.1537,  0.0473,  0.3820]), Q-values: tensor([1.2029, 1.2045]), Action: 1
State: tensor([0.0002, 0.0408, 0.0549, 0.1046]), Q-values: tensor([1.2046, 1.2042]), Action: 0
State: tensor([ 0.0010, -0.1551,  0.0570,  0.4141]), Q-values: tensor([1.2029, 1.2045]), Action: 1
State: tensor([-0.0021,  0.039