In [1]:
import time
import functools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = 'cpu'

In [7]:
def plot(board, handle=None, sleep_sec=0.8):
    """ display board """
    board = board.detach().cpu()
    s_booard = np.zeros(board.shape, dtype='U')
    s_booard[board == +1] = '○'
    s_booard[board == -1] = '×'
    df = pd.DataFrame(s_booard.squeeze())
    if handle is None:
        handle = display(df, display_id=True)
    else:
        handle.update(df)
    time.sleep(sleep_sec)
    return handle


In [8]:
# Filter which inspect three stone conteniusly. shape is (4,1,3,3)
FILTERS = torch.tensor([
    # holizontal
    [[[0,0,0],
      [1,1,1],
      [0,0,0]]],
    # vertival
    [[[0,1,0],
      [0,1,0],
      [0,1,0]]],
    # to top left
    [[[1,0,0],
      [0,1,0],
      [0,0,1]]],
    # to top right
    [[[0,0,1],
      [0,1,0],
      [1,0,0]]],
], dtype=torch.float32, device=DEVICE)


In [14]:
def next_move(board, player, nv, out=None):
    """ judge win/lose of next move nv の勝敗を判定
    Args
      board : current board. shape is (1,1,3,3)
      player : +1 or -1
      nv :  flat coodinate of next move　次手のフラット座標
      out : board which store the next move. shape is (1,1,3,3)
    Return
      state : winner +1 or -1. if draw,  0
    """
    assert board.numel() == 9
    assert player in (-1, +1)
    assert nv in range(board.numel())
    if board.flatten()[nv] != 0:
        return -player
    out = board.detach().clone() if out is None else out
    out.flatten()[nv] = player # inpout next move
    n_match = F.conv2d(out.view(1,1,3,3), FILTERS, stride=1, padding=1) # count consecutive stones
    mask = n_match.abs() == 3 # coodinate that stones are 3 consecutives
    state = n_match[mask].sign().sum().clamp(-1,1) # winner
    return state.detach()

In [15]:
def auto_play(board, policies, display_handle=None):
    policies = np.broadcast_to(policies,2) # policies of players
    board[:] = 0 # reset board
    if display_handle:
        plot(board, display_handle) # display board
    player = torch.tensor(1).float()
    for turn in range(board.numel()):
        bin = int(player == -1)
        nv = policies[bin](board, player) # select next move
        state = next_move(board, player, nv, out=board) # hit next move
        if display_handle:
            plot(board, display_handle)
        if state != 0: # finish if w/l has been decided
            break
        player = -player # change players
    return state

In [16]:
def policy_random(board, player):
    blanks, = torch.where(board.flatten()== 0)
    nv = torch.randint(len(blanks), (1,))
    return blanks[nv]

In [21]:
def run():
    handle = display(None, display_id=True)
    board = torch.zeros((3,3), dtype=torch.float32, device=DEVICE)
    state = auto_play(board, policy_random, display_handle=handle)
    msg = {-1: '× wins', 0: 'draw', 1: '○ wins'}
    return msg[int(state)]

run()

Unnamed: 0,0,1,2
0,×,×,○
1,×,○,○
2,○,○,×


'○ wins'

In [26]:
def repeat_play(board, policy, N=100):
    """ repeat games N times
    Return
    counts : [count wins of later, count of draws, count wins of former]
    """
    states = [int(auto_play(board, policy)) for i in range(N)]
    u, c = np.unique(states, return_counts=True)
    assert set(u).issubset({-1, 0, 1})
    counts = np.zeros(3, 'i')
    counts [u+1] = c
    return counts
    

In [29]:
def run(N=100):
    board = torch.zeros((3,3), dtype=torch.float32, device=DEVICE)
    counts = repeat_play(board, policy_random, N=N)
    with np.printoptions(precision=2, floatmode='fixed'):
        print('[later draw former]')
        print(counts / counts.sum(), '=', counts, '/', counts.sum())

run(N=1000)

[later draw former]
[0.28 0.12 0.59] = [284 121 595] / 1000
