#### 분산 강화학습으로 CartPole을 DQN을 이용하여 구현해보겠습니다. <br>기본적인 방식은 다음과 같습니다. <br>  
    1. Replay Buffer: Actor로부터 data를 받고, Learner에게 data를 전달하는 역할
    2. Parameter Server: Learner로부터 parameter를 받고, Actor에게 paramter를 전달하는 역할.
    3. Learner: Replay Buffer로 부터 데이터를 받아 학습을 진행하고, Parameter Server로 Learner 모델의 parameter를 전달하는 역할.
    4. Actor: Environment와 상호작용하며 data를 Replay Buffer에 전달하고, Parameter Server로부터 Learner 모델의 parameter를 받아 자신의 모델 parameter를 update.

#### ISSUE<br>

    cuda가 process에 올려진 객체를 ray.remote가 선언된 class의 변수로 전달할 경우, 에러가 발생합니다. 
    그래서 예를들어 Q-network은 @ray.remote로 데코레이션을 하지 않습니다. 마찬가지 이유로, Learner 또한 ray.remote를 하지 않습니다.

In [1]:
import ray 
import gym
import time 
import numpy as np 
from copy import deepcopy
import matplotlib.pyplot as plt
from IPython.display import clear_output

import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
ray.init() 

2021-01-28 13:48:46,572	INFO services.py:1173 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.0.61',
 'raylet_ip_address': '192.168.0.61',
 'redis_address': '192.168.0.61:6379',
 'object_store_address': '/tmp/ray/session_2021-01-28_13-48-46_133084_45380/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-01-28_13-48-46_133084_45380/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-01-28_13-48-46_133084_45380',
 'metrics_export_port': 57299,
 'node_id': '1c0e785f294956bf5a4f7b6b084c720d28dba159'}

In [3]:
# Buffer를 정의합니다. 
@ray.remote 
class ReplayBuffer:
    def __init__(self, 
                   buffer_size: ('int: Buffer_size'), 
                 state_dim: ('tuple: State dim')):

        # 1차원 state라할지라도 tuple로 입력받도록 tuple 타입을 강제하였습니다. 
        # 밑에 줄의 self.buffer_dim을 구하기 위해서 이렇게 한 것인데요, 사실 빼도 상관없고 얼마든지 다르게 구현해도 무방합니다.
        # 참고) ray를 쓸 때는, class선언시에 assert조건을 만족못하여도 에러를 주지 않습니다. class의 메소드를 실행하고나서야 __init__에서 assertion 에러가 있다고 표시를 합니다. 
        assert type(state_dim) == tuple
        
        self.buffer_dim = (buffer_size, ) + state_dim
        self.buffer_size = buffer_size
        self.batch_update_status = True
        
        self.state_buffer = np.zeros(self.buffer_dim)
        self.action_buffer = np.zeros(buffer_size)
        self.reward_buffer = np.zeros(buffer_size)
        self.next_state_buffer = np.zeros(self.buffer_dim)
        self.done_buffer = np.zeros(buffer_size)
        self.act_idx_buffer = np.zeros(buffer_size)

        self.store_idx = 0
        self.current_size = 0
        self.total_store_count = 0

    def store(self, state, action, next_state, reward, done, actor_idx): 
        # actor_idx는 쓰이지 않지만, 중간에 데이터 저장이 잘 되는지 확인용 변수
        self.state_buffer[self.store_idx] = state
        self.action_buffer[self.store_idx] = action
        self.reward_buffer[self.store_idx] = reward
        self.next_state_buffer[self.store_idx] = next_state
        self.done_buffer[self.store_idx] = done
        self.act_idx_buffer[self.store_idx] = actor_idx
        
        self.total_store_count += 1 # used for counting the total number of steps during training
        self.store_idx = (self.store_idx + 1) % self.buffer_size
        self.current_size = min(self.current_size+1, self.buffer_size)
    
    def batch_load(self, batch_size): 
        indices = np.random.randint(self.current_size, size=batch_size)  
        return dict( 
                states=self.state_buffer[indices], 
                actions=self.action_buffer[indices], 
                rewards=self.reward_buffer[indices], 
                next_states=self.next_state_buffer[indices], 
                dones=self.done_buffer[indices],
                actindices=self.act_idx_buffer[indices])  
    
    # 아래의 메소드들은 ray로 다른 객체가 current_size, store_idx, total_store_count 변수들을 접근할 때 쓰기 위해서 선언
    def return_current_size(self):
        return self.current_size

    def return_store_idx(self):
        return self.store_idx

    def return_total_store_count(self):
        return self.total_store_count
    
    def batch_update_on(self):
        self.batch_update_status = True

    def batch_update_off(self):
        self.batch_update_status = False

    def return_batch_update_status(self):
        return self.batch_update_status
    
# # test
# buffer_size = 1000
# batch_size = 16
# state_dim = (4, )
# temp_buffer = ReplayBuffer.remote(buffer_size, state_dim)

# for i in range(50):
#     temp_buffer.store.remote(np.array(state_dim), 1, np.array(state_dim), 1, 1, 1)

# batch = temp_buffer.batch_load.remote(batch_size)
# print("Batch Size:", ray.get(batch)['actindices'].shape) 

# current_size = temp_buffer.return_current_size.remote()
# print("Current Size: ", ray.get(current_size))

# return_store_idx = temp_buffer.return_store_idx.remote()
# print("Store Index: ", ray.get(return_store_idx))

In [4]:
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden=32):
        super(QNetwork, self).__init__()

        state_size = state_size[0]
        self.fc1 = nn.Linear(state_size, hidden)
        self.fc2 = nn.Linear(hidden, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# state_size = (4, ) 
# action_size = 2 
# temp_net = QNetwork(state_size, action_size, 32) 
# test = torch.randn(size=(4,)) 
# temp_net(test), temp_net(test).shape 

In [5]:
@ray.remote
class Network_parameter_server:
    def __init__(self): 
        self.is_saved = False # 언제부터 actor가 모델 parameter를 가지고 올 수 있는지 확인하는 변수
        
    def update_parameters(self, learner_params): 
        self.learner_params = learner_params
        self.is_saved = True

    def return_parameters(self):
        self.is_saved = False
        return self.learner_params
        
    def return_saving_status(self):
        return self.is_saved

In [6]:
# actor의 역할은 각각 env에서 경험한 것을 buffer에 넘겨주는 역할을 합니다.
@ray.remote
class Actor:  
    def __init__(self, 
                 params_server: ("Class: Network parameter server"),
                 memory: ("class: Replay Buffer"),
                 env_name: ("str: Environment name"), 
                 actor_idx: ("int: The index of an actor"), 
                 actor_update_freq: ("int: Frequency of updating actor's network. (unit: steps)"),
                 update_buf_start: ("int: Update starting buffer size"), 
                 epsilon: ("int: starting epsilon value for e-greedy update"), 
                 eps_decay: ("int: epsilon decay rate"), 
                 eps_min: ("int: minimum epsilon value"), 
                 hidden: ("int: Update frequency of learner's q_behave network"), 
                 device: ("int: Cuda device number"),
                 is_wandb: ("str: Whether wandb is on or off"),
                 plot_mode: ("str: whether to plot in wandb or inline in jupyter"),
                 WANDB_GROUP_NAME: ("str: Wandb's group name for all actors")):

        # wandb init config 
        if is_wandb:
            entity = 'rl_flip_school_team'  
            project_name = 'Distributed_DQN'
            wandb.init(
                    group=WANDB_GROUP_NAME,
                    project=project_name, 
                    entity=entity,
                    name=f'{actor_idx}_Distributed_DQN'
                    ) 

        self.env = gym.make(env_name)
        self.params_server = params_server
        self.memory = memory   # ray를 통해 공유하는 Replaybuffer class입니다.
        self.actor_idx = actor_idx # 어떤 actor에서 온 데이터인지 보기 위한 변수입니다.
        self.actor_update_freq = actor_update_freq
        self.plot_mode = plot_mode
        self.device = device

        # DQN hyperparameters
        self.epsilon = epsilon
        self.eps_decay = eps_decay
        self.eps_min = eps_min

        # Network parameters
        self.state_dim = (self.env.observation_space.shape[0], )
        try: self.action_dim = self.env.action_space.n # Discrete action
        except: self.action_dim = env.action_space.shape[0] # Continous action            
        self.q_behave = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)

    def select_action(self, state): 
        # e-greedy로 action을 선택 
        if np.random.random() < self.epsilon: 
            return np.zeros(self.action_dim), self.env.action_space.sample() 
        else: 
            state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 
            Qs = self.q_behave(state) 
            action = Qs.argmax() 
            return Qs.detach().cpu().numpy(), action.detach().item() 
        
    def explore(self):
        score = 0
        update_freq = 0
        state = self.env.reset()

        # actor는 멈추지 않고 무한 loop로 exploration하도록 설정
        while 1:
            Qs, action = self.select_action(state)
            next_state, reward, done, _ = self.env.step(action) 
            
            self.memory.store.remote(state, action, next_state, reward, done, self.actor_idx) # 공유 ReplayBuffer에 저장
            
            score += reward
            state = next_state
            self.epsilon = max(self.epsilon-self.eps_decay, self.eps_min)
            while 1:
                if ray.get(self.memory.return_total_store_count.remote()) < update_buf_start: break
                else:
                    if ray.get(self.memory.return_batch_update_status.remote()): 
                        self.memory.batch_update_off.remote()
                        break

            update_freq = (update_freq + 1) % self.actor_update_freq
            if update_freq==0: 
                buffer_status = ray.get(self.params_server.return_saving_status.remote())
                if buffer_status: self._pull_parameters() 

            if done:
                state = self.env.reset() 
                self._plot_status(score)
                score = 0

    def _pull_parameters(self):
        updated_params = ray.get(self.params_server.return_parameters.remote()) 
        self.q_behave.load_state_dict(updated_params) 

    def _plot_status(self, score):
        if self.plot_mode=='wandb':
            wandb.log({'Score': score, 
                       f"Epsilon_{self.actor_idx}": self.epsilon,
                       f'Score_{self.actor_idx}': score}) 

In [7]:
# Learner는 buffer에 있는 샘플을 이용하여 network parameter를 업데이트를 하며, parameter server에 network weight을 전달합니다.
# Learner는 network update 등 cuda 연산을 하고 cpu로 병렬처리하는 것이 없으므로 ray를 이용하여 선언하지 않습니다.
class Learner: 
    def __init__(self, 
                 env_name: ("str: Environment name"),
                 params_server: ("Class: Network parameter server"),
                 memory: ("class: ReplayBuffer"),
                 gamma: ("float: Discount rate"), 
                 epsilon: ("int: starting epsilon value for e-greedy update"), 
                 eps_decay: ("int: epsilon decay rate"), 
                 eps_min: ("int: minimum epsilon value"), 
                 update_buf_start: ("int: Update starting buffer size"), 
                 update_freq: ("int: Frequency of updating learner's q_behave network"), 
                 update_target_freq: ("int: Frequency of updating learner's q_target network"), 
                 update_push_freq: ("int: Frequency of sending learner's paratemers to parameter-server"), 
                 hidden: ("int: Update frequency of learner's q_behave network"), 
                 batch_size: ("int: Batch size for updating network"),
                 learning_rate: ("float: Learning rate for updating the q_behave network"),
                 device: ("int: Cuda device number"),
                 is_wandb: ("str: Whether wandb is on or off"),
                 plot_mode: ("str: whether to plot in wandb or inline in jupyter"),
                 WANDB_GROUP_NAME: ("str: Wandb's group name for all actors")):
                    
        if is_wandb:
            entity = 'rl_flip_school_team'  
            project_name = 'Distributed_DQN'
            wandb.init(
                    group=WANDB_GROUP_NAME,
                    project=project_name, 
                    entity=entity,
                    name='Learner_Distributed_DQN'
                    ) 

        self.env = gym.make(env_name)
        self.params_server = params_server
        self.memory = memory
        self.gamma = gamma
        self.plot_mode = plot_mode
        
        # DQN hyperparameters
        self.epsilon = epsilon
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        
        self.state_dim = (self.env.observation_space.shape[0], )
        try: self.action_dim = self.env.action_space.n # Discrete action
        except: self.action_dim = env.action_space.shape[0] # Continous action 
            
        self.batch_size = batch_size
        self.update_cnt = 0 # q_behave 업데이트 횟수
        self.update_freq = update_freq # q_behave 업데이트 주기
        self.update_buf_start = update_buf_start # 업데이트 시작 buffer size
        self.update_target_freq = update_target_freq # q_target 업데이트 주기
        self.update_push_freq = update_push_freq # parameter server에 보내는 주기
        self.device = device
        self.total_steps = 0
        self.scores = []
        self.losses = [0]

        self.q_behave = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        self.q_target = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        self.q_target.load_state_dict(self.q_behave.state_dict())
        self.q_target.eval()

        self.optimizer = optim.Adam(self.q_behave.parameters(), lr=learning_rate) 

    # 저장된 buffer에서 데이터를 로딩한 후 q_network을 업데이트합니다.
    def update_q_network(self):
        # update_cnt를 q_behave를 업데이트 할 때마다 1씩 상승 (self.update_target_freq 만큼 q_behave를 업데이트를 할 때마다 q_target을 업데이트 하기 위함)
        self.update_cnt += 1
        batch = ray.get(self.memory.batch_load.remote(self.batch_size)) 
        loss = self._compute_loss(batch) 

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.losses.append(loss.item()) # for plotting the losses
        self.memory.batch_update_on.remote()

    def target_hard_update(self):  
        # Hard update 방식
        self.q_target.load_state_dict(self.q_behave.state_dict()) 

    def eval_select_action(self, state): 
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 
        Qs = self.q_behave(state) 
        action = Qs.argmax() 
        return Qs.detach().cpu().numpy(), action.detach().item() 

    def push_parameters(self):
        # Send paramters to server 
        copied_model = deepcopy(self.q_behave).cpu()
        self.params_server.update_parameters.remote(copied_model.state_dict())

    def train(self):
        # 여기서는 training의 종료시점을 정하지 않았습니다.
        print("training start..")
        
        # Learner는 environment와 상호작용을 할 필요가 없지만, 여기서는 learner의 학습률도 plot해보기 위해서 도입
        score = 0
        state = self.env.reset()
        while 1:
            Qs, action = self.select_action(state)
            next_state, reward, done, _ = self.env.step(action) 
            score += reward
            state = next_state
            self.epsilon = max(self.epsilon-self.eps_decay, self.eps_min)
            while 1:
                if not ray.get(self.memory.return_batch_update_status.remote()): 
                    self.update_q_network()

                    # 만일 update_push_freq 횟수 만큼 q_behave를 업데이트 했다면, server에 parameter를 보냅니다.
                    if (self.update_cnt%self.update_push_freq)==0: pass

                    # 만일 target_update_freq의 횟수 만큼 q_behave를 업데이트 했다면, target_network을 복사해옵니다.
                    if (self.update_cnt%self.update_target_freq)==0: 
                        self.target_hard_update()
                        self.push_parameters()

                    self.memory.batch_update_on.remote()
                    break
            if done:
                state = self.env.reset() 
                self._plot_status(score)
                score = 0

    def select_action(self, state): 
        if np.random.random() < self.epsilon: 
            return np.zeros(self.action_dim), self.env.action_space.sample() 
        else: 
            state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 
            Qs = self.q_behave(state) 
            action = Qs.argmax() 
            return Qs.detach().cpu().numpy(), action.detach().item() 

    def _compute_loss(self, batch: "Dictionary (S, A, R', S', Dones)"):
        states = torch.FloatTensor(batch['states']).to(self.device)
        next_states = torch.FloatTensor(batch['next_states']).to(self.device)
        actions = torch.LongTensor(batch['actions'].reshape(-1, 1)).to(self.device)
        rewards = torch.FloatTensor(batch['rewards'].reshape(-1, 1)).to(self.device)
        dones = torch.FloatTensor(batch['dones'].reshape(-1, 1)).to(self.device)
        
        current_q = self.q_behave(states).gather(1, actions)
        next_q = self.q_target(next_states).max(dim=1, keepdim=True)[0].detach()
        mask = 1 - dones
        target = (rewards + (mask * self.gamma * next_q)).to(self.device)
        loss = F.smooth_l1_loss(target, current_q)
        
        return loss

    def _plot_status(self, score):
        if self.plot_mode=='wandb':
            wandb.log({"Learner Score": score, 
                       "Learner Epsilon": self.epsilon,
                       "loss(10 frames avg)": np.mean(self.losses[-10:]),
                       "Number of frames": ray.get(self.memory.return_total_store_count.remote())
                      })

In [8]:
env_lists = ['CartPole-v0', 'LunarLander-v2']
env_name = env_lists[1]
gamma = 0.995

buffer_size = 3000 # Replay Buffer 사이즈 # cartpole
buffer_size = 150000 # Replay Buffer 사이즈 # lunarlander
batch_size = 32    # Replay Buffer에서 가지고 올 샘플 개수
env = gym.make(env_name)
state_dim = (env.observation_space.shape[0], ) 
        
update_buf_start = 100 # cartpole
update_buf_start = 100 # lunarlander
update_freq = 1
update_target_freq = 100
update_push_freq = 1

epsilon = 1.
eps_decay = 1/2000  # cartpole
eps_decay = 0.000005 # lunarlander
eps_min = 0.1

hidden = 128 # cartpole
hidden = 256 # lunarlander
learning_rate = 0.001 # cartpole
learning_rate = 0.00002 # lunarlander

learner_device = "cuda:1" if torch.cuda.is_available() else "cpu"
is_wandb = True     # whether to use wandb or not
plot_mode = 'wandb' # plot options: 'wandb' or 'inline'
WANDB_GROUP_NAME = 'Distributed_DQN_' + str(np.random.randint(10000))

params_server = Network_parameter_server.remote() 
memory = ReplayBuffer.remote(buffer_size, state_dim)
learner = Learner(env_name, params_server, memory, gamma, epsilon, eps_decay, eps_min,
                  update_buf_start, update_freq, update_target_freq, update_push_freq, 
                  hidden, batch_size, learning_rate, learner_device, is_wandb, plot_mode, WANDB_GROUP_NAME) 

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.10.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [9]:
# num_actors 개수만큼 선언하고, explore 실행. actor라는 변수가 계속 중복이 되더라도 실행은 잘 된다.
num_actors = 5  # actor의 개수
actor_device = "cpu"
actor_update_freq = 100

for actor_idx in range(num_actors):
    actor = Actor.remote(params_server, memory, env_name, actor_idx, actor_update_freq, update_buf_start,
                           epsilon, eps_decay, eps_min, hidden, actor_device, 
                           is_wandb, plot_mode, WANDB_GROUP_NAME)
    actor.explore.remote() 

[2m[36m(pid=45561)[0m wandb: Tracking run with wandb version 0.9.4
[2m[36m(pid=45510)[0m wandb: Tracking run with wandb version 0.9.4
[2m[36m(pid=45527)[0m wandb: Tracking run with wandb version 0.9.4
[2m[36m(pid=45550)[0m wandb: Tracking run with wandb version 0.9.4
[2m[36m(pid=45556)[0m wandb: Tracking run with wandb version 0.9.4
[2m[36m(pid=45561)[0m wandb: Wandb version 0.10.15 is available!  To upgrade, please run:
[2m[36m(pid=45561)[0m wandb:  $ pip install wandb --upgrade
[2m[36m(pid=45561)[0m wandb: Run data is saved locally in wandb/run-20210128_134857-1dmd1lhi
[2m[36m(pid=45527)[0m wandb: Wandb version 0.10.15 is available!  To upgrade, please run:
[2m[36m(pid=45527)[0m wandb:  $ pip install wandb --upgrade
[2m[36m(pid=45527)[0m wandb: Run data is saved locally in wandb/run-20210128_134857-znz72mrz
[2m[36m(pid=45550)[0m wandb: Wandb version 0.10.15 is available!  To upgrade, please run:
[2m[36m(pid=45550)[0m wandb:  $ pip install wandb 

In [None]:
while 1:
    buffer_saved_cnt = ray.get(memory.return_total_store_count.remote())
    if buffer_saved_cnt > learner.update_buf_start: learner.train()  

training start..


Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.10.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


    Single step은 성공
    
    Distributed_DQN_4027