# 概要
- このノートブックでは、ValueNetworkを学習させる。
- ネットワークの構造は以下のとおり。
    - 入力層：9チャネル
        - 黒石の位置(1)
        - 白石の位置(1)
        - 空白の位置(1)
        - 合法手の位置(1)
        - そこに打った場合、何個石を返せるか(1)
        - 隅の危険領域4マス×4隅をすべて1で埋める(1)
        - すべて1で埋める(1)
        - すべて0で埋める(1)
        - **手番情報：黒番ならすべて0で埋め、白番ならすべて1で埋める**(1)
    - 第1層：5x5のn_filters種類のフィルターとReLU関数
    - 第2-11層：3x3のn_filters種類のフィルターとReLU関数
    - 第12層：3x3のn_filters種類のフィルター
    - 第13層：1x1のn_filters種類のフィルター
    - 第14層：出力256個の全結合ネットワークとReLU関数
    - 第15層：出力1個の全結合ネットワークとtanh関数
- 学習データの作成方法は以下のとおり。（cf.AlphaGo解体新書p.171）
    - 1以上60以下の整数からランダムに数字を選択し、これをUとする。
    - ~~SL-PolicyNetworkをU-1回使って~~ランダムに手を選んで、U-1手目まで局面を進める。
    - 次のU手目は合法手の中からランダムに選択し局面を進め、この局面をSとする。
    - 局面Sからは、~~RL~~SLポリシーネットワークを使って、終局まで手を進める。最終的な勝敗をzとする。
    - 組(S,z)を学習データとする。

In [1]:
%%capture
!pip install creversi

In [2]:
from creversi import *

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from copy import copy
import gc

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split



In [3]:
def board_to_array(board):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

In [4]:
def board_to_array2(board):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

In [5]:
def move_rotate(move, k):
    if k == 1:
        return move_rotate270(move)
    if k == 2:
        return move_rotate180(move)
    if k == 3:
        return move_rotate90(move)
    
def move_fliplr(move):
    row = move // 8
    col = move % 8

    reversed_col = 7 - col
    reversed_move = row * 8 + reversed_col
    return reversed_move

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 100
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

# 学習データ作成

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model_SL = torch.load('/kaggle/input/reversi-datasets/SL-PolicyNetwork-v3-checkpoint-5epoch-subdata99.pth').to(device)

cpu


In [8]:
def model_move(board, legal_moves, model):
    with torch.no_grad():
        p = model(torch.from_numpy(board_to_array(board)).unsqueeze(0).to(device)).cpu()
    p_legal = p[0][legal_moves].softmax(0).numpy()
    move = np.random.choice(legal_moves, p=p_legal)
    return move

In [9]:
# 1ファイル当たりの対局数
N = 25000
N_file = 7

for i in range(N_file):
    print(f'----file{i+1}----')
    S,z = [],[]
    for n in tqdm(range(N)):
        board = Board()
        U = np.random.randint(1,61)
        u = 1  # 手数のカウンター
        while not board.is_game_over():
            legal_moves = list(board.legal_moves)
            if 64 in legal_moves: # パスの処理
                board.move_pass()
                if u == U:
                    S.append(board_to_array2(board))
                    turn = board.turn
            elif u < U:
                move = np.random.choice(legal_moves)
                board.move(move)
            elif u == U:
                move = np.random.choice(legal_moves)
                board.move(move)
                S.append(board_to_array2(board))
                turn = board.turn
            else:
                move = model_move(board, legal_moves, model_SL)
                board.move(move)
            u += 1
        if len(S) == len(z):
            print('[WARNING] early gameover.')
            continue

        if board.turn == turn:
            z.append(1 if board.diff_num()>0 else (-1 if board.diff_num()<0 else 0))
        else:
            z.append(1 if board.diff_num()<0 else (-1 if board.diff_num()>0 else 0))

    S = np.array(S)
    z = np.array(z).astype(np.float32)
    np.save(f'S-{str(i+1).zfill(3)}-data-for-ValueNetwork.npy', S)
    np.save(f'z-{str(i+1).zfill(3)}-data-for-ValueNetwork.npy', z)
    print(f'S:{S.shape}, z:{z.shape}')
    print(f'pos:{np.where(z>0,1,0).sum()}, neg:{np.where(z<0,1,0).sum()}, draw:{np.where(z==0,1,0).sum()}')

----file1----


  2%|▏         | 449/25000 [00:42<37:08, 11.01it/s]



  9%|▉         | 2298/25000 [03:37<21:24, 17.67it/s]



 16%|█▌        | 3970/25000 [06:19<26:11, 13.38it/s]



 18%|█▊        | 4460/25000 [07:07<36:54,  9.28it/s]



 35%|███▌      | 8852/25000 [14:05<18:25, 14.61it/s]



 37%|███▋      | 9294/25000 [14:48<19:57, 13.11it/s]



 60%|██████    | 15016/25000 [23:58<18:04,  9.21it/s]



 73%|███████▎  | 18216/25000 [29:14<10:56, 10.33it/s]



 77%|███████▋  | 19341/25000 [31:06<07:14, 13.04it/s]



 87%|████████▋ | 21665/25000 [34:55<03:34, 15.58it/s]



 93%|█████████▎| 23179/25000 [37:20<02:14, 13.59it/s]



 97%|█████████▋| 24162/25000 [38:53<00:54, 15.48it/s]



100%|██████████| 25000/25000 [40:14<00:00, 10.35it/s]


S:(24988, 9, 8, 8), z:(24988,)
pos:12696, neg:11411, draw:881
----file2----


 24%|██▎       | 5916/25000 [09:47<34:35,  9.19it/s]



 33%|███▎      | 8264/25000 [13:43<20:10, 13.82it/s]



 76%|███████▋  | 19095/25000 [31:14<07:43, 12.73it/s]



 89%|████████▉ | 22249/25000 [36:24<05:36,  8.17it/s]



100%|██████████| 25000/25000 [40:51<00:00, 10.20it/s]


S:(24996, 9, 8, 8), z:(24996,)
pos:12795, neg:11300, draw:901
----file3----


 18%|█▊        | 4474/25000 [07:09<22:17, 15.34it/s]



 20%|██        | 5019/25000 [08:02<32:34, 10.22it/s]



 21%|██▏       | 5328/25000 [08:31<28:36, 11.46it/s]



 41%|████      | 10154/25000 [16:19<23:56, 10.34it/s]



 47%|████▋     | 11755/25000 [18:55<17:59, 12.27it/s]



 57%|█████▋    | 14361/25000 [23:12<13:37, 13.02it/s]



 64%|██████▍   | 16110/25000 [26:03<12:24, 11.95it/s]



 67%|██████▋   | 16820/25000 [27:14<12:08, 11.22it/s]



 88%|████████▊ | 21939/25000 [35:38<04:25, 11.51it/s]



100%|██████████| 25000/25000 [40:41<00:00, 10.24it/s]


S:(24991, 9, 8, 8), z:(24991,)
pos:12726, neg:11390, draw:875
----file4----


  1%|          | 153/25000 [00:14<38:18, 10.81it/s]



  3%|▎         | 690/25000 [01:04<27:45, 14.59it/s]



  5%|▌         | 1348/25000 [02:07<37:43, 10.45it/s]



 21%|██        | 5182/25000 [08:25<20:58, 15.75it/s]



 33%|███▎      | 8140/25000 [13:15<23:21, 12.03it/s]



 35%|███▍      | 8628/25000 [14:02<24:03, 11.34it/s]



 46%|████▌     | 11535/25000 [18:51<19:27, 11.54it/s]



 58%|█████▊    | 14416/25000 [23:34<11:16, 15.64it/s]



 60%|██████    | 15055/25000 [24:36<13:24, 12.36it/s]



 65%|██████▌   | 16326/25000 [26:41<08:24, 17.19it/s]



 68%|██████▊   | 16986/25000 [27:42<08:11, 16.29it/s]



 84%|████████▎ | 20928/25000 [34:06<06:48,  9.98it/s]



 91%|█████████▏| 22836/25000 [37:17<02:51, 12.65it/s]



 91%|█████████▏| 22843/25000 [37:17<03:05, 11.64it/s]



 95%|█████████▌| 23776/25000 [38:46<01:38, 12.37it/s]



100%|██████████| 25000/25000 [40:44<00:00, 10.23it/s]


S:(24985, 9, 8, 8), z:(24985,)
pos:12748, neg:11439, draw:798
----file5----


  2%|▏         | 568/25000 [00:54<36:35, 11.13it/s]



  4%|▍         | 1099/25000 [01:46<34:47, 11.45it/s]



  8%|▊         | 2055/25000 [03:20<32:28, 11.78it/s]



 12%|█▏        | 3095/25000 [05:03<24:54, 14.66it/s]



 27%|██▋       | 6797/25000 [11:01<27:41, 10.95it/s]



 34%|███▍      | 8575/25000 [13:51<26:43, 10.24it/s]



 52%|█████▏    | 12877/25000 [20:58<29:59,  6.74it/s]



 55%|█████▍    | 13628/25000 [22:10<16:11, 11.71it/s]



 64%|██████▍   | 15988/25000 [26:02<19:47,  7.59it/s]



 92%|█████████▏| 22910/25000 [37:30<02:31, 13.76it/s]



100%|██████████| 25000/25000 [40:54<00:00, 10.18it/s]


S:(24990, 9, 8, 8), z:(24990,)
pos:12794, neg:11313, draw:883
----file6----


  8%|▊         | 2020/25000 [03:18<32:41, 11.71it/s]



  8%|▊         | 2107/25000 [03:27<34:13, 11.15it/s]



 13%|█▎        | 3209/25000 [05:14<26:22, 13.77it/s]



 14%|█▍        | 3600/25000 [05:54<29:48, 11.97it/s]



 25%|██▌       | 6263/25000 [10:17<30:07, 10.36it/s]



 29%|██▊       | 7127/25000 [11:41<26:54, 11.07it/s]



 41%|████      | 10288/25000 [16:54<18:40, 13.13it/s]



 45%|████▍     | 11158/25000 [18:18<18:57, 12.17it/s]



 49%|████▉     | 12289/25000 [20:08<17:40, 11.99it/s]



 53%|█████▎    | 13272/25000 [21:46<18:37, 10.49it/s]



 58%|█████▊    | 14515/25000 [23:50<11:22, 15.37it/s]



 69%|██████▉   | 17299/25000 [28:22<08:03, 15.94it/s]



 72%|███████▏  | 18093/25000 [29:42<06:19, 18.22it/s]



 72%|███████▎  | 18125/25000 [29:44<10:17, 11.14it/s]



 85%|████████▍ | 21234/25000 [34:54<05:41, 11.04it/s]



 87%|████████▋ | 21643/25000 [35:36<08:23,  6.66it/s]



 92%|█████████▏| 22928/25000 [37:44<03:07, 11.08it/s]



100%|██████████| 25000/25000 [41:18<00:00, 10.09it/s]


S:(24983, 9, 8, 8), z:(24983,)
pos:12581, neg:11512, draw:890
----file7----


 19%|█▉        | 4839/25000 [07:55<27:31, 12.21it/s]



 22%|██▏       | 5448/25000 [08:56<19:36, 16.62it/s]



 26%|██▌       | 6406/25000 [10:33<37:42,  8.22it/s]



 27%|██▋       | 6674/25000 [10:58<27:59, 10.91it/s]



 29%|██▉       | 7317/25000 [12:03<24:35, 11.99it/s]



 32%|███▏      | 8013/25000 [13:11<19:55, 14.21it/s]



 37%|███▋      | 9126/25000 [15:03<24:13, 10.92it/s]



 38%|███▊      | 9433/25000 [15:32<21:58, 11.80it/s]



 44%|████▍     | 11051/25000 [18:12<18:45, 12.40it/s]



 60%|█████▉    | 14918/25000 [24:41<12:48, 13.12it/s]



 68%|██████▊   | 17068/25000 [28:18<12:45, 10.36it/s]



 70%|███████   | 17613/25000 [29:15<13:27,  9.15it/s]



 72%|███████▏  | 17903/25000 [29:46<07:24, 15.98it/s]



 81%|████████▏ | 20333/25000 [33:53<06:13, 12.51it/s]



 97%|█████████▋| 24360/25000 [40:27<00:55, 11.56it/s]



 99%|█████████▊| 24641/25000 [40:55<00:21, 16.70it/s]



100%|██████████| 25000/25000 [41:30<00:00, 10.04it/s]

S:(24984, 9, 8, 8), z:(24984,)
pos:12749, neg:11346, draw:889



