# 18. 8th Attempt: MCTS with Networks

## MCTS

AlphaZero에서 사용한 MCTS 전략은 다음과 같습니다:

- 하나의 network에서 말단을 2개로 나눠서 policy network와
value network로 분리
- Policy network ($\pi$)는 탐색하는 우선순위를 확률로 반환
- Value network ($v$)는 해당 위치에 돌을 두었을 때의 승률을
$-1$에서 $1$ 사이로 반환
- Agent가 학습을 할 때 MCTS를 특정 iteration만큼 실행시키고,
그 때 가장 많이 방문한 노드를 다음 action으로 선택.
(예를 들어서 1600번 정도라든가)
- 탐색 우선순위를 정할 때는 선택지 노드의 승률 누적값 $W$,
방문횟수 $N$, $\pi$의 값 $P$, 부모 노드의 방문횟수 $N_p$,
탐색 조절 상수 $c_e$에 대해 UCB1 score인
$$u = \frac{W}{N} + c_e P \sqrt{\frac{\log N_p}{N}}$$
을 사용.
- 학습은 이전과 같이 policy에 대해서는 cross entropy,
value에 대해서는 MSE 등을 사용.

아래에서는 위 방식과 비슷하게 구현을 합니다.
다만, MCTS를 사용하면 착수에 너무 오랜 시간이 걸리기 때문에
self-play를 할 경우 학습 시간이 끔찍하게 늘어나게 될 것이므로
상대로는 이전까지 써온 `agent_mixed`를 사용합니다.
또한 iteration 횟수는 100번 정도로 제한합니다.

### Environment

In [2]:
!rm -rf mock5.py mock5 gen.cpp read_record.py gen
!git clone https://github.com/lumiknit/mock5.py
!mv mock5.py/mock5 .
!mv mock5.py/gen2_record/gen.cpp .
!mv mock5.py/gen2_record/read_record.py .
!g++ -O2 -fopenmp -o gen gen.cpp

from mock5 import Mock5
from mock5.analysis import Analysis as M5Analysis
import mock5.agent_random as m5rand
import mock5.agent_analysis_based as m5aa
import mock5.agent_ad as m5ad
import mock5.agent_pt as m5pt
import mock5.agent_df as m5df

import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import os
import time

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: {}".format(device))

Cloning into 'mock5.py'...
remote: Enumerating objects: 152, done.[K
remote: Counting objects: 100% (152/152), done.[K
remote: Compressing objects: 100% (101/101), done.[K
remote: Total 152 (delta 74), reused 121 (delta 47), pack-reused 0[K
Receiving objects: 100% (152/152), 40.18 KiB | 10.04 MiB/s, done.
Resolving deltas: 100% (74/74), done.
Device: cuda


### Board Size

In [3]:
# Board Size
W = 15
H = W

### Agent & Policy Transformer

In [4]:
def fn_name(fn):
  if hasattr(fn, 'name'): return fn.name
  else: return repr(fn)
  
def agent(pi, epsilon=0):
  # pi must return array of non-negative values
  def c(game):
    w, h = game.width, game.height
    m, p = np.ones(h * w), np.array(pi(game))
    for i in range(h * w):
      if game.board[i] != 0: m[i], p[i] = 0, 0
    s = p.sum()
    if np.random.uniform() < epsilon or s == 0:
      s = m.sum()
      if s == 0: return None, None # Cannot do anything
      else: idx = np.random.choice(h * w, p=(m / s))
    else: idx = np.random.choice(h * w, p=(p / s))
    return idx // w, idx % w
  c.name = 'stochastic({})'.format(fn_name(pi))
  return c

def softmax(arr, tau=1.0):
  arr = np.array(arr, dtype=np.float64)
  arr /= tau
  m = max(arr)
  z = np.exp(arr - m)
  return z / z.sum()

def pt_softmax(policy, tau=1.0):
  def p(game):
    p = policy(game)
    return softmax(p, tau=tau)
  p.name = 'pt_softmax({},tau={})'.format(fn_name(policy), tau)
  return p

def pt_norm(policy):
  def p(game):
    p = policy(game)
    return p / p.max()
  p.name = 'pt_norm({})'.format(fn_name(policy))
  return p

def policy_uniform(game):
  return np.ones(game.height * game.width)
policy_uniform.name = 'uniform'

def agent_mixed(game):
  a = np.random.uniform()
  if a < 0.4: return m5aa.agent(game)
  elif a < 0.8: return m5ad.agent(game)
  elif a < 0.9: return m5pt.agent(game)
  else: return m5df.agent(game)
agent_mixed.name = 'agent-mixed-analysis-based'

def test_agents(num_game, agent1, agent2):
  w1, w2 = 0, 0
  c1, c2 = 0, 0
  for i in range(num_game):
    g = Mock5(H, W)
    result = g.play(agent1, agent2,
                    print_intermediate_state=False, print_messages=False)
    if result == 1:
      w1 += 1
      c1 += len(g.history)
    elif result == 2:
      w2 += 1
      c2 += len(g.history)
  print("-- Test Result --")
  print("* Agent1 = {} \n* Agent2 = {}".format(
      fn_name(agent1), fn_name(agent2)))
  print("Total : {:5d}".format(num_game))
  print("A1 Win: {:5d} ({:5.3f}) (avg.mov {:6.1f})".format(w1, w1 / num_game, c1 / max(1, w1)))
  print("A2 Win: {:5d} ({:5.3f}) (avg.mov {:6.1f})".format(w2, w2 / num_game, c2 / max(1, w2)))

In [5]:
class Stat:
  def __init__(self, n=0, w=0, l=0, d=0): self.n, self.w, self.l, self.d = n, w, l, d
  def dup(self):
    return Stat(self.n, self.w, self.l, self.d)
  def __sub__(self, other):
    return Stat(self.n - other.n, self.w - other.w, self.l - other.l, self.d - other.d)
  def win(self):
    self.n += 1
    self.w += 1
  def lose(self):
    self.n += 1
    self.l += 1
  def draw(self):
    self.n += 1
    self.d += 1
  def update_by_result(self, result):
    self.n += 1
    if result == 1: self.w += 1
    elif result > 1: self.l += 1
    else: self.d += 1
  def win_rate(self):
    if self.n >= 1: return (self.w + 0.5 * self.d) / self.n
    else: return np.nan

### Neural Network

In [6]:
class Flatten(nn.Module):
  def forward(self, x):
    if len(x.shape) == 3: return x.view(-1)
    else: return x.flatten(1, -1)

class FixedTempSoftmax(nn.Module):
  def __init__(self, dim, tau):
    super(FixedTempSoftmax, self).__init__()
    self.dim = dim
    self.tau = np.exp(tau)
  def forward(self, x):
    return nn.functional.softmax(x / self.tau, self.dim)

class TempSoftmax(nn.Module):
  def __init__(self, dim, tau=0):
    super(TempSoftmax, self).__init__()
    self.dim = dim
    self.tau = nn.Parameter(torch.tensor([tau], dtype=torch.float))
  def forward(self, x):
    return nn.functional.softmax(x / torch.exp(self.tau), self.dim)

In [25]:
## nn
class Expert(nn.Module):
  def __init__(self):
    super().__init__()
    # back
    self.decoder = nn.Sequential(
      nn.Conv2d(3, 256, 5, padding='same'),
      nn.GELU(),
      nn.Conv2d(256, 64, 3, padding='same'),
      nn.GELU(),
      nn.Conv2d(64, 64, 3, padding='same'),
      nn.GELU(),
      nn.Conv2d(64, 64, 3, padding='same'),
      nn.GELU(),
      nn.Conv2d(64, 64, 3, padding='same'),
      nn.GELU(),
      nn.Conv2d(64, 64, 3, padding='same'),
      nn.GELU(),
    )
    # pi
    self.policy = nn.Sequential(
      nn.Conv2d(64, 16, 3, padding='same'),
      nn.GELU(),
      nn.Conv2d(16, 1, 1, padding='same'),
      Flatten(),
      nn.LogSoftmax(dim=-1)
    )
    # v
    self.value = nn.Sequential(
      nn.Conv2d(64, 16, 3, padding='same'),
      nn.GELU(),
      nn.Conv2d(16, 1, 1, padding='same'),
      nn.AvgPool2d(W),
      nn.Tanh()
    )

  def forward(self, x):
    y = self.decoder(x)
    p = self.policy(y)
    v = self.value(y)
    return p, v

  def np_wrapper(self):
    def e(game):
      t = game.tensor(dtype=torch.float).unsqueeze(dim=0).to(device)
      with torch.no_grad():
        p, v = self.forward(t)
      p = np.exp(p.cpu().detach().squeeze().numpy())
      for i in range(H * W):
        if game.board[i] != 0:
          p[i] = -float('inf')
      v = v.squeeze().item()
      return p, v
    return e

In [8]:
UCB_POLICY_FACTOR = 1.4

class MCTSNode:
  def __init__(self, p, policy, wr, n=1):
    self.n = n
    self.wr = wr
    self.p = p
    self.policy = policy
    self.children = [None] * (H * W)

  def ucb(self, n_tot):
    # WR/N + C * P * sqrt(log(N_tot) / (1 + N))
    v = self.wr / self.n
    e = np.sqrt(np.log(1 + n_tot) / (1 + self.n))
    return v + UCB_POLICY_FACTOR * self.p * e
  
  def choose_child(self, game):
    max_u, max_i = -float('inf'), None
    for i in range(H * W):
      if game.board[i] == 0:
        if self.children[i] is None:
          u = UCB_POLICY_FACTOR * self.policy[i] * np.sqrt(np.log(1 + self.n))
        else:
          u = self.children[i].ucb(self.n)
        u += np.random.uniform() * 0.0001
        if u > max_u:
          max_u, max_i = u, i
    return max_i, self.policy[max_i]

class MCTSwExpert:
  def __init__(self, game, expert):
    self.game = game.replay()
    self.t = len(self.game.history)
    self.expert = expert
    policy, wr = self.expert(self.game)
    self.root = MCTSNode(1.0, policy, wr, 0)
  
  def restore_game(self):
    for i in range(len(self.game.history) - self.t):
      self.game.undo()

  def explore(self):
    traj = []
    node = self.root
    while node is not None:
      w = self.game.check_win()
      if w is not None: break
      idx, p = node.choose_child(self.game)
      traj.append(node)
      node = node.children[idx]
      self.game.place_stone_at_index(idx)
    
    if w is not None:
      back_wr = 1
    elif node is None:
      policy, wr = self.expert(self.game)
      node = MCTSNode(traj[-1].policy[idx], policy, wr)
      traj[-1].children[idx] = node
      back_wr = wr
    for i in range(len(traj) - 1, -1, -1):
      back_wr = -back_wr
      traj[i].n += 1
      traj[i].wr += back_wr
    self.restore_game()
  
  def action_wr(self):
    max_n = -1
    max_i = None
    for i in range(H * W):
      c = self.root.children[i] 
      if c is not None and max_n < c.n:
        max_n, max_i = c.n, i
    return max_i, self.root.wr / self.root.n

  def decide(self, n_iter):
    for i in range(n_iter): self.explore()
    return self.action_wr()

  def __str__(self):
    def p(depth, idx, node):
      s = ("  " * depth) + "D {:3d}@{:3d} W/N = {:.3f}/{} = {:.4f}".format(
          depth, idx, node.wr, node.n, node.wr / node.n)
      for i in range(len(node.children)):
        if node.children[i] is not None:
          s += "\n" + p(depth + 1, i, node.children[i])
      return s
    return "[MCTS]\n" + p(0, -1, self.root)

In [23]:
# REINFORCE
def learn(
    expert,
    opt,
    n_episode,
    n_mcts_iter,
    n_opt_iter,
    interval_stat
):
  stat = Stat()
  last_stat = stat.dup()
  last_stat_epi = -1
  loss_acc = 0
  for epi in range(n_episode):
    # Generate episode
    result = 0
    Xs, As = [], []
    t = epi % 2
    game = Mock5(H, W)
    for turn in range(H * W + 1):
      result = game.check_win()
      if result is not None: break
      if turn % 2 == t:
        mcts = MCTSwExpert(game, expert.np_wrapper())
        action, wr = mcts.decide(n_mcts_iter)
        #print("T#{}: action={:3d}, wr={:5f}".format(turn, action, wr))
      else:
        r, c = agent_mixed(game)
        action = r * W + c
      game.place_stone_at_index(action)
    if result == 0: stat.draw()
    elif result == t + 1: stat.win()
    else: stat.lose()
    Rs, Qs = [], []
    for f in range(2):
      for r in range(4):
        g = game.replay(angle=r, flip=f)
        while len(g.history) > 0:
          As.append(g.history[-1])
          g.undo()
          Xs.append(g.tensor(dtype=torch.float))
          if result == 0:
            Rs.append(0.0)
            Qs.append(0.0)
          elif result == g.player:
            Rs.append(1.0)
            Qs.append(1.0)
          else:
            Rs.append(-1.0)
            Qs.append(0.0)
    # Tensor-fy
    x = torch.stack(Xs, dim=0).to(device)
    r = torch.tensor(Rs).to(device)
    q = torch.tensor(Qs).to(device)
    # Learn
    loss_pi_list = []
    loss_v_list = []
    for i in range(n_opt_iter):
      log_p, v = expert(x)
      log_p_a = log_p.gather(1, torch.tensor(As).unsqueeze(dim=1).to(device)).squeeze(dim=1)
      loss_pi = - (log_p_a * q).mean()
      loss_v = nn.SmoothL1Loss()(v.squeeze(), r)
      loss = loss_pi + 1e1 * loss_v

      opt.zero_grad()
      loss.backward()
      opt.step()

      loss_pi_list.append(loss_pi.item())
      loss_v_list.append(loss_v.item())
    
    # Print status and evaluate
    if epi - last_stat_epi >= interval_stat:
      #save_policy()
      print("----------")
      print("Ep #{:<6d} Loss Change Accum {:13.10f}".format(
        epi + 1, loss_acc / (epi - last_stat_epi)))
      print("  Win Rate {:8.4f}% ({}w + {}d + {}l = {})".format(
            100 * (stat.w + stat.d * 0.5) / stat.n,
            stat.w, stat.d, stat.l, stat.n))
      dstat = stat - last_stat
      print("   WR Diff {:8.4f}% ({}w + {}d + {}l = {})".format(
          100 * (dstat.w + dstat.d * 0.5) / dstat.n,
          dstat.w, dstat.d, dstat.l, dstat.n))
      print("LOSS PI {:.6f} -> {:.6f}".format(loss_pi_list[0], loss_pi_list[-1]))
      print("LOSS V  {:.6f} -> {:.6f}".format(loss_v_list[0], loss_v_list[-1]))
      print(game)
      print("RESULT = {}".format(result))
      last_stat_epi = epi
      loss_acc = 0

### Learning

In [26]:
def run():
  expert = Expert().to(device)

  opt = optim.Adam(expert.parameters(),
                   lr=1e-4,
                   weight_decay=1e-6)
  
  learn(
      expert = expert,
      opt = opt,
      n_episode = 200000,
      n_mcts_iter = 100,
      n_opt_iter = 2,
      interval_stat = 50)

  return expert

expert = run()

----------
Ep #50     Loss Change Accum  0.0000000000
  Win Rate   0.0000% (0w + 0d + 50l = 50)
   WR Diff   0.0000% (0w + 0d + 50l = 50)
LOSS PI 1.771570 -> 1.769580
LOSS V  0.493609 -> 0.492900
 [ Turn   9 ; 2P's turn (tone = X) ]
  | 0 1 2 3 4 5 6 7 8 9 A B C D E
--+------------------------------
0 | . . . . . . . . . . . . . . .
1 | . . . . . . . . . . . . . . .
2 | . . . . . . . . . . . . . . .
3 | . . . . . . . . . . . . . . .
4 | . . . . . . . . . . . . . . .
5 | . . . . . . . O . . . . . . .
6 | . . . . . . X O . . . . . . .
7 | . . . . . . X O X . . . . . .
8 | . . . . . . X O . . . . . . .
9 | . . . . . . . O . . . . . . .
A | . . . . . . . . . . . . . . .
B | . . . . . . . . . . . . . . .
C | . . . . . . . . . . . . . . .
D | . . . . . . . . . . . . . . .
E | . . . . . . . . . . . . . . .
RESULT = 1
----------
Ep #100    Loss Change Accum  0.0000000000
  Win Rate   0.0000% (0w + 0d + 100l = 100)
   WR Diff   0.0000% (0w + 0d + 100l = 100)
LOSS PI 1.652407 -> 1.641434
LOSS V 

KeyboardInterrupt: ignored

## 고찰

- 우선 위 진행 결과는 약 8시간 동안 1700번 게임을 진행하여
얻었습니다.
- 전에 policy network만으로 진행하는 경우에는 약 10분이면
1700번 이상을 진행하고도 남는 수준인데, 그에 비해 매우 오래
걸린다는 것을 확인할 수 있습니다.
- 중간에 출력된 기보를 보면 처음 1100판까지는
이전 학습에서도 보였던 '제대로 방어를 못 하고 말리는 수'를
두는 것을 보이다가, 그 이후부터 나름 공방을 벌이는 것을
볼 수 있습니다.
- 사실 결과적으로 학습에는 최종 기보만 쓰이다보니,
처음에 MCTS로 학습하는 대신에, 다른 기보를 넣어서 학습 시간을
줄일 수 있지 않을까 싶습니다.
- 마찬가지 이유로, 초반에는 MCTS의 iteration을 크게 잡아도
좋은 수를 거의 선택하지 못하다보니, iteration 횟수도
학습이 진행됨에 따라 점차 조절해나가는 편이 낫지 않을까 싶습니다.