In [15]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt

In [16]:
env = gym.make("CartPole-v1")

In [17]:
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

policy_net = QNetwork(env.observation_space.shape[0], env.action_space.n)
target_net = QNetwork(env.observation_space.shape[0], env.action_space.n)
target_net.load_state_dict(policy_net.state_dict())

<All keys matched successfully>

In [18]:
num_episodes = 500
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
target_update_interval = 10

In [19]:
def select_action(state, policy_net, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            return torch.argmax(policy_net(torch.FloatTensor(state))).item()

In [20]:
class ReplayBuffer:
    def __init__(self, max_size):
        self.buffer = deque(maxlen=max_size)

    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

replay_buffer = ReplayBuffer(max_size=10000)

In [23]:
optimizer = optim.Adam(policy_net.parameters())
rewards = []

for episode in range(num_episodes):
    state, info = env.reset()
    total_reward = 0
    done = False

    while not done:
        action = select_action(state, policy_net, epsilon)
        next_state, reward, done, truncated, info = env.step(action)

        replay_buffer.add((state, action, reward, next_state, done))
        state = next_state
        total_reward += reward

        if len(replay_buffer) > batch_size:
            experiences = replay_buffer.sample(batch_size)
            batch = list(zip(*experiences))
            states, actions, rewards_batch, next_states, dones = [torch.FloatTensor(x) for x in batch]

            q_values = policy_net(states)
            next_q_values = target_net(next_states)
            targets = rewards_batch + (1 - dones) * gamma * next_q_values.max(1)[0]

            optimizer.zero_grad()
            loss = nn.MSELoss()(q_values.gather(1, actions.unsqueeze(1)).squeeze(), targets.detach())
            loss.backward()
            optimizer.step()

    rewards.append(total_reward)

    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

    if episode % target_update_interval == 0:
        target_net.load_state_dict(policy_net.state_dict())

    print(f"Episode {episode + 1}, Total Reward: {total_reward}")

RuntimeError: gather(): Expected dtype int64 for index