# 概要
- このノートブックでは、ValueNetworkを学習させる。
- ネットワークの構造は以下のとおり。
    - 入力層：9チャネル
        - 黒石の位置(1)
        - 白石の位置(1)
        - 空白の位置(1)
        - 合法手の位置(1)
        - そこに打った場合、何個石を返せるか(1)
        - 隅の危険領域4マス×4隅をすべて1で埋める(1)
        - すべて1で埋める(1)
        - すべて0で埋める(1)
        - **手番情報：黒番ならすべて0で埋め、白番ならすべて1で埋める**(1)
    - 第1層：5x5のn_filters種類のフィルターとReLU関数
    - 第2-11層：3x3のn_filters種類のフィルターとReLU関数
    - 第12層：3x3のn_filters種類のフィルター
    - 第13層：1x1のn_filters種類のフィルター
    - 第14層：出力256個の全結合ネットワークとReLU関数
    - 第15層：出力1個の全結合ネットワークとtanh関数
- 学習データの作成方法は[Github](https://github.com/Nyanyan/OthelloAI_Textbook/tree/main/evaluation/self_play)で公開されている最強AIの自己対戦棋譜を使用。

In [None]:
!python -m pip install --no-index --find-links=/kaggle/input/reversi-datasets/ creversi

In [None]:
from creversi import *

import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from copy import copy
import os

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

In [None]:
def board_to_array(board):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

In [None]:
def board_to_array2(board):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

In [None]:
def board_to_array_aug(board, return_torch=False):
    boards = []
    board_array = board_to_array(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
def board_to_array_aug2(board, return_torch=False):
    boards = []
    board_array = board_to_array2(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
def test(moves_str, v_true, model, device):
    """
    Example:
     moves_str = "d3,e3,f2,e2,f5,c5,b6,e6,f6,c6,d6,c4,f3,f7,d7,e7,f4,b5,c3,g5,g6,b4,c7,d2,a6,a5,a3,a4,b3,d8,h6,h5,h4,g4,h3,g3,c2,f1,e1,d1,g2,g1,c1,b7,h1,b1,h2,a2,a8,a7,a1,b2,b8,c8,e8,g8,f8,g7,h8,h7"
     v_true = [0,0,-6,0,-8,0,-19,0,-11,-9,-16,-11,-20,-6,-10,0,-8,0,-18,-9,-10,-8,-7,0,0,0,0,0,-9,0,0,0,6,12,11,12,0,5,0,0,5,15,12,30,31,44,36,54,51,59,44,44,40,40,24,24,24,32,32,32]
    """
    moves = [move_from_str(move_str) for move_str in moves_str.split(',')]
    v_list = []
    board = Board()
    for move in moves:
        v = model(board_to_array_aug2(board,True).to(device)).detach().cpu().numpy().T[0]
        v_list.append(v*64)
        board.move(move)
    plt.figure(figsize=(4,1))
    plt.plot(v_list, c='red')
    plt.plot(np.array(v_list).mean(axis=1), c='orange')
    plt.plot(v_true, c='blue')
    plt.ylim(-64,64)
    plt.axhline(0, c='black', ls='--')
    plt.show()

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

# 学習

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
states = []
results = []
# seed_count = 0
for file in sorted(os.listdir('/kaggle/input/reversi-datasets/self_play')):
    print(file)
    with open('/kaggle/input/reversi-datasets/self_play/'+file, 'r') as f:
        txt = f.read().split('\n')[:-1]
        moves_str = [[game[x:x+2] for x in range(0,len(game),2)] for game in txt]
    for game in moves_str:
        board = Board()
        # np.random.seed(seed_count)
        pick = np.random.choice(game)
        # seed_count += 1
        for move in game:
            board.move_from_str(move)
            # if move == pick:
            if True:
                states.append(board_to_array2(board).reshape(1,9,8,8))
        z = np.array([board.diff_num() if board.turn else -board.diff_num()]*8*len(game), dtype=np.float32)
        
        # states.append(S)
        results.append(z)

states = np.concatenate(states)
results = np.concatenate(results).reshape(-1,1) / 64        
N = states.shape[0]
print(f"N : {N}")

In [None]:
test_size = 0.1
N_tr = int(N*(1-test_size))
N_va = N - N_tr

In [None]:
n_epoch = 5
n_batch = 256
lr = 0.001

criterion = nn.HuberLoss()
seed_everything(1234)
model = ValueNetwork().to(device)
optim = torch.optim.Adam(model.parameters(),lr=lr)

In [None]:
# 学習
train_loss_list = []
valid_loss_list = []
for epoch in range(n_epoch):
    train_loss = 0.
    np.random.seed(epoch)
    random_idx = np.random.permutation(N_tr)
    for i in tqdm(range(N_tr//n_batch)):
        X_batch = torch.from_numpy(states[random_idx[n_batch*i:n_batch*(i+1)]]).to(device)
        y_batch = torch.from_numpy(results[random_idx[n_batch*i:n_batch*(i+1)]]).to(device)

        model.train()
        optim.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optim.step()
        train_loss += loss.item()
    train_loss /= N_tr//n_batch

    # 評価
    model.eval()
    valid_loss = 0.0
    idx = np.array(range(N_tr,N))
    for i in range(N_va//n_batch):
        X_batch = torch.from_numpy(states[idx[n_batch*i:n_batch*(i+1)]]).to(device)
        y_batch = torch.from_numpy(results[idx[n_batch*i:n_batch*(i+1)]]).to(device)
        output = model(X_batch)
        valid_loss += criterion(output, y_batch).item()
    valid_loss /= N_va//n_batch
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)
    print(f'Epoch:{epoch+1}/{n_epoch}, train loss:{train_loss:.5f}, valid loss:{valid_loss:.5f}')
    torch.save(model.cpu(), f'ValueNetwork-v3-checkpoint-{epoch+1}.pth')
    model.to(device)
    
    moves_str = "d3,e3,f2,e2,f5,c5,b6,e6,f6,c6,d6,c4,f3,f7,d7,e7,f4,b5,c3,g5,g6,b4,c7,d2,a6,a5,a3,a4,b3,d8,h6,h5,h4,g4,h3,g3,c2,f1,e1,d1,g2,g1,c1,b7,h1,b1,h2,a2,a8,a7,a1,b2,b8,c8,e8,g8,f8,g7,h8,h7"
    v_true = [0,0,-6,0,-8,0,-19,0,-11,-9,-16,-11,-20,-6,-10,0,-8,0,-18,-9,-10,-8,-7,0,0,0,0,0,-9,0,0,0,6,12,11,12,0,5,0,0,5,15,12,30,31,44,36,54,51,59,44,44,40,40,24,24,24,32,32,32]
    test(moves_str, v_true, model, device)
    
    moves_str = "f5,d6,c7,f3,d3,c6,c5,c4,c3,e6,f4,e3,d7,b6,f6,b3,b5,g5,g4,b4,a3,h5,h4,h3,a4,a6,a2,a5,a7,e7,f8,g3,g6,f7,c2,d8,g8,b8,c8,e8,g7,h8,a8,b7,d2,e2,f1,b1,c1,e1,g2,d1,f2,h7,h6,g1,h1,h2,a1,b2"
    v_true = [0,0,-9,-9,-9,-9,-9,-9,-9,-9,-15,-15,-15,-8,-8,-1,-1,1,1,1,1,1,-15,-15,-15,-15,-15,-15,-15,0,0,2,2,2,2,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,-2,-2,-2,-2,-2,-2,-2,-2]
    test(moves_str, v_true, model, device)
    
    moves_str = "f5,f4,c3,e6,d3,f6,g4,f3,g5,e3,f2,g3,f7,h4,h3,b2,h5,f8,d6,d7,g6,d2,e7,h6,h7,h2,h1,g1,c7,g2,d1,e8,g7,d8,a1,c8,e1,f1,e2,c1,c2,b8,b1,g8,h8,b3,c6,c4,b7,a8,a7,c5,b5,a6,a3,b6,a2,b4,a4,a5"
    v_true = [0,6,-6,-6,-13,-13,-20,-20,-19,-19,-27,-27,-46,-46,-59,-6,-6,12,12,24,24,27,27,37,37,53,53,63,36,36,18,18,4,4,-13,-13,-14,-14,-14,18,18,44,44,64,64,64,46,46,26,26,10,10,4,4,2,2,20,20,20,26,26]
    test(moves_str, v_true, model, device)

In [None]:
plt.hist(results.flatten()*64, bins=64)
plt.show()

In [None]:
moves_str = "f5,d6,c7,f3,d3,c6,c5,c4,c3,e6,f4,e3,d7,b6,f6,b3,b5,g5,g4,b4,a3,h5,h4,h3,a4,a6,a2,a5,a7,e7,f8,g3,g6,f7,c2,d8,g8,b8,c8,e8,g7,h8,a8,b7,d2,e2,f1,b1,c1,e1,g2,d1,f2,h7,h6,g1,h1,h2,a1,b2"
moves = [move_from_str(move_str) for move_str in moves_str.split(',')]
board = Board()
for move in moves:
    board.move(move)
print(board.diff_num(), board.turn)
model(board_to_array_aug2(board,True).to(device)).mean().item()*64

In [None]:
board = Board()
model(board_to_array_aug2(board,True).to(device)).mean().item()*64

In [None]:
plt.figure(figsize=(4,4))
plt.plot(train_loss_list, label='train')
plt.plot(valid_loss_list, label='valid')
plt.show()