# CQL-DQN
- Q(s,a)를 표현하는 DeepNet 하나를 사용
- a의 개수가 정해져 있어야 함: a는 이산공간
- (ex.) Cartpole은 2개임 (0 or 1)
- (c.f.) a가 연속공간: SAC 사용
- (ref) https://github.com/BY571/CQL/tree/main

In [None]:
import gymnasium as gym
import numpy as np
from collections import deque
import argparse
import pdb

import torch
import torch.nn as nn
#from networks import DDQN
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import random

In [11]:
# config = get_config()

In [15]:
env = gym.make("CartPole-v1")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

eps = 1.
d_eps = 1 - 0.01
steps = 0
average10 = deque(maxlen=10)
total_steps = 0

In [55]:
agent = CQLAgent(state_size=env.observation_space.shape,  # 4
                         action_size=env.action_space.n,  # 2(0 or 1 둘 중 하나)
                         device=device)

buffer = ReplayBuffer(buffer_size=100000, batch_size=32, device=device)
collect_random(env=env, dataset=buffer, num_samples=10000)  # 경험 DB: data는 10000개 수집

In [56]:
state = env.reset()
episode_steps = 0
rewards = 0

# 정책에서 추출 또는 랜덤 추출(초기에는 random, 뒤로 갈수록 정책에서)
action = agent.get_action(state, epsilon=eps)  
print(action)
steps += 1
next_state, reward, done, truncated, info = env.step(action[0])
buffer.add(state, action, reward, next_state, done)
loss, cql_loss, bellmann_error = agent.learn(buffer.sample())

# state = next_state
# rewards += reward
# episode_steps += 1
# eps = max(1 - ((steps*d_eps)/config.eps_frames), config.min_eps)


[np.int64(0)]


In [None]:
# state = env.reset()
# episode_steps = 0
# rewards = 0
# while True:
#     action = agent.get_action(state, epsilon=eps)
#     steps += 1
#     next_state, reward, done, _ = env.step(action[0])
#     buffer.add(state, action, reward, next_state, done)
#     loss, cql_loss, bellmann_error = agent.learn(buffer.sample())
#     state = next_state
#     rewards += reward
#     episode_steps += 1
#     # min_eps=0.01이고 처음에는 eps가 1, 점차 작아짐
#     eps = max(1 - ((steps*d_eps)/config.eps_frames), config.min_eps)  
#     if done:
#         break

In [9]:
# def get_config():
#     parser = argparse.ArgumentParser(description='RL')
#     parser.add_argument("--run_name", type=str, default="CQL-DQN", help="Run name, default: CQL-DQN")
#     parser.add_argument("--env", type=str, default="CartPole-v0", help="Gym environment name, default: CartPole-v0")
#     parser.add_argument("--episodes", type=int, default=400, help="Number of episodes, default: 200")
#     parser.add_argument("--buffer_size", type=int, default=100_000, help="Maximal training dataset size, default: 100_000")
#     parser.add_argument("--seed", type=int, default=1, help="Seed, default: 1")
#     parser.add_argument("--min_eps", type=float, default=0.01, help="Minimal Epsilon, default: 4")
#     parser.add_argument("--eps_frames", type=int, default=1e4, help="Number of steps for annealing the epsilon value to the min epsilon, default: 1e5")
#     parser.add_argument("--log_video", type=int, default=0, help="Log agent behaviour to wanbd when set to 1, default: 0")
#     parser.add_argument("--save_every", type=int, default=100, help="Saves the network every x epochs, default: 25")
    
#     args = parser.parse_args()
#     return args

In [61]:
# logsumexp의 연산을 확인하는 부분
q_values = torch.tensor([[1.0, 2.0, 3.0], 
                         [0.5, 2.5, 1.5]])
print(q_values.shape)
logsumexp = torch.logsumexp(q_values, dim=1, keepdim=True)
# logsumexp는 exp(x)값의 연산이 커져서 안정성이 떨어지는 연산을 방지
# smoothing된 최대값 계산: 최대값과 비슷하지만 다른 값들도 일부 반영되도록 함
logsumexp, np.log(np.exp(1)+np.exp(2)+np.exp(3)), np.log(np.exp(0.5)+np.exp(2.5)+np.exp(1.5))

torch.Size([2, 3])


(tensor([[3.4076],
         [2.9076]]),
 np.float64(3.40760596444438),
 np.float64(2.90760596444438))

In [54]:
class CQLAgent():
    def __init__(self, state_size, action_size, hidden_size=256, device="cpu"):
        self.state_size = state_size
        self.action_size = action_size
        self.device = device
        self.tau = 1e-3
        self.gamma = 0.99
        
        self.network = DDQN(state_size=self.state_size,
                            action_size=self.action_size,
                            layer_size=hidden_size
                            ).to(self.device)

        self.target_net = DDQN(state_size=self.state_size,
                            action_size=self.action_size,
                            layer_size=hidden_size
                            ).to(self.device)
        
        self.optimizer = optim.Adam(params=self.network.parameters(), lr=1e-3)
        
    
    def get_action(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
            self.network.eval()  # Q(s,a): 2개의 a의 값에 따른 Q(보상)의 값
            with torch.no_grad():  # Q(s,a): a에 따른 누적보상 합(Q는 s와 a의 함수)
                action_values = self.network(state)  # (1) 정책에서 추출하거나
            self.network.train()  # Q가 더 큰 a를 선택
            action = np.argmax(action_values.cpu().data.numpy(), axis=1)
        else:  # (2) 완전하게 random하게 선택하거나
            action = random.choices(np.arange(self.action_size), k=1)
        return action

    def cql_loss(self, q_values, current_action):
        """Computes the CQL loss for a batch of Q-values and actions."""
        logsumexp = torch.logsumexp(q_values, dim=1, keepdim=True)  # Q값의 전체적 분포
        q_a = q_values.gather(1, current_action)  # 실제 선택한 action의 Q값
        
        # 두 값의 차이: 선택한 action의 Q값을 보수적으로 학습하도록 강제하는 역할
        return (logsumexp - q_a).mean()  

    def learn(self, experiences):
        # [32,4], [32, 1], [32, 4], [32, 1]
        states, actions, rewards, next_states, dones = experiences
        with torch.no_grad():  # next_states=[b,4]=[32,4]
            Q_targets_next = self.target_net(next_states).detach().max(1)[0].unsqueeze(1)  # [32,2]->[32,1]
            # ground truth값을 만듬
            Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))
             
        Q_a_s = self.network(states.to(self.device))  # [32,2]
        # 아래는 [32,2]에서 actions(0 or 1)값에 따라 해당하는 위치의 Q값을 가져옴
        # Q_a_s=[32,2]인데 actions는 0이나 1이 들어 있는 [32,2]임
        #  ==> actions가 0이면 Q_a_s[32,0]이, 1이면 Q_a_s[32,1]이 선정
        Q_expected = Q_a_s.gather(1, actions)  
        
        # 두 Loss값을 정의: CQL loss, MSE loss        
        cql1_loss = self.cql_loss(Q_a_s, actions)
        bellman_error = F.mse_loss(Q_expected, Q_targets)
        
        q1_loss = cql1_loss + 0.5 * bellman_error
        
        self.optimizer.zero_grad()
        q1_loss.backward()
        clip_grad_norm_(self.network.parameters(), 1.)  ################
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.network, self.target_net)
        return q1_loss.detach().item(), cql1_loss.detach().item(), bellman_error.detach().item()
        
        
    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)

In [27]:
def save(args, save_name, model, wandb, ep=None):
    import os
    save_dir = './trained_models/' 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not ep == None:
        torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth")
        wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth")
    else:
        torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth")
        wandb.save(save_dir + args.run_name + save_name + ".pth")

def collect_random(env, dataset, num_samples=200):
    state = env.reset()
    # pdb.set_trace()
    for _ in range(num_samples):
        action = env.action_space.sample()
        # next_state, reward, done, _ = env.step(action)
        next_state, reward, done, truncated, info = env.step(action)
        dataset.add(state, action, reward, next_state, done)
        state = next_state
        if done:
            state = env.reset()

In [5]:
class DDQN(nn.Module):
    def __init__(self, state_size, action_size, layer_size):
        super(DDQN, self).__init__()
        self.input_shape = state_size
        self.action_size = action_size
        self.head_1 = nn.Linear(self.input_shape[0], layer_size)
        self.ff_1 = nn.Linear(layer_size, layer_size)
        self.ff_2 = nn.Linear(layer_size, action_size)

    def forward(self, input):
        x = torch.relu(self.head_1(input))
        x = torch.relu(self.ff_1(x))
        out = self.ff_2(x)
        
        return out

In [62]:
import numpy as np
import random
import torch
from collections import deque, namedtuple

class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, buffer_size, batch_size, device):
        """Initialize a ReplayBuffer object.
        Params
        ======
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            seed (int): random seed
        """
        self.device = device
        self.memory = deque(maxlen=buffer_size)  
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
    
    def add(self, state, action, reward, next_state, done):
        """Add a new experience to memory."""
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)
    
    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)
        # pdb.set_trace()

        # states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
        states = torch.from_numpy(np.stack([
            e.state[0] if isinstance(e.state, (list, tuple)) and len(e.state) > 1 else e.state
            for e in experiences if e is not None
        ]))        
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)
  
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)