#### 여기서는 Tutorial에서 배운 개념을 이용하여 간단하게 Network parameter를 분산 환경에서 서로 복사해오는 작업을 해보겠습니다. <br>
    1. 각 actor는 network을 가지고 있다. 
    2. 학습을 하면서 일정 간격으로 actor는 learner의 파라미터를 복제해온다.

In [1]:
import ray 
import time 
import numpy as np 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
ray.init() 

2021-01-22 21:16:37,642	INFO services.py:1173 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.0.61',
 'raylet_ip_address': '192.168.0.61',
 'redis_address': '192.168.0.61:6379',
 'object_store_address': '/tmp/ray/session_2021-01-22_21-16-37_188833_102412/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-01-22_21-16-37_188833_102412/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-01-22_21-16-37_188833_102412',
 'metrics_export_port': 52487,
 'node_id': 'd7c04ed36ce43a571c20c4cc2c259ac313c260c8'}

In [3]:
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden=32):
        super(QNetwork, self).__init__()

        state_size = state_size[0]
        self.fc1 = nn.Linear(state_size, hidden)
        self.fc2 = nn.Linear(hidden, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

state_size = (4, ) 
action_size = 2 
temp_net = QNetwork(state_size, action_size, 32) 
test = torch.randn(size=(4,)) 
temp_net(test), temp_net(test).shape 

(tensor([-0.0516,  0.1979], grad_fn=<AddBackward0>), torch.Size([2]))

In [4]:
# actor의 역할은 각각 env에서 경험한 것을 buffer에 넘겨주는 역할을 합니다.
@ray.remote
class Actor:
    def __init__(self, 
                 learner: ("class: Learner class"),
                 actor_idx: ('int: Actor index'),
                 hidden: ("int: Update frequency of learner's q_behave network"), 
                 device: ("int: Cuda device number")):
        
        self.learner = learner # ray를 통해 공유하는 learner class입니다.
        self.device = device
        self.actor_idx = actor_idx
        
        # Network parameters
        self.state_dim = (16, )
        self.action_dim = 3
        self.q_behave = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        
    def explore(self):
        print("exploration start..")
        while 1:
            time.sleep(1)
            self.get_weights()
            print('updates done.')

    def get_weights(self):
        weight_copy = self.learner.return_weights.remote()
        weight_copy = ray.get(weight_copy)
        print(type(weight_copy), self.actor_idx)
        self.q_behave.load_state_dict(weight_copy)

In [5]:
# 공유 Buffer를 통해 학습을 진행하는 Learner를 정의합니다. 
# Learner는 buffer에 있는 샘플을 이용하여 network parameter를 업데이트를 하며, agent에게 network weight을 전달합니다.

@ray.remote
class Learner:
    def __init__(self, 
                 hidden: ("int: Update frequency of learner's q_behave network"), 
                 device: ("int: Cuda device number")):
        
        self.state_dim = (16, )
        self.action_dim = 3
        self.device = device
        
        self.q_behave = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        self.q_target = QNetwork(self.state_dim, self.action_dim, hidden).to(self.device)
        self.q_target.load_state_dict(self.q_behave.state_dict())
        self.q_target.eval()

    def return_weights(self):
        return self.q_behave.state_dict() 

In [6]:
hidden = 32
device = "cuda:1" if torch.cuda.is_available() else "cpu"
device = "cpu"

learner = Learner.remote(hidden, device) 

In [7]:
num_actors = 16 # actor의 개수

# num_actors 개수만큼 선언하고, explore 실행. actor라는 변수가 계속 중복이 되지만 실행은 잘 된다.
for actor_idx in range(num_actors):
    actor = Actor.remote(learner, actor_idx, hidden, device)
    actor.explore.remote()

[2m[36m(pid=103442)[0m exploration start..
[2m[36m(pid=103434)[0m exploration start..
[2m[36m(pid=103424)[0m exploration start..
[2m[36m(pid=103438)[0m exploration start..
[2m[36m(pid=103425)[0m exploration start..
[2m[36m(pid=103471)[0m exploration start..
[2m[36m(pid=103474)[0m exploration start..
[2m[36m(pid=103423)[0m exploration start..
[2m[36m(pid=103432)[0m exploration start..
[2m[36m(pid=103429)[0m exploration start..
[2m[36m(pid=103426)[0m exploration start..
[2m[36m(pid=103431)[0m exploration start..
[2m[36m(pid=103479)[0m exploration start..
[2m[36m(pid=103433)[0m exploration start..
[2m[36m(pid=103428)[0m exploration start..
[2m[36m(pid=103427)[0m exploration start..
[2m[36m(pid=103442)[0m <class 'collections.OrderedDict'> 6
[2m[36m(pid=103442)[0m updates done.
[2m[36m(pid=103434)[0m <class 'collections.OrderedDict'> 3
[2m[36m(pid=103434)[0m updates done.
[2m[36m(pid=103424)[0m <class 'collections.OrderedDict'> 

[2m[36m(pid=103442)[0m <class 'collections.OrderedDict'> 6
[2m[36m(pid=103442)[0m updates done.
[2m[36m(pid=103434)[0m <class 'collections.OrderedDict'> 3
[2m[36m(pid=103434)[0m updates done.
[2m[36m(pid=103438)[0m <class 'collections.OrderedDict'> 0
[2m[36m(pid=103438)[0m updates done.
[2m[36m(pid=103425)[0m <class 'collections.OrderedDict'> 5
[2m[36m(pid=103425)[0m updates done.
[2m[36m(pid=103474)[0m <class 'collections.OrderedDict'> 2
[2m[36m(pid=103474)[0m updates done.
[2m[36m(pid=103424)[0m <class 'collections.OrderedDict'> 7
[2m[36m(pid=103424)[0m updates done.
[2m[36m(pid=103471)[0m <class 'collections.OrderedDict'> 4
[2m[36m(pid=103471)[0m updates done.
[2m[36m(pid=103423)[0m <class 'collections.OrderedDict'> 11
[2m[36m(pid=103423)[0m updates done.
[2m[36m(pid=103432)[0m <class 'collections.OrderedDict'> 10
[2m[36m(pid=103432)[0m updates done.
[2m[36m(pid=103429)[0m <class 'collections.OrderedDict'> 1
[2m[36m(pid=1034

[2m[36m(pid=103442)[0m <class 'collections.OrderedDict'> 6
[2m[36m(pid=103442)[0m updates done.
[2m[36m(pid=103434)[0m <class 'collections.OrderedDict'> 3
[2m[36m(pid=103434)[0m updates done.
[2m[36m(pid=103425)[0m <class 'collections.OrderedDict'> 5
[2m[36m(pid=103425)[0m updates done.
[2m[36m(pid=103474)[0m <class 'collections.OrderedDict'> 2
[2m[36m(pid=103474)[0m updates done.
[2m[36m(pid=103424)[0m <class 'collections.OrderedDict'> 7
[2m[36m(pid=103424)[0m updates done.
[2m[36m(pid=103438)[0m <class 'collections.OrderedDict'> 0
[2m[36m(pid=103438)[0m updates done.
[2m[36m(pid=103471)[0m <class 'collections.OrderedDict'> 4
[2m[36m(pid=103471)[0m updates done.
[2m[36m(pid=103423)[0m <class 'collections.OrderedDict'> 11
[2m[36m(pid=103423)[0m updates done.
[2m[36m(pid=103432)[0m <class 'collections.OrderedDict'> 10
[2m[36m(pid=103432)[0m updates done.
[2m[36m(pid=103429)[0m <class 'collections.OrderedDict'> 1
[2m[36m(pid=1034

In [None]:
n_updates = 100 # learner가 update_network 메소드를 실행하는 횟수

for update_idx in range(n_updates): 
    time.sleep(1) 
    loss, batch_stat_shape, act_indices, buf_size = ray.get(learner.update_network.remote())
    print(f'Number of updates: {update_idx}')
    print(f'Loss: {loss}')
    print(f'State shape in Batch: {batch_stat_shape}')
    print(f'Actor index: {act_indices}')
    print(f'Buffer store index: {buf_size}\n')


- Loss: random한 실수값 <br>
- State shape: (batch, state[0], state[1])의 자원을 가지는 출력 <br>
- Actor index: batch 안의 각 sample이 어느 actor에게 나온 것인지 출력 <br>
- Buffer store index: Buffer에 저장되는 현재 store index(각 update 사이에 얼마나 저장되었는지)를 출력  <br><br>

#### 대략 아래와 같은 결과가 나오면 의도대로 나온 것입니다. 

    Number of updates: 9
    Loss: -1.7283143861676746
    State shape in Batch: (16, 2, 2)
    Actor index: [ 4. 12.  1.  3.  4.  4.  1. 14.  2. 15. 11.  0.  1. 15. 15.  9.]
    Buffer store index: 1863

    Number of updates: 10
    Loss: -1.3466382853532786
    State shape in Batch: (16, 2, 2)
    Actor index: [ 9.  8. 13. 15. 14.  9.  0.  4.  2.  8. 13.  7.  2.  2.  0. 11.]
    Buffer store index: 2023

    Number of updates: 11
    Loss: -0.8023523911669711
    State shape in Batch: (16, 2, 2)
    Actor index: [ 3.  9.  9.  7. 12.  3. 12.  6. 12.  5. 10.  7.  0. 11.  3.  6.]
    Buffer store index: 2181