# ライブラリ

In [None]:
%%capture
!pip install creversi

In [None]:
from creversi import *

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from copy import copy

import torch
import torch.nn as nn

from IPython.display import HTML

# ネットワーク構造

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 100
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

# 関数

In [None]:
def board_to_array(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

def board_to_array2(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

def board_to_array_aug2(board, return_torch=False):
    """
    boardオブジェクトから,8通りのデータ拡張を行ったndarrayへ変換する関数(ValueNetwork用)。
    """
    boards = []
    board_array = board_to_array2(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
class Node:
    def __init__(self, board):
        self.data = (None, None, None, None)
        self.board = board
        self.children = []

    def add(self, child):
        self.children.append(child)

    def backward(self):
        if not self.children:
            self.data = self.evaluate()
        else:
            for child in self.children:
                child.backward()
            min_v = min([abs(child.data[0]) for child in self.children])
            max_v = max([abs(child.data[1]) for child in self.children])
            count = sum([child.data[2] for child in self.children])
            N = sum([child.data[3] for child in self.children])
            self.data = (min_v, max_v, count, N)

    def evaluate(self):
        assert self.board.is_game_over(), self.board.to_line()
        evaluation = self.board.diff_num() if self.board.turn else -self.board.diff_num()
        count = 1 if evaluation==0 else 0
        N = 1
        return (evaluation, evaluation, count, N)

In [None]:
def apply_move(board, move):
    """仮想盤面を生成"""
    board_ = copy(board)
    board_.move(move)
    return board_

def create_tree(node, depth):
    """根=node,深さ=depthの木を作成"""
    if depth == 0:
        return
    for move in list(node.board.legal_moves):
        new_board = apply_move(node.board, move)
        child = Node(new_board)
        node.add(child)
        if not new_board.is_game_over():
            create_tree(child, depth - 1)
            
def value_func(v, V, n, N, alpha=0.9, beta=0.5, lam=0.5):
    """全探索時の評価関数"""
    if V == 0:
        return 10
    return 0.519 * np.tan(1.64*n/N - 0.126) + 0.0658 + lam*np.exp(-alpha*v) + (1-lam)*np.exp(-beta*V)

In [None]:
def search_based_value(board, legal_moves):
    """予測石差絶対値を最小化する手を選ぶ"""
    vbest,move = np.inf,None
    for if_move in legal_moves:
        if_board = apply_move(board, if_move)
        with torch.no_grad():
            if_v = model_v(board_to_array_aug2(if_board,True)).mean().item()*64
        if abs(if_v) < vbest:
            vbest = abs(if_v)
            move = if_move
    return move

def search_all(board, legal_moves, root):
    """全探索にもとづいて手を選ぶ"""
    if_board = Board()
    zbest,move_best = np.inf,None
    line,turn = board.to_line(), board.turn
    for if_move in legal_moves:
        if_board = apply_move(board, if_move)
        if_root = root.children[legal_moves.index(if_move)]
        # minmax探索
        while not if_board.is_game_over():
            if_legal_moves = list(if_board.legal_moves)
            if 64 in if_legal_moves:
                ifif_move = 64
            elif not if_board.turn:
                values = np.array([value_func(*child.data) for child in if_root.children])
                ifif_move = if_legal_moves[values.argmax()]
            else:
                values = np.array([value_func(*child.data) for child in if_root.children])
                ifif_move = if_legal_moves[values.argmin()]
            if_board.move(ifif_move)
            if_root = if_root.children[if_legal_moves.index(ifif_move)]

        z = abs(if_board.diff_num())
        if z < zbest:
            zbest = z
            move_best = if_move
    return move_best

In [None]:
def receive_input(legal_moves):
    """ユーザからの入力を受け取る関数"""
    legal_moves_str = [move_to_str(move) for move in legal_moves]
    print('legal moves -->', legal_moves_str)
    move_str = input('Your turn : ')
    if move_str == '0':
        return '0'
    while move_str not in legal_moves_str:
        move_str = input('Invalid input. Try again : ')
    move = move_from_str(move_str)
    return move

In [None]:
import io
import base64

def img2html(fig, svg):
    html_header = '''
    <html><head><style>
      .container {
        display: flex;
      }
    </style></head><body>
    '''
    html = '''
      <div class="container">
      {svg_board}
      <img src="data:image/png;base64,{image_bin}">
      </div>
    </body>
    </html>
    '''
    sio = io.BytesIO()
    fig.savefig(sio, format='png')
    image_bin = base64.b64encode(sio.getvalue())
    html = html.format(image_bin=str(image_bin)[2:-1], svg_board=svg)
    return html_header + html

In [None]:
def get_html(board, move, value_list, ai_idx, xmax):
    fig = plt.figure(figsize=(7,3))
    value_arr = np.array(value_list).mean(axis=1).ravel()
    max_arr = np.array(value_list).max(axis=1).ravel()
    min_arr = np.array(value_list).min(axis=1).ravel()
    plt.plot(value_arr, c='r', marker='o', markersize=4)
    plt.scatter(ai_idx, value_arr[ai_idx], marker='x', zorder=2)
    plt.fill_between(range(len(value_list)), min_arr, max_arr, alpha=0.5)
    plt.xlim(0,xmax)
    plt.ylim(-64,64)
    plt.axhline(0,c='black',ls='--')
    plt.yticks(range(-60,70,10))
    plt.grid()
    html = img2html(fig, str(board.to_svg(move)))
    plt.close()
    return html

# Policy関数

In [None]:
def strong_ai(board, legal_moves, model, device):
    with torch.no_grad():
        output = model(board_to_array(board,True).unsqueeze(0).to(device)).cpu()
        p_legal = output[0][legal_moves]
        move = legal_moves[p_legal.argmax().item()]
    return move

In [None]:
def draw_ai(board, legal_moves, root, model_v, device):
    if board.piece_sum() < 55:
        move = search_based_value(board, legal_moves)
        root = None
    else:
        if root is None:
            root = Node(board)
            create_tree(root, 100)
            root.backward()
        move = search_all(board, legal_moves, root)
    return move, root

# モデル読み込み

In [None]:
# 学習済みモデルの読み込み
model = torch.load('/kaggle/input/reversi-datasets/SL-PolicyNetwork-v3-checkpoint-5epoch-subdata99.pth')
model_v = ValueNetwork()
model_v.load_state_dict(torch.load('/kaggle/input/reversi-datasets/value-network-v2.pth'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device).eval()
model_v = model_v.to(device).eval()

# 強いAI

In [None]:
### config ###
you = 'black'
##############

you = True if you=='black' else False
value_list, ai_idx = [], []
xmax = 60
idx = 0
board = Board()
display(board)
while not board.is_game_over():
    legal_moves = list(board.legal_moves)
    if 64 in legal_moves:
        move = 64
        xmax += 1
        print(f"{'You' if you==board.turn else 'AI'} Passed.")
    elif you == board.turn:
        move = receive_input(legal_moves)
        if move == '0': break
    else:
        move = strong_ai(board, legal_moves, model, device)
        ai_idx.append(idx)
    board.move(move)
    idx += 1
    
    with torch.no_grad():
        value = model_v(board_to_array_aug2(board,True).to(device)).cpu().numpy() * 64
        value_list.append(value)
    html = get_html(board, move, value_list, ai_idx, xmax)
    display(HTML(html))


print('-----Finished ----')
print('YOU :', board.piece_num() if board.turn==you else board.opponent_piece_num())
print('AI :', board.piece_num() if board.turn!=you else board.opponent_piece_num())
print('------------------')

# 忖度AI

In [None]:
### config ###
you = 'white'
##############

you = True if you=='black' else False
value_list, ai_idx = [], []
xmax = 60
idx = 0
root = None
board = Board()
display(board)
while not board.is_game_over():
    legal_moves = list(board.legal_moves)
    if 64 in legal_moves:
        move = 64
        xmax += 1
        print(f"{'You' if you==board.turn else 'AI'} Passed.")
    elif you == board.turn:
        move = receive_input(legal_moves)
        if move == '0': break
    else:
        move, root = draw_ai(board, legal_moves, root, model_v, device)
        ai_idx.append(idx)
    board.move(move)
    if root is not None:
        root = root.children[legal_moves.index(move)]
    idx += 1
    
    with torch.no_grad():
        value = model_v(board_to_array_aug2(board,True).to(device)).cpu().numpy() * 64
        value_list.append(value)
    html = get_html(board, move, value_list, ai_idx, xmax)
    display(HTML(html))


print('-----Finished ----')
print('YOU :', board.piece_num() if board.turn==you else board.opponent_piece_num())
print('AI :', board.piece_num() if board.turn!=you else board.opponent_piece_num())
print('------------------')