# "가위바위보 강화학습"
> "두 에이전트가 겨루는 가위바위보 환경에서 강화학습을 통하여 승리 정책을 학습한다."

- toc: true
- badges: true
- author: 단호진
- categories: [rl]

마지막 갱신: 2021-07-04

ray 패키지의 rllib은 사용자 모델을 만들어 쓸 수 있게 설계되어 있다. 하지만 세부적인 사항을 이해하고 사용자 모델을 쓰기가 쉽지 않다. 여기에 내부 코드의 이해를 위하여 몇 가지 코드를 작성하여 시험해 본다. 

In [1]:
import ray
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.agents.pg import PGTorchPolicy, PGTrainer
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors

torch, nn = try_import_torch()
torch.__version__, ray.__version__

('1.9.0+cu102', '2.0.0.dev0')

## 가위바위보 환경

* player1과 player2가 취할 수 있는 action은 0, 1, 2이고, 관측치는 상대방의 action에서 나오는 0, 1, 2이다.
* env.reset()은 첫 관측치 {'player1': 0, 'player2': 0}를 내어 준다.
* env.step() 함수는 obs, rewards, done, info를 돌려준다.

In [2]:
env_config = dict()
env = RockPaperScissors(env_config)

In [3]:
obs = env.reset()
obs

{'player1': 0, 'player2': 0}

In [4]:
env.action_space, env.observation_space

(Discrete(3), Discrete(3))

In [5]:
obs, reward, done, info = env.step(dict(player1=1, player2=2))

In [6]:
print(f'obs   : {obs}')
print(f'reward: {reward}')
print(f'done  : {done}')
print(f'info  : {info}')

obs   : {'player1': 2, 'player2': 1}
reward: {'player1': -1, 'player2': 1}
done  : {'__all__': False}
info  : {}


In [7]:
obs, reward, done, info = env.step(dict(player1=1, player2=1))

In [8]:
print(f'obs   : {obs}')
print(f'reward: {reward}')
print(f'done  : {done}')
print(f'info  : {info}')

obs   : {'player1': 1, 'player2': 1}
reward: {'player1': 0, 'player2': 0}
done  : {'__all__': False}
info  : {}


## 정책

정책(Policy)은 에이전트가 관측된 환경, 이전 보상 이력을 참고하여 어떤 행동을 취하면 좋은지 결정한다. 에이전트가 둘인 가위바위보 게임에서 학습 가능한 learned 정책과 RandomMove, BeatLastHeuristic, AlwaysSameHeuristic 정책에서 하나를 뽑아 강확학습을 수행하였다. 가위바위보의 최고 전략은 RandomMove로 learned 정책은 이에 도달하는 결과를 보일 것이고, 그 밖에 BeatLastHeuristic, AlwaysSameHeuristic에 데해서는 쉽게 이길 수 있을 것이다.

새로운 정책을 개발한다면 Policy 클래스를 상속하고 compute_actions에 필요한 로직을 구현하면 된다. RandomMove 정책은 obs_batch에 대해서만 고려하여 배치 크기만큼의 임의 행동을 돌려준다. kwargs에 필수 생성 인자 내용을 정리하였다. 환경의 action_space와 observation_space가 그것이다.

In [9]:
import random
from ray.rllib.policy.policy import Policy
# from ray.rllib.policy.view_requirement import ViewRequirement


class RandomMove(Policy):
    """Pick a random move"""
    def __init__(self, obs_space, act_space, config):
        super().__init__(obs_space, act_space, config)
        
    def random_action(self):
        return random.choice([
            RockPaperScissors.ROCK,
            RockPaperScissors.PAPER,
            RockPaperScissors.SCISSORS,
        ])
    
    def compute_actions(
        self,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
        info_batch=None,
        episodes=None,
        **kwargs):
        """Returns:
            Tuple:
                actions: [BATCH_SIZE, ACTION_SHAPE]
        """
        return [self.random_action() for _ in obs_batch], [], {}

In [10]:
random_policy = RandomMove(env.observation_space, env.action_space, {})
random_policy.compute_actions(list(range(3)))

([1, 2, 2], [], {})

## Custom policy with template

In [11]:
# https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_policy.py

from ray.rllib.policy.torch_policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch

def policy_gradient_loss(policy, model, dist_class, train_batch):
    logits, _ = model.from_batch(train_batch)
    action_dist = dist_class(logits)
    log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
#     print(train_batch[SampleBatch.REWARDS].dtype)
#     print(log_probs.dtype)
    return -train_batch[SampleBatch.REWARDS].float().dot(log_probs)

MyTorchPolicy = build_policy_class(
    framework='torch', name='MyTorchPolicy', loss_fn=policy_gradient_loss,
)

## Custom loss model

In [12]:
# https://github.com/ray-project/ray/blob/master/rllib/examples/models/custom_loss_model.py
from typing import Dict
import numpy as np

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import TensorType


class TorchCustomLossModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        
        print(f'model config: {model_config}')
        
        self.fcnet = FullyConnectedNetwork(
            self.obs_space,
            self.action_space,
            num_outputs,
            model_config,
            name="fcnet"
        )
        
    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        return self.fcnet(input_dict, state, seq_lens)
    
    @override(ModelV2)
    def value_function(self):
        return self.fcnet.value_function()
    
    @override(ModelV2)
    def custom_loss(self,
                    policy_loss: TensorType,
                    loss_inputs: Dict[str, TensorType]) -> TensorType:
        logits, _ = self.forward({'obs': loss_inputs['obs']}, [], None)
        action_dist = TorchCategorical(logits, self.model_config)
        imitation_loss = torch.mean(
            -action_dist.logp(loss_inputs['actions'].to(policy_loss[0].device))
        )
        self.imitation_loss_metric = imitation_loss.item()
        self.policy_loss_metric = np.mean([
            loss.item() for loss in policy_loss
        ])
        
        return [loss_ + 10 * imitation_loss for loss_ in policy_loss]
    
    def metrics(self):
        return {
            'policy_loss': self.policy_loss_metric,
            'imitation_loss': self.imitation_loss_metric,
        }
    
    
ModelCatalog.register_custom_model('my_torch_model', TorchCustomLossModel)

## 트레이너 설정 및 학습

앞서 정의한 RandomMove외에 ray에서 제공하는 BeatLastHeuristic, AlwaysSameHeuristic 정책을 추가하였다.

In [13]:
import ray
from gym.spaces import Discrete
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.examples.policy.rock_paper_scissors_dummies import (
    BeatLastHeuristic, AlwaysSameHeuristic
)


config = {
    'env': RockPaperScissors,
    'gamma': 0.9,
    'num_gpus': 0,
    'num_workers': 0,
    'num_envs_per_worker': 4,
    'train_batch_size': 200,  # for the policy model
    'multiagent': {
        'policies': {
            'random_move': (RandomMove, Discrete(3), Discrete(3), {}),
            'beat_last': (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
            'always_same': (AlwaysSameHeuristic, Discrete(3), Discrete(3), {}),
            'learned0': (None, Discrete(3), Discrete(3), {
                'framework': 'torch',
                'model': {},  # use default
            }),
            'learned1': (MyTorchPolicy, Discrete(3), Discrete(3), {
                'framework': 'torch',
                'model': {},  # use default
            }),
            'learned2': (None, Discrete(3), Discrete(3), {
                'framework': 'torch',
                'model': {
                    'custom_model': 'my_torch_model',
                    'custom_model_config': {},
                },  # use default
            }),
        },
        'policy_mapping_fn': lambda agent_id, episode, **kwargs: (
            'learned2' if agent_id == 'player1' else 'beat_last'),
        'policies_to_train': ['learned0', 'learned1', 'learned2'],
    },
    'framework': 'torch',
}

# ray.shutdown()
ray.init()
trainer = get_trainer_class('PG')(config=config)

2021-07-04 11:54:03,897	INFO services.py:1330 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-07-04 11:54:04,893	INFO trainer.py:714 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


model config: {'_use_default_native_models': False, 'fcnet_hiddens': [256, 256], 'fcnet_activation': 'tanh', 'conv_filters': None, 'conv_activation': 'relu', 'post_fcnet_hiddens': [], 'post_fcnet_activation': 'relu', 'free_log_std': False, 'no_final_linear': False, 'vf_share_layers': True, 'use_lstm': False, 'max_seq_len': 20, 'lstm_cell_size': 256, 'lstm_use_prev_action': False, 'lstm_use_prev_reward': False, '_time_major': False, 'use_attention': False, 'attention_num_transformer_units': 1, 'attention_dim': 64, 'attention_num_heads': 1, 'attention_head_dim': 32, 'attention_memory_inference': 50, 'attention_memory_training': 50, 'attention_position_wise_mlp_dim': 32, 'attention_init_gru_gate_bias': 2.0, 'attention_use_n_prev_actions': 0, 'attention_use_n_prev_rewards': 0, 'num_framestacks': 0, 'dim': 84, 'grayscale': False, 'zero_mean': True, 'custom_model': 'my_torch_model', 'custom_model_config': {}, 'custom_action_dist': None, 'custom_preprocessor': None, 'lstm_use_prev_action_rewa

In [14]:
for _ in range(1):
    results = trainer.train()
    
for k, v in results.items():
    print(f'{k}: {v}')

episode_reward_max: 0.0
episode_reward_min: 0.0
episode_reward_mean: 0.0
episode_len_mean: 10.0
episode_media: {}
episodes_this_iter: 80
policy_reward_min: {'learned2': -6.0, 'beat_last': -6.0}
policy_reward_max: {'learned2': 6.0, 'beat_last': 6.0}
policy_reward_mean: {'learned2': 0.0625, 'beat_last': -0.0625}
custom_metrics: {}
hist_stats: {'episode_reward': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'episode_lengths': [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 

In [15]:
for k in range(1, 201):
    results = trainer.train()
    if k % 10 == 0:
        print(k,
              results['episode_reward_mean'],
              results['policy_reward_mean'],
              '\t', results['timesteps_total'],
              '\t', results['episodes_total'])

10 0.0 {'learned2': -0.03, 'beat_last': 0.03} 	 8800 	 880
20 0.0 {'learned2': -0.02, 'beat_last': 0.02} 	 16800 	 1680
30 0.0 {'learned2': 0.1, 'beat_last': -0.1} 	 24800 	 2480
40 0.0 {'learned2': 0.0, 'beat_last': 0.0} 	 32800 	 3280
50 0.0 {'learned2': 0.08, 'beat_last': -0.08} 	 40800 	 4080
60 0.0 {'learned2': -0.06, 'beat_last': 0.06} 	 48800 	 4880
70 0.0 {'learned2': 0.02, 'beat_last': -0.02} 	 56800 	 5680
80 0.0 {'learned2': 0.39, 'beat_last': -0.39} 	 64800 	 6480
90 0.0 {'learned2': 0.8, 'beat_last': -0.8} 	 72800 	 7280
100 0.0 {'learned2': 1.31, 'beat_last': -1.31} 	 80800 	 8080
110 0.0 {'learned2': 1.74, 'beat_last': -1.74} 	 88800 	 8880
120 0.0 {'learned2': 2.74, 'beat_last': -2.74} 	 96800 	 9680
130 0.0 {'learned2': 4.71, 'beat_last': -4.71} 	 104800 	 10480
140 0.0 {'learned2': 5.12, 'beat_last': -5.12} 	 112800 	 11280
150 0.0 {'learned2': 5.38, 'beat_last': -5.38} 	 120800 	 12080
160 0.0 {'learned2': 5.67, 'beat_last': -5.67} 	 128800 	 12880
170 0.0 {'learned2

학습이 진행되면서 learned 정책이 BeatLastHeuristic 정책을 쉽게 이기는 것을 알 수 있다.

## 맺으며

ray의 rllib은 다양한 강화학습 방식을 지원할뿐만 아니라 병렬 학습에 대한 처리가 매우 우수하다. 게다가 활발하게 코드가 관리되고 있다. 다만, 사용자 환경이나 모델을 잘 정의하여 활용하기 위해선 이론적 배경과 rllib의 내부 호출 구조를 잘 알아야 한다.