# 概要
- このノートブックでは、ValueNetworkを学習させる。
- ネットワークの構造は以下のとおり。
    - 入力層：9チャネル
        - 黒石の位置(1)
        - 白石の位置(1)
        - 空白の位置(1)
        - 合法手の位置(1)
        - そこに打った場合、何個石を返せるか(1)
        - 隅の危険領域4マス×4隅をすべて1で埋める(1)
        - すべて1で埋める(1)
        - すべて0で埋める(1)
        - **手番情報：黒番ならすべて0で埋め、白番ならすべて1で埋める**(1)
    - 第1層：5x5のn_filters種類のフィルターとReLU関数
    - 第2-11層：3x3のn_filters種類のフィルターとReLU関数
    - 第12層：3x3のn_filters種類のフィルター
    - 第13層：1x1のn_filters種類のフィルター
    - 第14層：出力256個の全結合ネットワークとReLU関数
    - 第15層：出力1個の全結合ネットワークとtanh関数
- 山名琢翔さんが公開している[自己対戦棋譜](https://www.egaroucid.nyanyan.dev/ja/technology/transcript/) を利用して、20手目以降の盤面の評価関数をモデリング
- 1~19手目は自前のPolicyによる棋譜でモデリング

In [None]:
!pip install creversi

In [None]:
from creversi import *

import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import copy
import os
import gc

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

In [None]:
def board_to_array(board):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

In [None]:
def board_to_array2(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

In [None]:
def board_to_array_aug(board, return_torch=False):
    boards = []
    board_array = board_to_array(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
def board_to_array_aug2(board, return_torch=False):
    boards = []
    board_array = board_to_array2(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

## データ読み込み

In [None]:
def load_data(idx):
    data_dir = '/kaggle/input/reversi-datasets/Egaroucid_Transcript/Egaroucid_Transcript/0000_egaroucid_6_3_0_lv11/'
    states = []
    results = []
    with open(data_dir+f'0000{str(idx//2).zfill(3)}.txt', 'r') as f:
        transcripts = f.read().split('\n')[:-1]
        if idx%2 == 0:
            transcripts = transcripts[:len(transcripts)//2]
        else:
            transcripts = transcripts[len(transcripts)//2:]
    for transcript in transcripts:
        board = Board()
        S = []
        for i in range(0,len(transcript),2):
            move = move_from_str(transcript[i:i+2])
            board.move(move)
            if i > 19:
                S.append(board_to_array_aug2(board))
        if len(S)==0:
            continue
        z = board.diff_num() if board.turn else -board.diff_num()
        S = np.concatenate(S)
        results += [z] * S.shape[0]
        states += [S]
    states = np.concatenate(states, dtype=np.float16)
    results = np.array(results, dtype=np.int8).reshape(-1,1)
    return states, results

# 学習

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
# n_epoch = 10
# n_batch = 256
# lr = 0.001

# import tensorflow as tf
# # TPUをセットアップ
# resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
# strategy = tf.distribute.TPUStrategy(resolver)

# with strategy.scope():
#     seed_everything(1234)
#     model = ValueNetwork()
#     criterion = nn.HuberLoss()
#     optim = torch.optim.Adam(model.parameters(),lr=lr)

In [None]:
n_epoch = 1
n_batch = 256
lr = 0.001

seed_everything(1234)
model = ValueNetwork().to(device)
criterion = nn.HuberLoss()
optim = torch.optim.Adam(model.parameters(),lr=lr)

In [None]:
# 学習
loss_list = []
for epoch in range(n_epoch):
    np.random.seed(epoch)
    subdata_idx = np.random.permutation(400)
    for j,idx in enumerate(subdata_idx):
        print(f'-------Epoch:{epoch+1}/{n_epoch} Data:{j+1}/400-------')
        train_loss = 0.
        states,results = load_data(idx)
        np.random.seed(epoch*idx)
        random_idx = np.random.permutation(states.shape[0])
        for i in range(len(random_idx)//n_batch):
            X_batch = torch.from_numpy(states[random_idx[n_batch*i:n_batch*(i+1)]]).to(torch.float32).to(device)
            y_batch = torch.from_numpy(results[random_idx[n_batch*i:n_batch*(i+1)]]).to(torch.float32).to(device) / 64
#             X_batch = torch.from_numpy(states[random_idx[n_batch*i:n_batch*(i+1)]]).to(torch.float32)
#             y_batch = torch.from_numpy(results[random_idx[n_batch*i:n_batch*(i+1)]]).to(torch.float32) / 64

            model.train()
            optim.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optim.step()
            train_loss += loss.item()

        train_loss /= len(random_idx)//n_batch
        loss_list.append(train_loss)
        print(f'train loss:{train_loss:.8f}')
        del states,results
        gc.collect()

In [None]:
torch.save(model.cpu().state_dict(), f'value-network-v4.pth')

In [None]:
plt.figure(figsize=(4,4))
plt.plot(loss_list)
plt.show()

In [None]:
moves_list = ['d3e3f2e2f5c5b6e6f6c6d6c4f3f7d7e7f4b5c3g5g6b4c7d2a6a5a3a4b3d8h6h5h4g4h3g3c2f1e1d1g2g1c1b7h1b1h2a2a8a7a1b2b8c8e8g8f8g7h8h7',
         'f5d6c7f3d3c6c5c4c3e6f4e3d7b6f6b3b5g5g4b4a3h5h4h3a4a6a2a5a7e7f8g3g6f7c2d8g8b8c8e8g7h8a8b7d2e2f1b1c1e1g2d1f2h7h6g1h1h2a1b2',
         'f5f4c3e6d3f6g4f3g5e3f2g3f7h4h3b2h5f8d6d7g6d2e7h6h7h2h1g1c7g2d1e8g7d8a1c8e1f1e2c1c2b8b1g8h8b3c6c4b7a8a7c5b5a6a3b6a2b4a4a5']
v_true_list = [[0,0,-6,0,-8,0,-19,0,-11,-9,-16,-11,-20,-6,-10,0,-8,0,-18,-9,-10,-8,-7,0,0,0,0,0,-9,0,0,0,6,12,11,12,0,5,0,0,5,15,12,30,31,44,36,54,51,59,44,44,40,40,24,24,24,32,32,32],
     [0,0,-9,-9,-9,-9,-9,-9,-9,-9,-15,-15,-15,-8,-8,-1,-1,1,1,1,1,1,-15,-15,-15,-15,-15,-15,-15,0,0,2,2,2,2,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,-2,-2,-2,-2,-2,-2,-2,-2],
     [0,6,-6,-6,-13,-13,-20,-20,-19,-19,-27,-27,-46,-46,-59,-6,-6,12,12,24,24,27,27,37,37,53,53,63,36,36,18,18,4,4,-13,-13,-14,-14,-14,18,18,44,44,64,64,64,46,46,26,26,10,10,4,4,2,2,20,20,20,26,26]]

In [None]:
for moves,v_true in zip(moves_list, v_true_list):
    v_list = []
    board = Board()
    for move in [move_from_str(moves[i:i+2]) for i in range(0,len(moves),2)]:
        num = 64 - board.piece_sum()
        if num !=60:
            model.eval()
            with torch.no_grad():
                v = model(board_to_array_aug2(board,True)).numpy().T[0]
            v_list.append(v*64)
        else:
            v_list.append(np.array([0]*8))
        board.move(move)
    plt.figure(figsize=(4,1))
    plt.plot(v_list, c='red')
    plt.plot(np.array(v_list).mean(axis=1), c='orange')
    plt.plot(v_true, c='blue')
    plt.ylim(-64,64)
    plt.axhline(0, c='black', ls='--')
    plt.show()

In [None]:
b = Board()
b.set_line('O'*64, True)
print(model(board_to_array_aug2(b,True)).mean().item()*64)

b = Board()
b.set_line('-'*64, True)
print(model(board_to_array_aug2(b,True)).mean().item()*64)

b = Board()
b.set_line('X'*64, True)
print(model(board_to_array_aug2(b,True)).mean().item()*64)