<a href="https://colab.research.google.com/github/coregvy/Con4AI/blob/master/handson1_sb3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# コネクトフォー の ゲームAIを作ろう(SB3)

Stable Baselines3 を使用したハンズオンです。


## Install Dependencies and Stable Baselines3 Using Pip


In [None]:
!apt-get update && apt-get install -y -q ffmpeg freeglut3-dev xvfb  # For visualization
!pip install stable-baselines3[extra] --quiet
!pip install pyglet==1.4 --quiet


## Imports

In [None]:
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.ppo.policies import MlpPolicy
import random
import re


In [None]:
class Dist(gym.Env):
  MAX_X = 7
  MAX_Y = 7
  ACTION_TOP = 0
  ACTION_RIGHT = 1
  ACTION_BOTTOM = 2
  ACTION_LEFT = 3
  PRE = [0,0,0]
  ACTION_MARK = ['↑','→','↓','←']
  POS_MY = [0,0]
  POS_GOAL = [MAX_X, MAX_Y]

  def __init__(self) -> None:
    super(Dist, self).__init__()
    self.initState()
    self.action_space = gym.spaces.Discrete(4)
    self.observation_space = gym.spaces.Box(low=0, high=self.MAX_X * self.MAX_Y - 1, shape=(5,), dtype = np.uint8)

  def reset(self):
    # 状態を初期化します
    self.initState()
    return self.state()

  def step(self, action):
    self.PRE = [self.POS_MY[0], self.POS_MY[1], action]
    # 座標移動
    if action == self.ACTION_TOP:
      self.POS_MY[1] -= 1
    elif action == self.ACTION_RIGHT:
      self.POS_MY[0] += 1
    elif action == self.ACTION_BOTTOM:
      self.POS_MY[1] += 1
    elif action == self.ACTION_LEFT:
      self.POS_MY[0] -= 1
    
    # ゴールまでの距離を計算
    dist = self.distance(self.POS_MY, self.POS_GOAL)
    # ゴールしたか、枠外に出た場合は終了
    done = (dist == 0) or self.isOver()
    reward = 0
    if dist == 0:
      reward = 1
    return self.state(), reward, done, {}

  # 現在の盤面の状態をコンソール表示
  def render(self, mode='console', close=False):
    print(str(self.state()))
    for y in range(self.MAX_Y):
      print('| ', end='')
      for x in range(self.MAX_X):
        print(self.renderMark(x, y), end='')
      print(' |')

  def renderMark(self, x, y):
    if (x == self.POS_MY[0]) and (y == self.POS_MY[1]):
      return 'P'
    elif (x == self.POS_GOAL[0]) and (y == self.POS_GOAL[1]):
      return 'G'
    elif (x == self.PRE[0]) and (y == self.PRE[1]):
      return self.ACTION_MARK[self.PRE[2]]
    else:
      return '.'
  
  # ゲームをリセットする(P/G の位置をランダムに配置)
  def initState(self):
    self.POS_MY = [random.randrange(self.MAX_X - 1), random.randrange(self.MAX_Y - 1)]
    self.POS_GOAL = [random.randrange(self.MAX_X - 1), random.randrange(self.MAX_Y - 1)]
    # ゴールまで近すぎる場合はやり直し
    if self.distance(self.POS_MY, self.POS_GOAL) < 3:
      self.initState()

  # 枠外に出たかどうかを判定
  def isOver(self):
    if self.POS_MY[0] < 0 or self.POS_MY[1] < 0 or self.POS_MY[0] >= self.MAX_X or self.POS_MY[1] >= self.MAX_Y:
      return True
    return False 

  def state(self):
    return [self.POS_MY[0], self.POS_MY[1], self.POS_GOAL[0], self.POS_GOAL[1], self.distance(self.POS_MY, self.POS_GOAL)]

  # マンハッタン距離を返す
  def distance(self, a, b):
    return abs(a[0] - b[0]) + abs(a[1] - b[1])


In [None]:
env0 = Dist()

# モデルの生成
#  verbose：ログの詳細表示(0:ログなし、1:訓練情報を表示、2:TensorFlowログを表示)
model0 = PPO(MlpPolicy, env0, verbose=0, tensorboard_log="./log/step0")

# モデルの学習回数
sample = 10000
model0.learn(total_timesteps=sample, tb_log_name="first_run")
# モデルをファイル保存
model0.save('dist_model_' + str(sample))

print('training end')


In [None]:
%load_ext tensorboard
%tensorboard --logdir=./log/

In [None]:
# 状態をリセット
env0.reset()
i = 0
while True:
  i += 1
  # 次のAIの行動を取得する
  action, _ = model0.predict(env0.state())
  state, reward, done, info = env0.step(action)
  # 現在の状態を描画
  env0.render()

  if done:
    print('end: ', i)
    break
  print('next action:', done, action, reward)


### game_util.py

コネクトフォー独自のルールやロジックなど

In [None]:
class GameUtil:
  @staticmethod
  def stdinToState(stdin, blank='0', my='1', your='2'):
    # サーバから受け取る盤面の情報を読み込む
    ao = stdin.splitlines()
    meta = ao.pop(0).split(' ')
    return list(map(lambda x: list(re.sub('[^MB]', 'Y', x.replace('.', 'B').replace(meta[2], 'M')).replace('Y', your).replace('B', blank).replace('M', my)), ao))

  def listToState(ao, meta, blank=0, my=1, your=2):
    # 一次元配列で受け取った盤面の情報を二次元配列に置き換える
    for row in range(int(meta[1])):
      for col in range(int(meta[0])):
        if ao[row][col] == '.':
          ao[row][col] = blank
        elif ao[row][col] == meta[2]:
          ao[row][col] = my
        else:
          ao[row][col] = your
    return ao

  @staticmethod
  def resetState(row = 6, col = 7):
    return [[0] * col for i in range(row)]

  @staticmethod
  def fallCoin(state, action, mark=1, blank=0):
    """ Return new state

    Args:
        state (list[list[str]]): state list
        action (number): [colmn number]
        mark (str, optional): [description]. Defaults to '1'.
        blank (str, optional): [description]. Defaults to '0'.

    Returns:
        list: new state
    """
    fallNg = True
    for ry in range(len(state)):
      y = len(state) - ry - 1
      if state[y][action] == blank:
        state[y][action] = mark
        fallNg = False
        break
    return state, fallNg

  @staticmethod
  def checkEnd(state, goal=4, blank=0):
    """ Check if the game is finished

    Args:
        state (list[list[str]]): game state list
        goal (int, optional): goal count. Defaults to 4.
        blank (str, optional): blank mark. Defaults to '0'.

    Returns:
        str: Win mark or blank
    """
    # GameUtil.render(state)
    # check row(-)
    for row in range(len(state)):
      for col in range(len(state[row]) - goal + 1):
        tmpMark = state[row][col]
        if tmpMark == blank:
          continue
        for p in range(goal - 1):
          if tmpMark != state[row][col + p + 1]:
            tmpMark = blank
            break
        
        if tmpMark != blank:
          return tmpMark

    # check col(|)
    for col in range(len(state[0])):
      for row in range(len(state) - goal + 1):
        tmpMark = state[row][col]
        if tmpMark == blank:
          continue
        for p in range(goal - 1):
          if tmpMark != state[row + p + 1][col]:
            tmpMark = blank
            break

        if tmpMark != blank:
          return tmpMark

    # check /
    for row in range(goal - 1, len(state)):
      for col in range(0, len(state[row]) - goal + 1):
        tmp = state[row][col]
        if tmp == blank:
          continue
        for r in range(1, goal):
          if tmp != state[row - r][col + r]:
            tmp = blank
            break
        if tmp != blank:
          return tmp

    # check \
    for row in range(len(state) - goal + 1):
      for col in range(len(state[row]) - goal + 1):
        tmp = state[row][col]
        if tmp == blank:
          continue
        for r in range(1, goal):
          if tmp != state[row + r][col + r]:
            tmp = blank
            break
        if tmp != blank:
          return tmp

    return blank

  @staticmethod
  def checkReach(state, mark = 1, goal=4, blank=0):
    """ Check if the game is Reach

    Args:
        state (list[list[str]]): game state list
        goal (int, optional): goal count. Defaults to 4.
        blank (str, optional): blank mark. Defaults to '0'.

    Returns:
        str: Win mark or blank
        pos: reach column
    """
    ret = []
    for i in range(len(state[0])):
      ps, _ = GameUtil.fallCoin(GameUtil.stateCopy(state), i, mark, blank)
      ec = GameUtil.checkEnd(ps, goal, blank)
      if ec != blank:
        ret.append([ec, i])
    return ret

  @staticmethod
  def stateCopy(state):
    row = len(state)
    col = len(state[0])
    ret = [[0] * col for i in range(row)]
    for r in range(row):
      for c in range(col):
        ret[r][c] = state[r][c]
    return ret


  @staticmethod
  def render(state, my = 1, blank = 0):
    print('-0-1-2-3-4-5-6-')
    for i in range(len(state)):
      print(' ', end='')
      for j in range(len(state[i])):
        mark = '☆'
        if state[i][j] == my:
          mark = '◆'
        elif state[i][j] == blank:
          mark = '・'
        print(mark, end='')
      print()
    print('--------------')

  @staticmethod
  def enemyPlay(state):
    # todo
    pos = random.randrange(7)
    if state[0][pos] == 0:
      return pos
    else:
      return GameUtil.enemyPlay(state)


In [None]:
# test
state = GameUtil.resetState(6, 7)
GameUtil.render(state)
GameUtil.fallCoin(state, 2)
GameUtil.render(state)
done = GameUtil.checkEnd(state)
print('check: ', done)


### environment.py

StableBaselines の環境クラス

In [None]:

class Con4(gym.Env):
  MY_MARK = 1
  BLANK_MARK = 0
  MAX_ROW = 6
  MAX_COL = 7

  def __init__(self):
    super(Con4, self).__init__()
    self.board = GameUtil.resetState(self.MAX_ROW, self.MAX_COL)
    self.action_space = gym.spaces.Discrete(self.MAX_COL)
    self.observation_space = gym.spaces.Box(low=0, high=2, shape=(self.MAX_ROW, self.MAX_COL))

  def reset(self):
    # 状態を初期化します
    self.board = GameUtil.resetState(self.MAX_ROW, self.MAX_COL)
    return self.board

  def step(self, action):
    reward = 0
    done = False
    # アクション実行後の状態を取得する
    self.board, stepNg = GameUtil.fallCoin(self.board, action, self.MY_MARK, self.BLANK_MARK)
    if stepNg:
      # この列にコインをこれ以上落とせなかった
      done = True
      reward = -10000
      return self.board, reward, done, {}
    # 相手の行動を追加する
    self.board, stepNg = GameUtil.fallCoin(self.board, GameUtil.enemyPlay(self.board), 2, self.BLANK_MARK)
    # ゲームが終了したかどうかを確認する
    win = GameUtil.checkEnd(self.board)
    if win == self.MY_MARK:
      # 自分が勝った
      done = True
      reward = 1.0
    elif win != self.BLANK_MARK:
      # 相手が勝った
      done = True
      reward = -1
    return self.board, reward, done, {}

  def render(self, mode='console', close=False):
    GameUtil.render(self.board, self.MY_MARK, self.BLANK_MARK)

  def initState(self):
    """ 盤面を初期化する

    Returns:
        list: 初期化された盤面の2次元配列
    """
    return [[self.BLANK_MARK] * self.MAX_COL for i in range(self.MAX_ROW)]

### training

指定回数反復学習し、結果をモデルファイルとして保存する

In [None]:
#!python3.7
env = Con4()

# モデルの生成
#  verbose：ログの詳細表示(0:ログなし、1:訓練情報を表示、2:TensorFlowログを表示)
model = PPO(MlpPolicy, env, verbose=0, tensorboard_log='./log/con4')
# model = PPO2(MlpPolicy, env, verbose=0)
# モデルの学習
sample = 20000
model.learn(total_timesteps=sample)
# モデルの保存
model.save('con4_model_' + str(sample))

print('training end', sample)


### 学習結果の確認

Tensorboard を使用して、学習の様子を確認します。\
パラメータや報酬ロジックを変更した際には違いを確認し、より強いAIになるよう調整しましょう


In [None]:
%load_ext tensorboard
%tensorboard --logdir=./log/con4

### AIのテスト

作ったAIが想定通りに動くか、まずはコンソールで試してみましょう

In [None]:
state = GameUtil.resetState(6, 7)
i = 0

while True:
  i += 1
  action, _ = model.predict(state)
  state, done = GameUtil.fallCoin(state, action)
  if done:
    GameUtil.render(state)
    break
  done = GameUtil.checkEnd(state)
  print('done2:', done)
  if done != 0:
    print('end: ', i)
    break

  GameUtil.render(state)
  if done != 0:
    print('win ai: ', i)
    break
  print('AI action:', done, action)
  action = input('input action > ')
  state, done = GameUtil.fallCoin(state, int(action), mark = 2)
  if done:
    print('failed fall: ', action)
    GameUtil.render(state)
    break
  done = GameUtil.checkEnd(state)
  if done != 0:
    print('win player: ', i)
    break