In [1]:
import time
import functools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = 'cpu'

In [2]:
def plot(board, handle=None, sleep_sec=0.8):
    """ display board """
    board = board.detach().cpu()
    s_booard = np.zeros(board.shape, dtype='U')
    s_booard[board == +1] = '○'
    s_booard[board == -1] = '×'
    df = pd.DataFrame(s_booard.squeeze())
    if handle is None:
        handle = display(df, display_id=True)
    else:
        handle.update(df)
    time.sleep(sleep_sec)
    return handle


In [3]:
# Filter which inspect three stone conteniusly. shape is (4,1,3,3)
FILTERS = torch.tensor([
    # holizontal
    [[[0,0,0],
      [1,1,1],
      [0,0,0]]],
    # vertival
    [[[0,1,0],
      [0,1,0],
      [0,1,0]]],
    # to top left
    [[[1,0,0],
      [0,1,0],
      [0,0,1]]],
    # to top right
    [[[0,0,1],
      [0,1,0],
      [1,0,0]]],
], dtype=torch.float32, device=DEVICE)


In [11]:
def next_move(board, player, nv, out=None):
    """ judge win/lose of next move nv
    Args
      board : current board. shape is (1,1,3,3)
      player : +1 or -1
      nv :  flat coodinate of next move
      out : board which store the next move. shape is (1,1,3,3)
    Return
      state : winner +1 or -1. if draw,  0
    """
    assert board.numel() == 9
    assert player in (-1, +1)
    assert nv in range(board.numel())
    if board.flatten()[nv] != 0:
        return -player
    out = board.detach().clone() if out is None else out
    out.flatten()[nv] = player # inpout next move
    n_match = F.conv2d(out.view(1,1,3,3), FILTERS, stride=1, padding=1) # count consecutive stones
    mask = n_match.abs() == 3 # coodinate that stones are 3 consecutives
    state = n_match[mask].sign().sum().clamp(-1,1) # winner
    return state.detach()

In [12]:
def auto_play(board, policies, display_handle=None):
    policies = np.broadcast_to(policies,2) # policies of players
    board[:] = 0 # reset board
    if display_handle:
        plot(board, display_handle) # display board
    player = torch.tensor(1).float()
    for turn in range(board.numel()):
        bin = int(player == -1)
        nv = policies[bin](board, player) # select next move
        state = next_move(board, player, nv, out=board) # hit next move
        if display_handle:
            plot(board, display_handle)
        if state != 0: # finish if w/l has been decided
            break
        player = -player # change players
    return state

In [13]:
def policy_random(board, player):
    blanks, = torch.where(board.flatten()== 0)
    nv = torch.randint(len(blanks), (1,))
    return blanks[nv]

In [14]:
def run():
    handle = display(None, display_id=True)
    board = torch.zeros((3,3), dtype=torch.float32, device=DEVICE)
    state = auto_play(board, policy_random, display_handle=handle)
    msg = {-1: '× wins', 0: 'draw', 1: '○ wins'}
    return msg[int(state)]

run()

Unnamed: 0,0,1,2
0,○,,×
1,,○,×
2,○,×,○


'○ wins'

In [15]:
def repeat_play(board, policy, N=100):
    """ repeat games N times
    Return
    counts : [count wins of later, count of draws, count wins of former]
    """
    states = [int(auto_play(board, policy)) for i in range(N)]
    u, c = np.unique(states, return_counts=True)
    assert set(u).issubset({-1, 0, 1})
    counts = np.zeros(3, 'i')
    counts [u+1] = c
    return counts
    

In [16]:
def run(N=100):
    board = torch.zeros((3,3), dtype=torch.float32, device=DEVICE)
    counts = repeat_play(board, policy_random, N=N)
    with np.printoptions(precision=2, floatmode='fixed'):
        print('[later draw former]')
        print(counts / counts.sum(), '=', counts, '/', counts.sum())

run(N=1000)

[later draw former]
[0.29 0.12 0.59] = [293 119 588] / 1000


In [29]:
def estimate(net, board, player):
    """ AI estimate """
    board = player * board
    board = board.view(1,1, *board.size())
    return net(board).flatten()

def policy_ai(board, player, net=None):
    net.eval()
    with torch.no_grad():
        blanks, = torch.where(board.flatten() == 0)
        Q = estimate(net, board, player)
        action = blanks[Q[blanks].argmax()]
    return action

In [30]:
def policy_train(board, player, net=None, optimizer=None,epsilon=0.5,gamma=1., report=None):
    """ train AI """
    ######################
    # estimated number of action values
    ######################
    net.train()
    optimizer.zero_grad()
    Q = estimate(net, board, player) # estimate action value
    blanks, = torch.where(board.flatten() == 0) # get blank cell
    if torch.rand((1,)) < epsilon:
        i = torch.randint(len(blanks), (1,)) # select random move from blanks
        action = blanks[i.item()]
    else:
        action = blanks[Q[blanks].argmax()] # select next move from the maximum value from blanks

    Q_action = Q[action] # value of next action

    ######################
    # target number of action values
    ######################
    net.eval()
    with torch.no_grad():
        post = board.detach().clone()
        state = next_move(board, player, action, out=post) # judege win/lose of next move
        blanks_next, = torch.where(post.flatten() == 0) # get blank cells of next move
        reward, Q_next = 0, 0
        if state == player: # player wins
            reward = 1 
        elif state == -player: # player loses
            reward = -1
        elif len(blanks_next) == 0: # draw
            reward = 0
        else:                       # win/lose is not decided
            Q = estimate(net, post, -player) # estimate action value of next of next move 
            Q_next = Q[blanks_next].max()    # select maximum value of action

        Q_target = reward - gamma * Q_next
        Q_target = torch.as_tensor(Q_target, dtype=torch.float32, device=Q_action.device)

    ######################
    # update paramaters
    ######################
    net.train()
    loss = F.mse_loss(Q_action, Q_target)
    loss.backward()
    optimizer.step()
    if report is not None:
        report['loss'] += loss.item()

    return action
    

In [31]:
def run(n_episode=10000, interval=500, lr=0.01, N=100):
    dim = 128
    cnn = nn.Sequential(
        nn.Conv2d(1, dim, kernel_size=3, padding=0, bias=False),
        nn.Flatten(),
        nn.Linear(dim, dim, bias=True),
        nn.ReLU(True),
        nn.Linear(dim, dim, bias=True),
        nn.ReLU(True),
        nn.Linear(dim, 3*3, bias=True),
        nn.Tanh(),
    ).to(DEVICE)

    # policies while evaluation
    ai_vs_random = (
        functools.partial(policy_ai, net=cnn),
        policy_random,
    )
    random_vs_ai = (
        policy_random,
        functools.partial(policy_ai, net=cnn),
    )

    # policies while learning
    op = torch.optim.SGD(cnn.parameters(), lr=lr)
    rp = {'loss': 0.}
    policy = functools.partial(policy_train, net=cnn, optimizer=op, report=rp)
    
    print('#[later draw former]')
    n_train = np.array([0,0,0])
    board = torch.zeros((3,3), dtype=torch.float32, device=DEVICE)
    start_tm = time.time()
    for i in range(n_episode):
        winner = int(auto_play(board, policy))
        n_train[winner+1] += 1

        if i==0 or (i+1) & interval == 0 or i+1 == n_episode:
            n_1st = repeat_play(board, ai_vs_random, N=N)
            n_2nd = repeat_play(board, random_vs_ai, N=N)
            loss = rp['loss']
            current_tm =  time.time()

            with np.printoptions(formatter={'float': '{:02.0f}'.format}):
                print('[{}/{}] loss:{:.3f} %Train:{} %1st:{} %2nd:{} {:.3f}s'.format(
                    i+1, n_episode, loss,
                    100 * n_train / n_train.sum(),
                    100 * n_1st / n_1st.sum(),
                    100 * n_2nd / n_2nd.sum(),
                    current_tm - start_tm,
                ))

            n_train[:] = 0
            rp['loss'] = 0
            start_tm = current_tm

    return cnn

cnn = run(n_episode=10000, interval=500, lr=0.01, N=100)

#[later draw former]
[1/10000] loss:1.175 %Train:[00 00 100] %1st:[35 11 54] %2nd:[25 02 73] 0.368s
[2/10000] loss:1.187 %Train:[00 00 100] %1st:[49 07 44] %2nd:[25 02 73] 0.393s
[3/10000] loss:1.283 %Train:[00 00 100] %1st:[51 05 44] %2nd:[34 11 55] 0.339s
[8/10000] loss:4.420 %Train:[20 20 60] %1st:[46 05 49] %2nd:[27 05 68] 0.458s
[9/10000] loss:1.073 %Train:[100 00 00] %1st:[36 09 55] %2nd:[34 08 58] 0.434s
[10/10000] loss:1.132 %Train:[00 00 100] %1st:[36 07 57] %2nd:[40 07 53] 0.310s
[11/10000] loss:0.082 %Train:[00 100 00] %1st:[43 06 51] %2nd:[36 09 55] 0.377s
[512/10000] loss:464.851 %Train:[33 02 65] %1st:[03 00 97] %2nd:[68 01 31] 4.238s
[513/10000] loss:0.641 %Train:[00 00 100] %1st:[05 00 95] %2nd:[70 01 29] 0.361s
[514/10000] loss:0.172 %Train:[00 100 00] %1st:[05 00 95] %2nd:[64 02 34] 0.342s
[515/10000] loss:1.124 %Train:[00 00 100] %1st:[08 00 92] %2nd:[77 01 22] 0.256s
[520/10000] loss:3.856 %Train:[20 00 80] %1st:[06 02 92] %2nd:[66 01 33] 0.355s
[521/10000] loss:0.5