In [1]:
import time
import functools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = 'cpu'

In [7]:
def plot(board, handle=None, sleep_sec=0.8):
    """ display board """
    board = board.detach().cpu()
    s_booard = np.zeros(board.shape, dtype='U')
    s_booard[board == +1] = '○'
    s_booard[board == -1] = '×'
    df = pd.DataFrame(s_booard.squeeze())
    if handle is None:
        handle = display(df, display_id=True)
    else:
        handle.update(df)
    time.sleep(sleep_sec)
    return handle


In [8]:
# Filter which inspect three stone conteniusly. shape is (4,1,3,3)
FILTERS = torch.tensor([
    # holizontal
    [[[0,0,0],
      [1,1,1],
      [0,0,0]]],
    # vertival
    [[[0,1,0],
      [0,1,0],
      [0,1,0]]],
    # to top left
    [[[1,0,0],
      [0,1,0],
      [0,0,1]]],
    # to top right
    [[[0,0,1],
      [0,1,0],
      [1,0,0]]],
], dtype=torch.float32, device=DEVICE)


In [10]:
def next_move(board, player, nv, out=None):
    """ judge win/lose of next move nv の勝敗を判定
    Args
      board : current board. shape is (1,1,3,3)
      player : +1 or -1
      nv :  flat coodinate of next move　次手のフラット座標
      out : board which store the next move. shape is (1,1,3,3)
    Return
      state : winner +1 or -1. if draw,  0
    """
    assert board.numel() == 9
    assert player in (-1, +1)
    assert nv in range(board.numel())
    if board.flatten()[nv] != 0:
        return -player
    out = board.detach().clone() if out is None else out
    out.flatten()[nv] = player # inpout next move
    n_match = F.conv2d(out.view(1,1,3,3), FILTERS, stride=1, padding=1) # count consecutive stones
    mask = n_match.abs() == 3 # coodinate that stones are 3 consecutives
    stake = n_match[mask].sign().sum().clamp(-1,1) # winner
    return state.detach()