# 方針
- このノートブックでは、SLポリシーネットワークを学習させる。
- パスは予測値に含めず、出力は0~63
- ネットワークの構造は以下のとおり。
    - 入力層：8チャネル
        - 黒石の位置(1)
        - 白石の位置(1)
        - 空白の位置(1)
        - 合法手の位置(1)
        - そこに打った場合、何個石を返せるか(1)
        - 隅の危険領域4マス×4隅をすべて1で埋める(1)
        - すべて1で埋める(1)
        - すべて0で埋める(1)
    - 第1層：5x5のn_filters種類のフィルターとReLU関数
    - 第2-12層：3x3のn_filters種類のフィルターとReLU関数
    - 第13層：1x1の1種類のフィルターと位置に依存するバイアス項＋softmax関数

In [None]:
%%capture
!pip install creversi

In [None]:
# リバーシ用ライブラリ
from creversi import Board,move_to_str,move_from_str,move_rotate90,move_rotate180,move_rotate270
import creversi
# 基礎ライブラリ
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from copy import copy
import gc
import psutil

# 学習用ライブラリ
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

`piece_planes()`の仕様
- 自分の石が第1チャンネル、相手の石が第2チャンネルに出力される。

`legal_moves`の仕様
- 石を置けない場合は、64(=pass)のみが返される。

# 学習データの準備

In [None]:
def parse(x):
    move_arr = np.zeros(60,dtype=int)
    move = [move_from_str(x[i:i+2]) for i in range(0, len(x), 2)]
    move_arr[:len(move)] = move
    move_arr[len(move):] = -1
    return move_arr

def read_data(year):
    """1年分のデータを読み込む関数"""
    df = pd.read_csv(f"/kaggle/input/reversi-datasets/wthor_{year}.csv")
    df = df["transcript"].apply(parse).apply(pd.Series)
    return df.values

In [None]:
# 47年分のデータを読み込む
for y in tqdm(range(1977,2023)):
    d = read_data(y)
    if y==1977:
        data = d
    else:
        data = np.concatenate([data, d])
data

In [None]:
def board_to_array(board):
    """
    boardオブジェクトからndarrayに変換する関数。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

In [None]:
def move_rotate(move, k):
    if k == 1:
        return move_rotate270(move)
    if k == 2:
        return move_rotate180(move)
    if k == 3:
        return move_rotate90(move)
    
def move_fliplr(move):
    row = move // 8
    col = move % 8

    reversed_col = 7 - col
    reversed_move = row * 8 + reversed_col
    return reversed_move

In [None]:
def load_sub_data(data, n, n_split, shuffle=True):
    """dataをn_split分割したうちのn番目(n=0,1,...)のブロック(sub-data)をロード"""
    N = data.shape[0]
    if shuffle:
        np.random.seed(0)
        idx = np.random.permutation(N)
    else:
        idx = list(range(N))
    assert n_split <= N
    N_batch = N // n_split
    X = []
    y = []

    for i in tqdm(range(N_batch*n, N_batch*(n+1))):
        board = Board()
        for j,move in enumerate(data[idx[i]]):
            if move == -1:
                break
            if list(board.legal_moves) != [64]: # パスの局面ではない場合
                board_array = board_to_array(board)
                X.append(board_array)
                y.append(move)
                X.append(np.flip(board_array,axis=2).copy())
                y.append(move_fliplr(move))
                for k in range(1,4):
                    board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
                    X.append(board_array_rot)
                    y.append(move_rotate(move, k=k))
                    X.append(np.flip(board_array_rot, axis=2).copy())
                    y.append(move_fliplr(move_rotate(move, k=k)))
                board.move(move)
            else:  # パスの局面の場合
                board.move_pass()
    X = np.array(X).astype(np.float32)
    y = np.array(y).astype(np.int64)
    print(f'X:{X.shape}, y:{y.shape}')
    return X, y

# ネットワーク

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 100
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

# 学習

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
n_epoch = 5
n_batch = 256
n_split = 300
lr = 0.001

model = PolicyNetwork().to(device)
optim = torch.optim.AdamW(model.parameters(),lr=lr)
criterion = nn.CrossEntropyLoss()
train_loss_list = []
valid_loss_list = []
acc_list = []

In [None]:
# 評価用データセット
X_va, y_va = load_sub_data(data,n_split-1,n_split)
X_va, y_va = torch.from_numpy(X_va), torch.from_numpy(y_va)
n_valid_data = len(y_va)

In [None]:
# メモリーチェック
memory = psutil.virtual_memory()
print(f'{memory.used/1e9} Used ({memory.percent}%)')
print(gc.collect())
memory = psutil.virtual_memory()
print(f'{memory.used/1e9} Used ({memory.percent}%)')

In [None]:
bst_loss = np.inf
bst_acc = -np.inf

for epoch in range(n_epoch):
    for n in range(n_split-1):
        # サブデータをロード
        print(f'----Epoch{epoch+1}, SubData{n+1}----')
        X_tr, y_tr = load_sub_data(data,n,n_split)
        X_tr, y_tr = torch.from_numpy(X_tr), torch.from_numpy(y_tr)
        n_train_data = len(y_tr)
        
        # 学習
        train_loss = 0.
        random_idx = np.random.permutation(n_train_data)        
        for i in range(n_train_data//n_batch):
            X_batch = X_tr[random_idx[n_batch*i:n_batch*(i+1)]].to(device)
            y_batch = y_tr[random_idx[n_batch*i:n_batch*(i+1)]].to(device)
            
            optim.zero_grad()
            model.train()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optim.step()
            train_loss += loss.item()
        train_loss_list.append(train_loss / (n_train_data//n_batch))
        
        # 評価
        valid_loss = 0.
        correct = 0
        idx = np.arange(n_valid_data)
        for i in range(n_valid_data//n_batch):
            X_batch = X_va[idx[n_batch*i:n_batch*(i+1)]].to(device)
            y_batch = y_va[idx[n_batch*i:n_batch*(i+1)]].to(device)
            model.eval()
            pred = model(X_batch)
            valid_loss += criterion(pred, y_batch).item()
            correct += (pred.argmax(axis=1) == y_batch).sum().item()
        valid_loss_list.append(valid_loss / (n_valid_data//n_batch))
        acc = correct / ((n_valid_data//n_batch)*n_batch) * 100
        acc_list.append(acc)
        print(f'Epoch:{epoch+1}/{n_epoch}, SubData:{n+1}/{n_split-1}, train loss:{train_loss/(n_train_data//n_batch):.5f} valid loss:{valid_loss/(n_valid_data//n_batch):.5f} valid acc:{acc:.3f}%')
        
        if (valid_loss/(n_valid_data//n_batch) < bst_loss) or (acc > bst_acc):
            torch.save(model.cpu(), f'SL-PolicyNetwork-v3-checkpoint-{epoch+1}epoch-subdata{n+1}.pth')
            model.to(device)
            bst_loss = valid_loss/(n_valid_data//n_batch)
            bst_acc = acc
        del X_tr, y_tr
        gc.collect()

X:(202520, 8, 8, 8), y:(202520,)  
Epoch:1/5, SubData:10/299, train loss:1.40540 valid loss:1.38548 valid acc:48.246%

In [None]:
# 学習曲線
plt.figure(figsize=(8,3))
plt.subplot(1,2,1)
plt.plot(train_loss_list,label='train')
plt.plot(valid_loss_list,label='valid')
plt.xlabel('Epoch')
plt.title('Loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(acc_list)
plt.xlabel('Epoch')
plt.title('Accuracy(%)')
plt.show()

In [None]:
# 学習成果(Greedy戦略)
n_trial = 1000
results = []

for trial in tqdm(range(n_trial)):
    board = Board()
    while not board.is_game_over():
        legal_moves = list(board.legal_moves)
        # パスの処理
        if 64 in legal_moves:
            move = 64
        elif board.turn:
            model.eval()
            p = model(torch.from_numpy(board_to_array(board)).unsqueeze(0).to(device)).detach().cpu()
            p_legal = p[0][legal_moves]
            move = legal_moves[p_legal.argmax().item()]
            if trial == 0:
                plt.figure(figsize=(3,3))
                p = p.softmax(dim=1).numpy().reshape(8,8) *100
                sns.heatmap(p, cmap='gray_r',annot=True, fmt='.0f',cbar=False)
                plt.show()
        else:
            move = np.random.choice(list(board.legal_moves))
        board.move(move)
        if trial == 0:
            display(board)
    diff = board.diff_num()
    if board.turn:
        results.append(diff)
    else:
        results.append(-diff)

In [None]:
plt.figure(figsize=(5,3))
plt.hist(results,bins=30)
plt.title(f'Normal, {np.where(np.array(results)>0,1,0).sum()/n_trial *100:.1f}%')
plt.show()

In [None]:
# 学習成果(アンサンブル)
n_trial = 1000
results_ensemble = []

for trial in tqdm(range(n_trial)):
    board = Board()
    while not board.is_game_over():
        legal_moves = list(board.legal_moves)
        # パスの処理
        if 64 in legal_moves:
            move = 64
        elif board.turn:
            # 8パターンの盤面を生成
            board_array = board_to_array(board)
            boards = [board_array, np.flip(board_array,axis=2).copy()]
            for k in range(1,4):
                board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
                boards.append(board_array_rot)
                boards.append(np.flip(board_array_rot, axis=2).copy())
            # 各パターンに対する予測
            model.eval()
            probs = model(torch.from_numpy(np.array(boards)).to(device)).softmax(1).cpu().detach().numpy()
            probs_org = [probs[0], np.fliplr(probs[1].reshape(8,8)).flatten(),
                         np.rot90(probs[2].reshape(8,8), k=-1).copy().flatten(), np.rot90(np.fliplr(probs[3].reshape(8,8)), k=-1).copy().flatten(),
                         np.rot90(probs[4].reshape(8,8), k=-2).copy().flatten(), np.rot90(np.fliplr(probs[5].reshape(8,8)), k=-2).copy().flatten(),
                         np.rot90(probs[6].reshape(8,8), k=-3).copy().flatten(), np.rot90(np.fliplr(probs[7].reshape(8,8)), k=-3).copy().flatten()]
            probs_org = np.array(probs_org)
            # 平均を算出
            p = probs_org.mean(axis=0)
            p_legal = p[legal_moves]
            move = legal_moves[p_legal.argmax()]
            if trial == 0:
                plt.figure(figsize=(3,3))
                sns.heatmap(p.reshape(8,8)*100, cmap='gray_r',annot=True, fmt='.0f',cbar=False)
                plt.show()
        else:
            move = np.random.choice(list(board.legal_moves))
        board.move(move)
        if trial == 0:
            display(board)
    diff = board.diff_num()
    if board.turn:
        results_ensemble.append(diff)
    else:
        results_ensemble.append(-diff)

In [None]:
plt.figure(figsize=(5,3))
plt.hist(results_ensemble,bins=30)
plt.title(f'Ensemble, {np.where(np.array(results_ensemble)>0,1,0).sum()/n_trial *100:.1f}%')
plt.show()