## ライブラリ

In [None]:
!pip install creversi

In [None]:
from creversi import *

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import seaborn as sns
from copy import copy

import torch
import torch.nn as nn
from torch.autograd import Function

## クラス・関数

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 80
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

In [None]:
def board_to_array(board, return_torch=False):
    """
    boardオブジェクトからndarrayに変換する関数(PolicyNetwork用)。
    第1チャンネルは黒石の位置、第2チャンネルに白石の位置、第3チャンネルに空白の位置、
    第4チャンネルに合法手の位置、第5チャンネルに返せる石の個数、第6チャンネルに隅=1、
    第7チャンネルに1埋め、第8チャンネルに0埋め。
    """
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    if return_torch:
        return torch.from_numpy(b)
    return b

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradient = None
        self.activations = None
        
        def backward_hook(module, grad_input, grad_output):
            """モデルの特定の層の勾配を取得するためのフックを設定"""
            self.gradient = grad_output[0]

        def forward_hook(module, input, output):
            """モデルの特定の層の出力を取得するためのフックを設定"""
            self.activations = output        
        
        self.target_layer.register_backward_hook(backward_hook)
        self.target_layer.register_forward_hook(forward_hook)
    
    def __call__(self, input_image, target_class=None):
        self.model.zero_grad()
        output = self.model(input_image)
        
        if target_class is None:
            target_class = torch.argmax(output)
        
        output[0, target_class].backward()
        
        gradients = self.gradient[0].detach().cpu().numpy()
        activations = self.activations[0].detach().cpu().numpy()
        
        # クラスごとの重要度を計算
        weights = np.mean(gradients, axis=(1, 2))
        cam = np.sum(weights[:, np.newaxis, np.newaxis] * activations, axis=0)
        cam = np.maximum(cam, 0)
        return cam

## モデル読み込み

In [None]:
model = torch.load('/kaggle/input/reversi-datasets/policy-network-v1.pth')
model.eval()

## 盤面の設定

- `line`：黒石は`X`，白石は`O`，空白は`-`として左上から順に１列で記述。
- `turn`：次が黒番なら`True`，白番なら`False`
- 下のセルを実行すると、以下の情報が出力されます：
    - モデルが予測した最善手
    - 設定した盤面
    - モデルの予測値の詳細(確率の分布)

In [None]:
board = Board()

#### 盤面の設定 ####
line = 'OXX-OOO-OOOOOO--OOOXOOOOOOOOXOOO-OOXOOOO-OOXOOXO--OOOXOX--OOOOOO'
turn = True
##################
assert len(line) == 64, 'length of "line" should be 64.'
board.set_line(line,turn)

output = model(board_to_array(board,True))
out = output.argmax().item()
print(f'predicted : {move_to_str(out)}')
display(board)

plt.figure(figsize=(3,3))
sns.heatmap(output.softmax(1).detach().numpy().reshape(8,8)*100, annot=True, fmt='.0f', cmap='gray_r', cbar=False)
plt.show()

## Grad-CAMで各畳み込み層が注目している箇所を可視化
- モデルには12層の畳み込み層を使用しており、その全てを可視化しました。
- 番号が小さいほど入力に近い層で、番号が大きいほど出力に近い層です。
- 最後の`Total`は全12層のGrad-CAMの出力の和を取ったものを可視化しています。
- Warningが出るかもしれませんが無視してください。

In [None]:
target_layers = [model.input_layer[0]] + [model.hidden_layer[i] for i in range(0,21,2)]
fig = plt.figure(figsize=(12,9))
ax = [fig.add_subplot(3,4,i+1) for i in range(12)]
cam_total = np.zeros((8,8))
for i,target_layer in enumerate(target_layers):
    # 入力画像を用意
    input_image = board_to_array(board,True).unsqueeze(0)
    # Grad-CAMの計算
    grad_cam = GradCAM(model, target_layer)
    cam = grad_cam(input_image)
    cam_total += cam
    # 可視化
    ax[i].set_xticks(range(9),'')
    ax[i].set_yticks(range(9),'')
    ax[i].set_xlim(0,8)
    ax[i].set_ylim(0,8)
    ax[i].grid()
    ax[i].set_title(f'layer{i}', fontsize=10)
    for j,p in enumerate(board.to_line()):
        if p=='X':
            C = pat.Circle(xy=(1/2+j%8,1/2+(7-j//8)), radius=3/7, color='black')
            ax[i].add_patch(C)
        if p=='O':
            C = pat.Circle(xy=(1/2+j%8,1/2+(7-j//8)), radius=3/7, fc='white', ec='black')
            ax[i].add_patch(C)
    C = pat.Rectangle(xy=(1/4+out%8,1/4+(7-out//8)), width=1/2, height=1/2, ec='red', fill=False)
    ax[i].add_patch(C)
    
    ax[i].contourf(np.arange(0.5,8.5), np.arange(0.5,8.5), cam[::-1], cmap='jet', levels=100, alpha=0.5, antialiased=True)
plt.show()



fig = plt.figure(figsize=(3,3))
ax = fig.add_subplot(111)
ax.set_xticks(range(9),'')
ax.set_yticks(range(9),'')
ax.set_xlim(0,8)
ax.set_ylim(0,8)
ax.grid()
for j,p in enumerate(board.to_line()):
    if p=='X':
        C = pat.Circle(xy=(1/2+j%8,1/2+(7-j//8)), radius=3/7, color='black')
        ax.add_patch(C)
    if p=='O':
        C = pat.Circle(xy=(1/2+j%8,1/2+(7-j//8)), radius=3/7, fc='white', ec='black')
        ax.add_patch(C)
    C = pat.Rectangle(xy=(1/4+out%8,1/4+(7-out//8)), width=1/2, height=1/2, ec='red', fill=False)
    ax.add_patch(C)

ax.contourf(np.arange(0.5,8.5), np.arange(0.5,8.5), cam_total[::-1], cmap='jet', levels=100, alpha=0.5, antialiased=True)
ax.set_title('Total')
plt.show()