# 概要
- このノートブックでは、DQNを学習させる。
- 現在の局面sを入力すると、各手aを選択した時の期待報酬Q(s,a)が出力される。
- ネットワークの構造は以下のとおり。
    - 入力層：9チャネル
        - 黒石の位置(1)
        - 白石の位置(1)
        - 空白の位置(1)
        - 合法手の位置(1)
        - そこに打った場合、何個石を返せるか(1)
        - 隅の危険領域4マス×4隅をすべて1で埋める(1)
        - すべて1で埋める(1)
        - すべて0で埋める(1)
        - 手番情報：黒番ならすべて0で埋め、白番ならすべて1で埋める(1)
    - 第1層：5x5のn_filters種類のフィルターとReLU関数
    - 第2-11層：3x3のn_filters種類のフィルターとReLU関数
    - 第12層：3x3のn_filters種類のフィルター
    - 第13層：1x1のn_filters種類のフィルター
    - 第14層：出力256個の全結合ネットワークとReLU関数
    - 第15層：出力64個の全結合ネットワークとtanh関数

In [None]:
!python -m pip install --no-index --find-links=/kaggle/input/reversi-datasets/ creversi

In [None]:
from creversi import *
import creversi

import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from copy import copy
import os

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

In [None]:
def board_to_array2(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(ValueNetwork用)。
    第1チャネルは黒石の位置、第2チャネルに白石の位置、第3チャネルに空白の位置、
    第4チャネルに合法手の位置、第5チャネルに返せる石の個数、第6チャネルに隅=1、
    第7チャネルに1埋め、第8チャネルに0埋め、第9チャネルに手番情報(黒番=0埋め、白番=1埋め)
    """
    b = np.zeros((9,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7,8],:,:]
        b[8] = 1
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

In [None]:
def board_to_array_aug2(board, return_torch=False):
    boards = []
    board_array = board_to_array2(board)
    boards.append(board_array)
    boards.append(np.flip(board_array,axis=2).copy())
    for k in range(1,4):
        board_array_rot = np.rot90(board_array, k=k, axes=(1,2)).copy()
        boards.append(board_array_rot)
        boards.append(np.flip(board_array_rot, axis=2).copy())
    if return_torch:
        return torch.from_numpy(np.array(boards))
    else:
        return np.array(boards)

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

In [None]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 10
        self.input_layer = nn.Sequential(
            nn.Conv2d(9,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(n_filters),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.Conv2d(n_filters,n_filters,kernel_size=1,padding=1),
            nn.Flatten()
        )
        self.output_layer = nn.Sequential(
            nn.Linear(n_filters*100, 256),
            nn.ReLU(),
            nn.Linear(256, 64)
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out.tanh()

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

# ValueNetworkの重みを初期値に設定する

In [None]:
seed_everything(1234)
model_v = ValueNetwork()
model_v.load_state_dict(torch.load('/kaggle/input/reversi-datasets/value-network-v2.pth'))
model = DQN()

In [None]:
for layer_v, layer in zip(model_v.input_layer, model.input_layer):
    if (str(layer) != 'ReLU()') and ('Flatten' not in str(layer)):
        layer.weight = layer_v.weight
        layer.bias = layer_v.bias

for layer_v, layer in zip(model_v.hidden_layer, model.hidden_layer):
    if (str(layer) != 'ReLU()') and ('Flatten' not in str(layer)):
        layer.weight = layer_v.weight
        layer.bias = layer_v.bias
        
model.output_layer[0].weight = model_v.output_layer[0].weight
model.output_layer[0].bias = model_v.output_layer[0].bias

# 学習データの作成

In [None]:
# とりあえずrandomプレイで生成
N = 30000

states,actions,rewards,turns = [],[],[],[]
pos_idx = []
neg_idx = []
i = -1

for n in tqdm(range(N)):
    board = Board()
    S,A,R,T = [],[],[],[]
    while not board.is_game_over():
        legal_moves = list(board.legal_moves)
        move = random.choice(legal_moves)
        b = np.empty(1, creversi.dtypeBitboard)
        board.to_bitboard(b)
        S.append(b[0])
        A.append(move)
        R.append(0)
        T.append(board.turn)
        board.move(move)
        i += 1
    result = board.diff_num() if board.turn else -board.diff_num()
    
    if result == 0:
        pos_idx.append(i)
    else:
        neg_idx.append(i)
        
    states.append(S)
    actions.append(A)
    rewards.append(R)
    turns.append(T)
    
states = np.concatenate(states, axis=0)
actions = np.concatenate(actions, axis=0)
rewards = np.concatenate(rewards, axis=0).astype(float)
turns = np.concatenate(turns, axis=0)

rewards[pos_idx] = 1
rewards[neg_idx] = -1

In [None]:
N_all = states.shape[0]
N_valid = int(N_all*0.004)
N_train = N_all - N_valid
print(f'N_train : {N_train},  N_test : {N_valid}')

np.random.seed(1234)
split_idx = np.random.permutation(N_all)
train_idx = split_idx[:N_train]
valid_idx = split_idx[N_train:]

In [None]:
# N = 10000
# count = [0,0,0] # win,lose,draw

# len_list = []

# data = []
# for n in tqdm(range(N)):
#     episode = []
#     board = Board()
#     while True:
#         step = {"board":np.empty(1,creversi.dtypeBitboard),"turn":None,"move":None,"reward":0,"done":None,"next_pass":False}
#         board.to_bitboard(step['board'])
#         step["turn"] = board.turn
#         legal_moves = list(board.legal_moves)
#         move = random.choice(legal_moves)
#         board.move(move)
#         if board.puttable_num() == 0:
#             step["next_pass"] = True
#         step["move"] = move
#         if board.is_game_over():
#             if not board.turn:
#                 board.move_pass()
#             diff = board.diff_num()
#             step["done"] = True
#             step["reward"] = 1 if diff==0 else -1
#             episode.append(step)
#             break
#         else:
#             step["done"] = False
#             episode.append(step)
    
#     if diff > 0:
#         count[0] += 1
#         data += episode
#     elif diff < 0:
#         count[1] += 1
#         data += episode
#     else:
#         count[2] += 1
#         data += episode

# print(f"{len(data)} boards")

# 学習

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
n_epoch = 1
n_batch = 256
lr = 0.001
gamma = 0.99

model.to(device)
criterion = nn.HuberLoss()
optim = torch.optim.Adam(model.parameters(),lr=lr)

In [None]:
# loss_sum = 0
# log_interval = 2500
# loss_list = []
# board = Board()

# for epoch in range(n_epoch):
#     random_idx = np.random.permutation(len(data))
#     for it in tqdm(range(len(data)//n_batch)):
#         trajectory = [data[k] for k in random_idx[it*n_batch:(it+1)*n_batch]]
#         states = np.empty((n_batch,9,8,8),np.float32)
#         moves = []
#         not_done_next_states = np.empty((n_batch,9,8,8),np.float32)
#         not_done_flag = []
#         not_done_next_actions = []

#         for i, record in enumerate(trajectory):
#             board.set_bitboard(record["board"], record["turn"])
#             states[i] = board_to_array2(board)
#             move = record["move"]
#             moves.append(move)

#             if not record["done"]:
#                 not_done_flag.append(True)
#                 board.move(move)
#                 not_done_next_states[len(not_done_next_actions)] = board_to_array2(board)
#                 legal_moves = list(board.legal_moves)
#                 not_done_next_actions.append(legal_moves + [legal_moves[0]] * (30 - len(legal_moves)))
#                 trajectory[i]["reward"] = 0
#             else:
#                 not_done_flag.append(False)

#         not_done_flag = torch.tensor(not_done_flag).to(torch.bool).to(device)
#         not_done_next_states = torch.from_numpy(not_done_next_states[:len(not_done_next_actions)]).to(device)

#         states = torch.from_numpy(states).to(device)
#         actions = torch.tensor(moves).view(-1,1).to(torch.long).to(device)
#         rewards = torch.tensor([r["reward"] for r in trajectory]).to(torch.float32).to(device)

#         next_state_values = torch.zeros(n_batch).to(device)
#         not_done_next_actions = torch.tensor(not_done_next_actions).to(torch.long).to(device)
#         next_q = model(not_done_next_states)
#         next_state_values[not_done_flag] = next_q.gather(1, not_done_next_actions).max(1)[0].detach()

#         expected_state_action_values = (next_state_values * gamma) + rewards

#         model.eval()
#         state_action_values = model(states).gather(1, actions)

#         loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

#         # Optimize the model
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         loss_sum += loss.item()
#         loss_list.append(loss.item())
#         if (it + 1) % log_interval == 0:
#             print(f"[{it+1}] loss = {loss_sum / log_interval}")
#             vs_random(model, 100)
#             torch.save(model.state_dict(), f"value-net_{it+1}.pth")
#             loss_sum = 0

In [None]:
# states[0] <---> ( rewards[0] + max Q(sからa1を選んだ時の盤面, a'), ...... ,rewards[0] + max Q(sからh8を選んだ時の盤面, a') )
# states[0] <---> ( rewards[0] + model(sからa1を選んだ時の盤面).max(), ...... , rewards[0] + model(sからh8を選んだ時の盤面) )
# ただし「sからxxを選んだ時の盤面」が存在しない(＝xxが合法手でない)場合は、model(...)=0とする

In [None]:
# 学習
train_loss_list = []
valid_loss_list = []
board = Board()
for epoch in range(n_epoch):
    train_loss = 0.
    np.random.seed(epoch)
    random_idx = np.random.permutation(N_train)
    for i in tqdm(range(N_train//n_batch)):
        S = states[random_idx[n_batch*i:n_batch*(i+1)]]
        A = actions[random_idx[n_batch*i:n_batch*(i+1)]]
        R = rewards[random_idx[n_batch*i:n_batch*(i+1)]]
        T = turns[random_idx[n_batch*i:n_batch*(i+1)]]
        target = np.zeros((n_batch, 64), float) + R.reshape(-1,1)
        S_arr = np.zeros((n_batch, 9, 8, 8), float)
        
        for j, (bitboard,move,turn) in enumerate(zip(S,A,T)):
            board.set_bitboard(bitboard,turn)
            S_arr[j] = board_to_array2(board)
            board.move(move)
            line = board.to_line()
            legal_moves = list(board.legal_moves)
            Qs = np.zeros(64, dtype=float)
            for move_next in range(64):
                if move_next in legal_moves:
                    board.set_line(line, turn)
                    board.move(move_next)
                    model.eval()
                    with torch.no_grad():
                        Qs[move_next] = model(board_to_array2(board,True).unsqueeze(0).to(device)).cpu().numpy().max()
            target[j] = Qs
            
        X_batch = torch.from_numpy(S_arr).to(device).to(torch.float)
        y_batch = torch.from_numpy(target).to(device).to(torch.float)

        model.train()
        optim.zero_grad()
        output = model(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optim.step()
        train_loss += loss.item()
    train_loss /= N_train//n_batch

    # 評価
    model.eval()
    valid_loss = 0.
    idx = np.array(range(N_train,N_all))
    for i in range(N_valid//n_batch):
        S = states[idx[n_batch*i:n_batch*(i+1)]]
        A = actions[idx[n_batch*i:n_batch*(i+1)]]
        R = rewards[idx[n_batch*i:n_batch*(i+1)]]
        T = turns[idx[n_batch*i:n_batch*(i+1)]]
        target = np.zeros((n_batch, 64), float) + R.reshape(-1,1)
        S_arr = np.zeros((n_batch, 9, 8, 8), float)
        
        for j, (bitboard,move,turn) in enumerate(zip(S,A,T)):
            board.set_bitboard(bitboard,turn)
            S_arr[j] = board_to_array2(board)
            board.move(move)
            line = board.to_line()
            legal_moves = list(board.legal_moves)
            Qs = np.zeros(64, dtype=float)
            for move_next in range(64):
                if move_next in legal_moves:
                    board.set_line(line, turn)
                    board.move(move_next)
                    model.eval()
                    with torch.no_grad():
                        Qs[move_next] = model(board_to_array2(board,True).unsqueeze(0).to(device)).cpu().numpy().max()
            target[j] = Qs
            
        X_batch = torch.from_numpy(S_arr).to(device).to(torch.float)
        y_batch = torch.from_numpy(target).to(device).to(torch.float)

        output = model(X_batch)
        valid_loss += criterion(output, y_batch).item()
    valid_loss /= N_valid//n_batch
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)
    print(f'Epoch:{epoch+1}/{n_epoch}, train loss:{train_loss:.5f}, valid loss:{valid_loss:.5f}')
    torch.save(model.cpu().state_dict(), f'DQN-v1-checkpoint-{epoch+1}.pth')
    model.to(device)

In [None]:
board = Board()
q_pred = model(board_to_array_aug2(board,True).to(device)).mean(axis=0).detach().cpu().numpy().reshape(8,8)
sns.heatmap(q_pred*100, cmap='gray_r', annot=True, fmt='.2f', cbar=False)