In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import math
import numpy as np

from collections import deque

import gymnasium as gym

In [11]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        experience_tuple = (state, action, reward, next_state, done)
        self.memory.append(experience_tuple)

    def __len__(self):
        return len(self.memory)

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size) 
        states, actions, rewards, next_states, dones = zip(*batch) 

        states_tensor = torch.tensor(np.array(states), dtype=torch.float32).reshape(batch_size, -1)
        rewards_tensor = torch.tensor(rewards, dtype=torch.float32).reshape(batch_size)
        next_states_tensor = torch.tensor(np.array(next_states), dtype=torch.float32).reshape(batch_size, -1)
        dones_tensor = torch.tensor(dones, dtype=torch.float32).reshape(batch_size)
        # Giữ action là tensor (batch_size, 1) để sử dụng luôn hàm .gather của torch
        actions_tensor = torch.tensor(actions, dtype=torch.long).reshape(batch_size, -1)

        return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor

In [12]:
class QNetwork(nn.Module):
    def __init__(self, state_dims, action_dims):
        super().__init__()
        self.fc1 = nn.Linear(state_dims, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dims)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [5]:
env = gym.make("LunarLander-v2")

In [6]:
def select_action(q_values, step, start, end, decay):
    epsilon = (
        end + (start - end) * math.exp(-step / decay)
    )
    if random.random() < epsilon:
        return random.randint(0, len(q_values) - 1)
    return torch.argmax(q_values).item()
def soft_update(online, target, tau):
    for target_param, online_param in zip(target.parameters(), online.parameters()):
        target_param.data.copy_(tau * online_param.data + (1.0 - tau) * target_param.data)

In [None]:
replay_buffer = ReplayBuffer(1000)

online_network = QNetwork(8, 4)
target_network = QNetwork(8, 4)
target_network.load_state_dict(online_network.state_dict())
optimizer = optim.Adam(online_network.parameters(), lr=0.0001)

batch_size = 64
gamma = 0.99
tau = 0.005

global_step = 0
update_every = 4


for i in range(1000):
    state, info = env.reset()
    done = False

    while not done:
        global_step += 1

        q_values = online_network(torch.tensor(state, dtype=torch.float32))
        action = select_action(q_values, global_step, start=0.9, end=0.05, decay=1000)

        next_state, reward, terminate, truncated, _ = env.step(action)
        done = truncated or terminate
        replay_buffer.push(state, action, reward, next_state, done)

        if replay_buffer.__len__() >= batch_size:
            states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size=64)

            # Q_online_policy(s_t, a_t)
            q_values = online_network(states).gather(1, actions).reshape(batch_size)

            # a^ = argmax_{a_t} Q_online_policy(s_t, a_t)

            # max_{a^} Q_target_policy(s_t+1, a^)
            with torch.no_grad():
                next_actions = online_network(next_states).argmax(dim=1).reshape(batch_size, -1)
                
                next_q_values = target_network(next_states).gather(1, next_actions).reshape(batch_size)

                target_q_values = rewards+ next_q_values * gamma * (1 - dones)  

            loss = nn.MSELoss()(q_values, target_q_values)

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()
            
            if global_step % update_every == 0:
                soft_update(online_network, target_network, tau)

        state = next_state