# ライブラリ

In [None]:
%%capture
!pip install creversi

In [None]:
from creversi import *

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from copy import copy

import torch
import torch.nn as nn

from IPython.display import HTML

# ネットワーク構造

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 100
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

# 関数

In [None]:
def board_to_array(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

def board_to_array2(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

def board_to_array_aug2(board, return_torch=False):
    """
    boardオブジェクトから,8通りのデータ拡張を行ったndarrayへ変換する関数(ValueNetwork用)。
    """
    boards = []
    board_array = board_to_array2(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
class Node:
    def __init__(self, board):
        self.data = []
        self.board = board
        self.children = []

    def add(self, child):
        self.children.append(child)

    def backward(self):
        if not self.children:
            self.data.append(self.evaluate())
        else:
            for child in self.children:
                child.backward()
            for child in self.children:
                self.data += child.data

    def backward2(self, model_v):
        if not self.children:
            self.data.append(self.evaluate2(model_v))
        else:
            for child in self.children:
                child.backward2(model_v)
            for child in self.children:
                self.data += child.data

    def evaluate(self):
        assert self.board.is_game_over(), self.board.to_line()
        z = self.board.diff_num() if self.board.turn else -self.board.diff_num()
        return z

    def evaluate2(self, model_v):
        with torch.no_grad():
            z = model_v(board_to_array_aug2(self.board,True)).mean().item()*64
        return z
    
    def get_data(self):
        return np.array(self.data)

In [None]:
def apply_move(board, move):
    """仮想盤面を生成"""
    board_ = copy(board)
    board_.move(move)
    return board_

def create_tree(node, depth):
    """根=node,深さ=depthの木を作成"""
    if depth == 0:
        return
    for move in list(node.board.legal_moves):
        new_board = apply_move(node.board, move)
        child = Node(new_board)
        node.add(child)
        if not new_board.is_game_over():
            create_tree(child, depth - 1)

# Policy関数

In [None]:
def minimax_draw(node, turn, depth):
    if depth == 0 or len(node.data) == 1:
        z = node.get_data()
        return abs(z).mean(), list(node.board.legal_moves)[0]
    
    if node.board.turn == turn:
        zbest = np.inf
        mbest = None
        for i,child in enumerate(node.children):
            z,_ = minimax_draw(child, turn, depth-1)
            if zbest > z:
                zbest = z
                mbest = list(node.board.legal_moves)[i]
    else:
        zbest = -np.inf
        mbest = None
        for i,child in enumerate(node.children):
            z,_ = minimax_draw(child, turn, depth-1)
            if zbest < z:
                zbest = z
                mbest = list(node.board.legal_moves)[i]
    assert (abs(zbest)<100) and (mbest is not None), f'zbest={zbest}, mbest={mbest}'
    return zbest, mbest

def minimax_strong(node, turn, depth):
    if depth == 0 or len(node.data) == 1:
        z = node.get_data()
        return z.mean(), list(node.board.legal_moves)[0]
    
    if ((turn==False) and (node.board.turn==turn)) or ((turn==True) and (node.board.turn!=turn)):
        zbest = -np.inf
        mbest = None
        for i,child in enumerate(node.children):
            z,_ = minimax_draw(child, turn, depth-1)
            if zbest < z:
                zbest = z
                mbest = list(node.board.legal_moves)[i]
    else:
        zbest = np.inf
        mbest = None
        for i,child in enumerate(node.children):
            z,_ = minimax_draw(child, turn, depth-1)
            if zbest > z:
                zbest = z
                mbest = list(node.board.legal_moves)[i]        
    assert (abs(zbest)<100) and (mbest is not None), f'zbest={zbest}, mbest={mbest}'
    return zbest, mbest

In [None]:
def strong_ai(board, legal_moves, root, model, device, end_to_end=False):
    if (root is None) or end_to_end:
        with torch.no_grad():
            output = model(board_to_array(board,True).unsqueeze(0).to(device)).cpu()
            p_legal = output[0][legal_moves]
            move = legal_moves[p_legal.argmax().item()]
    else:
        _,move = minimax_strong(root, root.board.turn, 100)
    return move


def draw_ai(board, legal_moves, root, model_v, device, num=54, depth=3):
    if board.piece_sum() < num:
        root_tmp = Node(board)
        create_tree(root_tmp, depth)
        root_tmp.backward2(model_v)
        _,move = minimax_draw(root_tmp, root_tmp.board.turn, depth)
        root = None
    else:
        if root is None:
            root = Node(board)
            create_tree(root, 100)
            root.backward()
        _,move = minimax_draw(root, root.board.turn, 100)
    return move, root

# モデル読み込み

In [None]:
# 学習済みモデルの読み込み
model = torch.load('/kaggle/input/reversi-datasets/SL-PolicyNetwork-v3-checkpoint-5epoch-subdata99.pth')
model_v = ValueNetwork()
model_v.load_state_dict(torch.load('/kaggle/input/reversi-datasets/value-network-v2.pth'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device).eval()
model_v = model_v.to(device).eval()

# 強いAI

## Minimax探索あり

In [None]:
# ### config ###
# you = 'black'
# ##############

# you = True if you=='black' else False
# Z = []

# for i in range(1000):
#     root = None
#     board = Board()
#     while not board.is_game_over():
#         legal_moves = list(board.legal_moves)
#         if 64 in legal_moves:
#             move = 64
#         elif you == board.turn:
#             move = np.random.choice(legal_moves)
#         else:
#             move = strong_ai(board, legal_moves, root, model, device, end_to_end=False)
#         board.move(move)

#         if board.piece_sum() >= 53:
#             root = Node(board)
#             create_tree(root, 100)
#             root.backward()
#     z = board.diff_num() if board.turn else -board.diff_num()
#     Z.append(z)

# Z = np.array(Z)

# cnt = []
# for d in range(-64,65):
#     cnt.append((Z==d).sum()/len(Z))
# plt.figure(figsize=(5,3))
# plt.bar(range(-64,65), cnt)
# plt.grid()
# plt.show()

# for d in range(65):
#     p = (Z<=-d).sum()/len(Z)
#     w = 1.96 * (p*(1-p)/len(Z))**0.5
#     print(f'diff>={d} : {p*100:.1f}%  ({(p-w)*100:.1f}%,{(p+w)*100:.1f}%)')

## MiniMax探索なし

In [None]:
# ### config ###
# you = 'black'
# ##############

# you = True if you=='black' else False
# Z = []

# for i in range(1000):
#     root = None
#     board = Board()
#     while not board.is_game_over():
#         legal_moves = list(board.legal_moves)
#         if 64 in legal_moves:
#             move = 64
#         elif you == board.turn:
#             move = np.random.choice(legal_moves)
#         else:
#             move = strong_ai(board, legal_moves, root, model, device, end_to_end=True)
#         board.move(move)

#     z = board.diff_num() if board.turn else -board.diff_num()
#     Z.append(z)

# Z = np.array(Z)

# cnt = []
# for d in range(-64,65):
#     cnt.append((Z==d).sum()/len(Z))
# plt.figure(figsize=(5,3))
# plt.bar(range(-64,65), cnt)
# plt.grid()
# plt.show()

# for d in range(65):
#     p = (Z<=-d).sum()/len(Z)
#     w = 1.96 * (p*(1-p)/len(Z))**0.5
#     print(f'diff>={d} : {p*100:.1f}%  ({(p-w)*100:.1f}%,{(p+w)*100:.1f}%)')

# 忖度AI

In [None]:
## config ###
you = 'white'
##############

you = True if you=='black' else False
Z = []

for i in range(300):
    root = None
    board = Board()
    while not board.is_game_over():
        legal_moves = list(board.legal_moves)
        if 64 in legal_moves:
            move = 64
        elif you == board.turn:
            move = np.random.choice(legal_moves)        
        else:
            move, root = draw_ai(board, legal_moves, root, model_v, device, 53, 3)
        board.move(move)
        if root is not None:
            root = root.children[legal_moves.index(move)]

    z = board.diff_num() if board.turn else -board.diff_num()
    Z.append(z)
    if (i+1)%10==0:
        print(i+1)
    
Z = np.array(Z)

cnt = []
for d in range(-64,65):
    cnt.append((Z==d).sum()/len(Z))
plt.figure(figsize=(5,3))
plt.bar(range(-64,65), cnt)
plt.grid()
plt.show()

for d in range(65):
    p = (abs(Z)<=d).sum()/len(Z)
    w = 1.96 * (p*(1-p)/len(Z))**0.5
    print(f'|diff|<={d} : {p*100:.1f}%  ({(p-w)*100:.1f}%,{(p+w)*100:.1f}%)')