In [None]:
import copy
import random
import collections
import warnings

import numpy as np

!pip install gymnasium
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distributions
import torch.distributions as dist

from tqdm import tqdm
import matplotlib.pyplot as plt

!pip install omegaconf
from omegaconf import OmegaConf

from gym.wrappers import RecordVideo

ray install

In [None]:
!pip install ray
import ray

# **Random seed fix**

In [None]:
import random

seed = 40
deterministic = True

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
warnings.filterwarnings('ignore')

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mujoco install

In [None]:
!pip show mujoco
!pip install --upgrade gymnasium[mujoco]

gym : half cheetah v4

In [6]:
env = gym.make('HalfCheetah-v4', exclude_current_positions_from_observation=False , render_mode = 'rgb_array')

In [7]:
def compute_log_prob(mean, log_std, action_raw):
    # Gaussian 로그 확률 계산
    std = log_std.exp()
    gaussian_log_prob = -0.5 * (((action_raw - mean) / std) ** 2 + 2 * log_std + torch.log(torch.tensor(2 * torch.pi)))
    gaussian_log_prob = gaussian_log_prob.sum(axis=-1)  # 차원 합산

    # Tanh 변환 보정 계산
    log_det_jacobian = torch.log(1 - torch.tanh(action_raw) ** 2 + 1e-6).sum(axis=-1)

    # 최종 로그 확률
    log_prob = gaussian_log_prob - log_det_jacobian
    return log_prob

In [8]:
@ray.remote(num_cpus=0.8, num_gpus=0.1)
class ReplayBuffer():
    def __init__(self, config):
        self.config = config
        self.buffer = collections.deque(maxlen=self.config["buffer_limit"])

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
      mini_batch = random.sample(self.buffer, n)
      s_lst, a_lst, na_lst, r_lst, next_s_lst, done_mask_lst = [], [], [], [], [], []

      for transition in mini_batch:
        s, a, na, r, next_s, done = transition
        s_lst.append(s.tolist())
        a_lst.append(a.tolist())
        na_lst.append(na.tolist())
        r_lst.append([r])
        next_s_lst.append(next_s.tolist())
        done_mask = 0.0 if done else 1.0
        done_mask_lst.append([done_mask])

    # 반환값은 이미 텐서로 변환됨
      return (
        torch.Tensor(s_lst),
        torch.Tensor(a_lst),
        torch.Tensor(na_lst),
        torch.Tensor(r_lst),
        torch.Tensor(next_s_lst),
        torch.Tensor(done_mask_lst)
      )


    def size(self):
        return len(self.buffer)

    def remove_old_data(self, threshold):
        self.buffer = collections.deque(
            [transition for transition in self.buffer if transition[-1] > threshold],
            maxlen=self.config.buffer_limit
        )


In [9]:
def soft_update(target, source, tau):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(tau * source_param.data + (1.0 - tau) * target_param.data)

In [10]:
class Actor(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.config=config
    self.tanh = nn.Tanh()
    self.relu = nn.ReLU()
    self.softmax = nn.Softmax()
    self.dropout = nn.Dropout(0.2)

    self.actor_l1 = nn.Linear(config["state_dim"], config["hidden_dim"])
    self.actor_l2 = nn.Linear(config["hidden_dim"], config["hidden_dim"])
    self.actor_l3 = nn.Linear(config["hidden_dim"], config["hidden_dim"])
    self.actor_l4 = nn.Linear(config['hidden_dim'], config['hidden_dim'])

    self.actor_mean = nn.Linear(config["hidden_dim"] , config["action_dim"])
    self.actor_std = nn.Linear(config["hidden_dim"] , config["action_dim"])

  def forward(self, state):

    x = self.actor_l1(state)
    x = self.relu(x)
    x = self.actor_l2(x)
    x = self.relu(x)
    x = self.actor_l3(x)
    x = self.relu(x)
    x = self.actor_l4(x)
    x = self.relu(x)

    mean_x = self.actor_mean(x)
    std_x = self.actor_std(x)
    std_x = torch.clamp(std_x, min= 1e-10 , max = 0.5)
    log_std_x = torch.log(std_x)

    normal = dist.Normal(mean_x, std_x)
    z = normal.rsample()  # reparameterization trick

    log_prob = compute_log_prob(mean_x, log_std_x, torch.Tensor(z))

    action = self.tanh(z)
    policy = log_prob

    return action , policy

In [11]:
class Critic(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.data = []
    self.config=config
    self.tanh = nn.Tanh()
    self.relu = nn.ReLU()
    self.swish = nn.SiLU()
    self.dropout = nn.Dropout(0.1)

    self.critic_l1 = nn.Linear(config["state_dim"] + config["action_dim"], config["hidden_dim"])
    self.critic_l2 = nn.Linear(config["hidden_dim"], config["hidden_dim"])
    self.critic_l3 = nn.Linear(config['hidden_dim'], config['hidden_dim'])
    self.critic_l4 = nn.Linear(config["hidden_dim"], 1)

  def forward(self, input):
    x = self.critic_l1(input)
    x = self.relu(x)
    x = self.critic_l2(x)
    x = self.relu(x)
    x = self.critic_l3(x)
    x = self.relu(x)
    x = self.critic_l4(x)

    return x

In [12]:
config = {
        'critic_class': Critic,
        'env_class': env,
        'lr_actor': 0.003,
        'lr_critic': 0.003,
        'num_episodes': 1000,
        'batch_size': 256,
        'buffer_limit' : 3000000,
        'gamma': 0.99,
        'tau': 0.01,
        'buffer_size': 10000,
        'num_workers': 64,
        'state_dim' : env.observation_space.shape[0],
        'action_dim' : int(env.action_space.shape[0]),
        'hidden_dim' : 512,
    }

In [13]:
actor = Actor(config)

In [14]:
@ray.remote(num_cpus=1.4, num_gpus=0.1)
def train_loop_per_worker(config , actor):
  critic1 = config['critic_class'](config)
  critic2 = config['critic_class'](config)
  target_critic = config['critic_class'](config)

  actor_optimizer = optim.Adam(actor.parameters(), lr=config['lr_actor'])
  critic1_optimizer = optim.Adam(critic1.parameters(), lr=config['lr_critic'])
  critic2_optimizer = optim.Adam(critic2.parameters(), lr=config['lr_critic'])

  memory = ReplayBuffer.remote(config)

  epi_rews = []

  # 학습 시작
  for n_epi in tqdm(range(config['num_episodes'])):
    state, _ = env.reset()
    terminated, truncated = False, False
    epi_rew = 0

    while not (terminated or truncated):
      state = torch.Tensor(state)
      action, policy = actor(state)

      # 환경에서 다음 상태 및 보상 얻기
      next_state, reward, terminated, truncated, _ = env.step(action.detach().numpy())
      next_state = torch.Tensor(next_state)
      next_action, _ = actor(next_state)

      # 메모리에 저장
      ray.get(memory.put.remote([state, action, next_action, reward, next_state, terminated or truncated]))

      # 메모리 샘플링 및 업데이트
      if ray.get(memory.size.remote()) > 2000 :
        for _ in range(2):
          samples = ray.get(memory.sample.remote(config['batch_size']))
          samples = (samples[0], samples[1], samples[2], samples[3], samples[4], samples[5])

          states, actions, next_actions, rewards, next_states, dones = samples

          critic_inputs = torch.cat([states, actions], axis=1)
          next_critic_inputs = torch.cat([next_states, next_actions], axis=1)

          # Target critic 업데이트
          binary = torch.sum(critic1(next_critic_inputs)) > torch.sum(critic2(next_critic_inputs))
          if binary:
            soft_update(target_critic, critic1, config['tau'])
          else:
            soft_update(target_critic, critic2, config['tau'])

          # 손실 함수 계산
          actor_loss = -target_critic(critic_inputs).mean()
          critic1_loss = nn.MSELoss()(critic1(critic_inputs),
                                                rewards + config['gamma'] * (1 - dones) * target_critic(next_critic_inputs))
          critic2_loss = nn.MSELoss()(critic2(critic_inputs),
                                                rewards + config['gamma'] * (1 - dones) * target_critic(next_critic_inputs))

          # 역전파 및 최적화
          actor_optimizer.zero_grad()
          critic1_optimizer.zero_grad()
          critic2_optimizer.zero_grad()

          actor_loss.backward()
          critic1_loss.backward()
          critic2_loss.backward()

          actor_optimizer.step()
          critic1_optimizer.step()
          critic2_optimizer.step()

          epi_rew += reward
          state = next_state

    epi_rews.append(epi_rew)

  return epi_rews

In [15]:
# Ray의 병렬 작업 실행
def train_distributed(config , actor):
    num_workers = config['num_workers']
    results = ray.get([train_loop_per_worker.remote(config , actor) for _ in range(num_workers)])
    return results

In [16]:
ray.init(num_cpus=8, num_gpus=1)

2024-12-08 15:24:13,539	INFO worker.py:1821 -- Started a local Ray instance.


0,1
Python version:,3.10.12
Ray version:,2.40.0


In [17]:
results = train_distributed(config , actor)
print("Training results:", results)

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:08<2:29:35,  8.98s/it]
  0%|          | 0/1000 [00:00<?, ?it/s][32m [repeated 5x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
  0%|          | 2/1000 [00:16<2:10:44,  7.86s/it][32m [repeated 4x across cluster][0m
  0%|          | 3/1000 [03:17<24:02:56, 86.84s/it][32m [repeated 4x across cluster][0m
  0%|          | 4/1000 [06:24<35:00:32, 126.54s/it][32m [repeated 4x across cluster][0m
  0%|          | 5/1000 [09:32<41:06:06, 148.71s/it][32m [repeated 4x across cluster][0m
  1%|          | 6/1000 [12:40<44:45:49, 162.12s/it][32m [repeated 4x across cluster][0m
  1%|          | 7/1000 [15:49<47:07:26, 170.84s/it][32m [repeated 4x across cluster][0m
  1%|          | 8/1000 [18:57<48:35:16, 176.33s/it][32m [repeated 4x ac

KeyboardInterrupt: 

In [None]:
ray.shutdown()

# **visualization**

In [18]:
import os
os.environ['MUJOCO_GL']='egl'
env = gym.wrappers.RecordVideo(env, video_folder='./videos')

In [19]:
state,_ = env.reset()
terminated, truncated = False, False
for i in range(1):
    while not (terminated or truncated):
      state = torch.Tensor(state)
      action , policy = actor(state)
      action = action.detach().numpy()
      next_state, reward, terminated, truncated, _ = env.step(action)

      state = next_state

env.close()

  1%|          | 11/1000 [28:23<50:43:07, 184.62s/it]
  1%|          | 11/1000 [28:30<51:02:30, 185.79s/it]
  1%|          | 11/1000 [28:34<51:15:39, 186.59s/it]


In [20]:
from IPython.display import Video
Video("./videos/rl-video-episode-0.mp4", embed=True)