In [None]:
%%capture
!pip install creversi

In [None]:
from creversi import *

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from copy import copy, deepcopy

import torch
import torch.nn as nn

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 100
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

In [None]:
def board_to_array(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

In [None]:
def board_to_array2(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

In [None]:
def board_to_array_aug2(board, return_torch=False):
    boards = []
    board_array = board_to_array2(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
# 学習済みモデルの読み込み
model = torch.load('/kaggle/input/reversi-datasets/SL-PolicyNetwork-v3-checkpoint-5epoch-subdata99.pth')
model_v = ValueNetwork()
model_v.load_state_dict(torch.load('/kaggle/input/reversi-datasets/value-network-v2.pth'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device).eval()
model_v = model_v.to(device).eval()

# 54手目まではValueNetwork使用、55手目以降は全探索
大差をつけられてしまった場合は、差がなくなるまでPolicyで最強の一手がいいかも？

In [None]:
class Node:
    def __init__(self, board):
        self.data = (None, None, 0)
        self.board = board
        self.children = []

    def add(self, child):
        self.children.append(child)

    def backward(self):
        if not self.children:
            # 終端ノードの評価値を計算して設定
            self.data = self.evaluate()
        else:
            # 子ノードの評価値を計算
            for child in self.children:
                child.backward()

            # 子ノードの評価値を利用して現在のノードの評価値を計算
            min_v = min([abs(child.data[0]) for child in self.children])
            max_v = max([abs(child.data[1]) for child in self.children])
            count = sum([child.data[2] for child in self.children])
            self.data = (min_v, max_v, count)

    def evaluate(self):
        assert self.board.is_game_over(), self.board.to_line()
        evaluation = self.board.diff_num() if self.board.turn else -self.board.diff_num()
        count = 1 if evaluation==0 else 0
        return (evaluation, evaluation, count)

def apply_move(board, move):
    board_ = copy(board)
    board_.move(move)
    return board_

def create_tree(node, depth):
    if depth == 0:
        return

    moves = list(node.board.legal_moves)
    for move in moves:
        new_board = apply_move(node.board, move)  # 手を適用して新しい盤面を生成
        child = Node(new_board)
        node.add(child)
        create_tree(child, depth - 1)

In [None]:
def search(board, depth):
    legal_moves = list(board.legal_moves)
    root = Node(board)
    create_tree(root, depth)
    root.backward()
    min_v = np.array([child.data[0] for child in root.children])
    max_v = np.array([child.data[1] for child in root.children])
    count = np.array([child.data[2] for child in root.children])
    print(min_v)
    print(max_v)
    print(count)
    print('---------')
    
    if 0 in max_v:
        return legal_moves[list(max_v).index(0)]  # 採用
    elif 0 in min_v:
        cand = np.where(max_v==max_v[min_v==0].min())[0]  # minが0かつmax最小
        if len(cand) == 0:
            return legal_moves[cand]  # 採用
        else:
            return legal_moves[np.where(count==count[cand].max())[0][0]]  # countが最大, 採用
    else:
        cand = np.where(min_v==min_v.min())[0]
        if len(cand) == 0:
            return legal_moves[cand[0]]  # minが唯一最小
        else:
            return legal_moves[np.where(max_v==max_v[min_v==min_v.min()].min())[0][0]]  # minが複数最小 max最小, 採用

In [None]:
board = Board()
if_board = Board()
values = []
if_values = []
AI_idx = []

while board.piece_num() + board.opponent_piece_num() < 55:
    legal_moves = list(board.legal_moves)
    if not board.turn:
        AI_idx.append(len(values))

    # パスの処理
    if 64 in legal_moves:
        move = 64
    # AIの番
    elif not board.turn:
        vbest,move = np.inf,None
        line,turn = board.to_line(), board.turn
        for if_move in legal_moves:
            # 動かす前の盤面を復元------------
            if if_board.turn != turn:
                if_board.move_pass()
            if_board.set_line(line, turn)
            # -----------------------------
            if_board.move(if_move)
            with torch.no_grad():
                if_v = model_v(board_to_array_aug2(if_board,True)).mean().item()*64
            if abs(if_v) < vbest:
                vbest = abs(if_v)
                move = if_move
    # ランダムプレイヤーの番
    else:
        move = np.random.choice(legal_moves)
        # move = model(board_to_array(board,True).unsqueeze(0)).argmax().item()
    
    # 局面を評価
    line,turn = board.to_line(), board.turn
    board.move(move)
    with torch.no_grad():
        v = model_v(board_to_array_aug2(board,True)).mean().item()*64
    values.append(v)

    # 反実仮想評価を計算
    vmin,vmax = np.inf,-np.inf
    for if_move in legal_moves:
        # 動かす前の盤面を復元------------
        if if_board.turn != turn:
            if_board.move_pass()
        if_board.set_line(line, turn)
        # -----------------------------
        if_board.move(if_move)
        with torch.no_grad():
            if_v = model_v(board_to_array_aug2(if_board,True)).mean().item()*64
        if if_v < vmin:
            vmin = if_v
        if if_v > vmax:
            vmax = if_v
    if_values.append([vmin,vmax])

# 集計
z = board.diff_num() if board.turn else -board.diff_num()
if_values = np.array(if_values)
values = np.array(values)
print(z)
display(board)

# plot
plt.figure(figsize=(10,3))
plt.plot(values,c='r',marker='o',markersize=4)
plt.scatter(AI_idx, values[AI_idx], marker='x', zorder=2, label="AI's turn")
plt.fill_between(range(len(if_values)), if_values[:,0], if_values[:,1], alpha=0.5)
plt.ylim(-64,64)
plt.axhline(0,c='black',ls='--')
plt.axhline(z,c='blue',ls=':')
plt.yticks(range(-60,70,10))
plt.grid()
plt.legend()
plt.show()

In [None]:
depth = 15
while not board.is_game_over():
    legal_moves = list(board.legal_moves)
    # パスの処理
    if 64 in legal_moves:
        move = 64
    # AIの番
    elif not board.turn:
        move = search(board, depth)
        depth -= 1
    # ランダムプレイヤーの番
    else:
        move = np.random.choice(legal_moves)
    board.move(move)

print(board.diff_num())
board