#### 여기서는 Tutorial에서 배운 개념을 이용하여 간단하게 Network parameter를 분산 환경에서 서로 복사해오는 작업을 해보겠습니다. <br>
    1. 각 actor는 network을 가지고 있다. 
    2. 학습을 하면서 일정 간격으로 actor는 learner의 파라미터를 복제해온다.

In [1]:
import ray 
import time 
import numpy as np 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
ray.init() 

2021-01-23 07:45:58,963	INFO services.py:1173 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.0.61',
 'raylet_ip_address': '192.168.0.61',
 'redis_address': '192.168.0.61:6379',
 'object_store_address': '/tmp/ray/session_2021-01-23_07-45-58_512576_95995/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-01-23_07-45-58_512576_95995/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-01-23_07-45-58_512576_95995',
 'metrics_export_port': 60564,
 'node_id': '25569fa6e591529a16a4161f1adbc687ba4cdd71'}

In [3]:
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden=32):
        super(QNetwork, self).__init__()

        state_size = state_size[0]
        self.fc1 = nn.Linear(state_size, hidden)
        self.fc2 = nn.Linear(hidden, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

state_size = (4, ) 
action_size = 2 
temp_net = QNetwork(state_size, action_size, 32) 
test = torch.randn(size=(4,)) 
temp_net(test), temp_net(test).shape 

(tensor([-0.1507,  0.1949], grad_fn=<AddBackward0>), torch.Size([2]))

In [4]:
# actor의 역할은 각각 env에서 경험한 것을 buffer에 넘겨주는 역할을 합니다.
@ray.remote
class Actor:
    def __init__(self, 
                 learner: ("class: Learner class"),
                 actor_idx: ('int: Actor index'),
                 hidden: ("int: Update frequency of learner's q_behave network"), 
                 device: ("int: Cuda device number")):

        self.learner = learner # ray를 통해 공유하는 learner class입니다.
        self.device = device
        self.actor_idx = actor_idx

        # Network parameters
        self.state_dim = (16, )
        self.action_dim = 3
        self.q_behave = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)

    def explore(self):
        print("exploration start..")
        while 1:
            time.sleep(3)
            self.get_weights()
            print('updates done.')

    def get_weights(self):
        weight_copy = self.learner.return_weights.remote()
        weight_copy = ray.get(weight_copy)
        print(type(weight_copy), self.actor_idx)
        self.q_behave.load_state_dict(weight_copy)

In [5]:
# 공유 Buffer를 통해 학습을 진행하는 Learner를 정의합니다. 
# Learner는 buffer에 있는 샘플을 이용하여 network parameter를 업데이트를 하며, agent에게 network weight을 전달합니다.

@ray.remote
class Learner:
    def __init__(self, 
                 hidden: ("int: Update frequency of learner's q_behave network"), 
                 device: ("int: Cuda device number")):
        
        self.state_dim = (16, )
        self.action_dim = 3
        self.device = device
        
        self.q_behave = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        self.q_target = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        self.q_target.load_state_dict(self.q_behave.state_dict())
        self.q_target.eval()

    def return_weights(self):
        return self.q_target.state_dict() 

In [6]:
hidden = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

learner = Learner.remote(hidden, device) 

In [7]:
num_actors = 16 # actor의 개수

# num_actors 개수만큼 선언하고, explore 실행. actor라는 변수가 계속 중복이 되지만 실행은 잘 된다.
for actor_idx in range(num_actors):
    actor = Actor.remote(learner, actor_idx, hidden, device)
    actor.explore.remote()

[2m[36m(pid=96117)[0m exploration start..
[2m[36m(pid=96123)[0m exploration start..
[2m[36m(pid=96122)[0m exploration start..
[2m[36m(pid=96168)[0m exploration start..
[2m[36m(pid=96170)[0m exploration start..
[2m[36m(pid=96171)[0m exploration start..
[2m[36m(pid=96116)[0m exploration start..
[2m[36m(pid=96118)[0m exploration start..
[2m[36m(pid=96125)[0m exploration start..
[2m[36m(pid=96124)[0m exploration start..
[2m[36m(pid=96131)[0m exploration start..
[2m[36m(pid=96110)[0m exploration start..
[2m[36m(pid=96120)[0m exploration start..
[2m[36m(pid=96113)[0m exploration start..
[2m[36m(pid=96133)[0m exploration start..
[2m[36m(pid=96114)[0m exploration start..
[2m[36m(pid=96117)[0m <class 'collections.OrderedDict'> 9
[2m[36m(pid=96117)[0m updates done.
[2m[36m(pid=96168)[0m <class 'collections.OrderedDict'> 4
[2m[36m(pid=96168)[0m updates done.
[2m[36m(pid=96170)[0m <class 'collections.OrderedDict'> 1
[2m[36m(pid=96170