In [2]:
# environment:
# pip3 install torch
# 解决五子棋计算五联子碰到索引出界，需要pad矩阵
import numpy as np
a = np.array([[1, 2], [3, 4]])
np.pad(a,(4,4))
# print(a)

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 2, 0, 0, 0, 0],
       [0, 0, 0, 0, 3, 4, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [1]:
# 用五字棋尝试一下
# You can change this to another two-player game.
# 给状态张量增加一个channel表示当前行棋方

# TODO 1.实现 train Nets on GPU. Done on 2023.1.7
# TODO 2.implement a memoryband storing trails for training network. Done on 2023.1.11
# TODO 3.实现训练sudoindex（比如轨迹长度的分布等等）收集至wandb


from typing import Union, List
import numpy as np

BLACK, WHITE = 1, -1  # 颜色的先后手,黑圈O先,白叉X后

class State:
    '''实现 9 x 9 wuziqi 的棋盘'''
    X, Y = 'ABCDEFGHI',  '123456789'
    C = {0: '_', BLACK: 'O', WHITE: 'X'}

    def __init__(self):
        self.board = np.zeros((9, 9)) # (x, y)
        self.color = 1
        self.win_color = 0
        self.record = []

    def action2str(self, a:int):
        """用0-80编码落子位
            '1' '2' '3' ... '9'
        'A'  0   1   2  ...  8
        'B'  9   10  11 ...  17
        'C'  18  19  20 ...  26
        ...             ...
        'I'  72  73  74 ...  80
        """
        return self.X[a // 9] + self.Y[a % 9]

    def str2action(self, s:str):
        return self.X.find(s[0]) * 9 + self.Y.find(s[1])

    def record_string(self):
        """记录动作的列表组装成字符串序列，用空格分隔

        Returns
        -------
            一条动作的（轨迹）字符串，空格分隔方便split
        """
        return ' '.join([self.action2str(a) for a in self.record])

    def __str__(self):
        # 打印棋盘
        s = '   ' + ' '.join(self.Y) + '\n'
        for i in range(9):
            s += self.X[i] + ' ' + ' '.join([self.C[self.board[i, j]] for j in range(9)]) + '\n'
        s += 'record = ' + self.record_string()
        return s

    def check_win(self, x:int, y:int):
        # check whether 5 stones are on the line , pad zero around board by (4,4) then compute if sum to 5*self.color
        x_tmp, y_tmp = x + 4, y + 4
        boardex4 = np.pad(self.board,(4,4))
        for i in range(5):
            if sum(boardex4[x_tmp-4+i:x_tmp+1+i, y_tmp])==5*self.color:
                return True
            elif sum(boardex4[x_tmp, y_tmp-4+i:y_tmp+1+i])==5*self.color:
                return True
            elif boardex4[x_tmp+i-4,y_tmp+i-4]+boardex4[x_tmp+i-3,y_tmp+i-3]+boardex4[x_tmp+i-2,y_tmp+i-2]+\
                    boardex4[x_tmp+i-1,y_tmp+i-1]+boardex4[x_tmp+i,y_tmp+i]==5*self.color:
                return True
            elif boardex4[x_tmp+i-4,y_tmp-i+4]+boardex4[x_tmp+i-3,y_tmp-i+3]+boardex4[x_tmp+i-2,y_tmp-i+2]+\
                    boardex4[x_tmp+i-1,y_tmp-i+1]+boardex4[x_tmp+i,y_tmp-i]==5*self.color:
                return True
        return False

    def play(self, action:Union[str, int]) -> 'State':
        # 关于type hint : Python中的类是在读取完整个类之后才被定义的，因此在类中无法通过正常方式表示这个类本身。
        # 替代方法是使用一个和类同名的字符串，这被称为自引用类型。
        """状态转移
        Parameters
        ----------
            action : 0-80的落子位置int,或者动作的用空格分隔的字符串序列str
        Returns
        -------
            self
        """
        # 如果是一条（轨迹）字符串（该对象对应有数据结构的设计，一定程度上组织成有可遍历的特征）
        # 那么，可直接设计递归的调用，归约为仅需实现单次的int输入的动作状态转移
        if isinstance(action, str):
            for astr in action.split(): # 默认用空格分隔字符串
                self.play(self.str2action(astr))
            return self

        x, y = action // 9, action % 9
        self.board[x, y] = self.color

        # # 检查是否5子连线
        if self.check_win(x , y):
            self.win_color = self.color

        self.color = -self.color
        self.record.append(action)
        return self

    def terminal(self):
        # 终止状态检查，用于selfplay循环条件
        return self.win_color != 0 or len(self.record) == 9 * 9

    def terminal_reward(self):
        # 返回终局奖励 1，-1
        return self.win_color if self.color == BLACK else -self.win_color

    def legal_actions(self) -> List[int]:
        # 返回根节点下的合法走子位，List of int
        return [a for a in range(9 * 9) if self.board[a // 9, a % 9] == 0]

    def feature(self, to_cuda:bool = False):
        # making input ndarray for NN_state
        # 堆个ndarry用作神经网络输入 : [当前行动者 ,我方视角棋盘，对方视角棋盘]
        # support sending ndarry to cuda tensor with added a batch_dim
        now_mover = np.ones((9, 9)) * self.color # 加1通道进卷积
        s = np.stack([now_mover, self.board == self.color, self.board == -self.color]).astype(np.float32)
        if to_cuda:
            return torch.from_numpy(s).unsqueeze(0).cuda()
        return s

    def action_feature(self, action, to_cuda:bool = False):
        # 制作动作矩阵
        # support sending ndarry to cuda tensor with added a batch_dim
        a = np.zeros((1, 9, 9), dtype=np.float32)
        a[0, action // 9, action % 9] = 1
        if to_cuda:
            return torch.from_numpy(a).unsqueeze(0).cuda()
        return a

state = State().play('A2')
print(state)
print('input feature')
print(state.feature())
state = State().play('B2 A1 I2')
print(state)
print('input feature')
print(state.feature())

   1 2 3 4 5 6 7 8 9
A _ O _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ _ _ _ _ _ _
F _ _ _ _ _ _ _ _ _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = A2
input feature
[[[-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]
  [-1. -1. -1. -1. -1. -1. -1. -1. -1.]]

 [[ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.]]

 [[ 0.  1.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0

In [2]:
# 定义组件网络Res&Conv

import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv(nn.Module):
    def __init__(self, filters0, filters1, kernel_size, bn=False):
        super().__init__()
        self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn = None
        if bn:
            self.bn = nn.BatchNorm2d(filters1)

    def forward(self, x):
        h = self.conv(x)
        if self.bn is not None:
            h = self.bn(h)
        return h

class ResidualBlock(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.conv = Conv(filters, filters, 3, True)

    def forward(self, x):
        return F.relu(x + (self.conv(x)))

In [3]:
num_filters = 64
num_blocks = 6

class Representation(nn.Module):
    ''' Conversion from observation to inner abstract state '''
    def __init__(self, input_shape):
        super().__init__()
        self.input_shape = input_shape # (c, 9, 9)
        self.board_size = self.input_shape[1] * self.input_shape[2]
        # 初始化nn.Conv2d inputchannels，outputchannels
        self.layer0 = Conv(self.input_shape[0], num_filters, 3, bn=True)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])

    def forward(self, x):
        h = F.relu(self.layer0(x))
        for block in self.blocks:
            h = block(h)
        return h # torch.Size([1, 16, 9, 9])

    def inference(self, x, pass_to_cpu:bool = True):
        self.eval()
        with torch.no_grad():
            # rp = self(torch.from_numpy(x).unsqueeze(0)) # cpu_only版本: conv2d的输入tensor需要四维，多加一维度在0位置
            rp = self(x)
        if not pass_to_cpu:
            return rp # print('rp tensor shape' , rp.shape) # torch.Size([1, 16, 9, 9])
        return rp.cpu().numpy()[0]  

class Prediction(nn.Module):
    ''' Policy and value prediction from inner abstract state '''
    def __init__(self, action_shape):
        super().__init__()
        self.board_size = np.prod(action_shape[1:]) # 9 x 9 = 81
        self.action_size = action_shape[0] * self.board_size # 1 x 81 = 81

        self.conv_p1 = Conv(num_filters, 4, 1, bn=True)
        self.conv_p2 = Conv(4, 1, 1)

        self.conv_v = Conv(num_filters, 4, 1, bn=True)
        self.fc_v = nn.Linear(self.board_size * 4, 1, bias=False)

    def forward(self, rp):
        h_p = F.relu(self.conv_p1(rp))
        # print('过第一层p卷积', h_p.shape) # torch.Size([1, 4, 9, 9])
        h_p = self.conv_p2(h_p).view(-1, self.action_size)
        # print('过第二层p卷积', h_p.shape) # torch.Size([1, 81])
        h_v = F.relu(self.conv_v(rp))
        # print('过第一层fc卷积', h_v.shape) # torch.Size([1, 4, 9, 9])
        h_v = self.fc_v(h_v.view(-1, self.board_size * 4))
        # print('过第二层fc层', h_v.shape) # torch.Size([1, 1])
        # range of value is -1 ~ 1
        return F.softmax(h_p, dim=-1), torch.tanh(h_v)

    def inference(self, rp, pass_to_cpu:bool = True):
        self.eval()
        with torch.no_grad():
            # p, v = self(torch.from_numpy(rp).unsqueeze(0))
            p, v = self(rp) # print('p shape is ', p.shape) # torch.Size([1, 81])
        if not pass_to_cpu:
            return p, v
        return p.cpu().numpy()[0], v.cpu().numpy()[0][0]

class Dynamics(nn.Module):
    '''Abstract state transition'''
    def __init__(self, rp_shape, act_shape):
        super().__init__()
        self.rp_shape = rp_shape
        self.layer0 = Conv(rp_shape[0] + act_shape[0], num_filters, 3, bn=True)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])

    def forward(self, rp, a):
        h = torch.cat([rp, a], dim=1)
        # print('dim=1 cat shape h ' , h.shape) # torch.Size([1, 17, 9, 9])
        h = self.layer0(h)
        for block in self.blocks:
            h = block(h)
        return h

    def inference(self, rp, a, pass_to_cpu:bool = True):
        self.eval()
        with torch.no_grad():
            # cpu_only版本: rp = self(torch.from_numpy(rp).unsqueeze(0), torch.from_numpy(a).unsqueeze(0))
            rp = self(rp, a)
        if not pass_to_cpu:
            return rp
        return rp.cpu().numpy()[0]

class Net(nn.Module):
    '''Whole net'''
    def __init__(self):
        super().__init__()
        state = State()
        input_shape = state.feature().shape # state (c, 9, 9)
        action_shape = state.action_feature(0).shape # action (1, 9, 9)
        rp_shape = (num_filters, *input_shape[1:]) # hidden space (16, 9, 9)

        self.representation = Representation(input_shape)
        self.prediction = Prediction(action_shape)
        self.dynamics = Dynamics(rp_shape, action_shape)

    def predict(self, state0, path):
        '''Predict p and v from original state and path'''
        outputs = []
        x = state0.feature(to_cuda=True)
        rp = self.representation.inference(x, pass_to_cpu= False)
        outputs.append(self.prediction.inference(rp, pass_to_cpu = True))
        for action in path:
            a = state0.action_feature(action, to_cuda=True)
            rp = self.dynamics.inference(rp, a, pass_to_cpu = False)
            outputs.append(self.prediction.inference(rp, pass_to_cpu = True))
        return outputs

In [4]:
# 给上面的三个网络做单元测试用
def show_net(net, state):
    '''Display policy (p) and value (v)'''
    print(state)
    p, v = net.predict(state, [])[-1]
    print('p = ')
    print((p * 10000).astype(int).reshape((-1, *net.representation.input_shape[1:3])))
    print('v = ', v)
    print()

#  Outputs before training
show_net(Net().cuda(), State())

   1 2 3 4 5 6 7 8 9
A _ _ _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ _ _ _ _ _ _
F _ _ _ _ _ _ _ _ _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = 
p = 
[[[136 131 128 127 127 127 126 129 136]
  [130 125 118 116 115 115 114 120 123]
  [130 124 118 117 117 117 116 122 125]
  [130 124 118 117 117 116 115 122 125]
  [130 124 118 117 117 117 116 122 125]
  [130 124 118 117 117 116 116 121 124]
  [129 123 118 117 118 117 116 122 126]
  [125 122 118 118 118 118 117 123 127]
  [133 134 131 132 132 132 132 137 139]]]
v =  0.045420434



In [5]:
# 实现蒙特卡洛树搜索MCTS

class Node:
    '''Search result of one abstract (or root) state'''
    def __init__(self, p, v):
        self.p, self.v = p, v
        self.n, self.q_sum = np.zeros_like(p), np.zeros_like(p)
        self.n_all, self.q_sum_all = 1, v / 2 # prior

    def update(self, action, q_new):
        # Update
        self.n[action] += 1
        self.q_sum[action] += q_new

        # Update overall stats
        self.n_all += 1
        self.q_sum_all += q_new

In [19]:
import time
import copy

class Tree:
    '''Monte Carlo Tree'''
    def __init__(self, net):
        self.net = net
        self.nodes = {}

    def search(self, state, path, rp, depth):
        # Return predicted value from new state
        # rp is a Tensor on Gpu
        key = state.record_string()
        if len(path) > 0:
            key += '|' + ' '.join(map(state.action2str, path))
        if key not in self.nodes:
            p, v = self.net.prediction.inference(rp, pass_to_cpu = True)
            self.nodes[key] = Node(p, v)
            return v

        # State transition by an action selected from bandit
        node = self.nodes[key]
        p = node.p
        mask = np.zeros_like(p)
        if depth == 0:
            # Add noise to policy on the root node
            p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * len(p))
            # On the root node, we choose action only from legal actions
            mask[state.legal_actions()] = 1
            p *= mask
            p /= p.sum() + 1e-16

        n, q_sum = 1 + node.n, node.q_sum_all / node.n_all + node.q_sum
        ucb = q_sum / n + 2.0 * np.sqrt(node.n_all) * p / n + mask * 4 # PUCB formula
        best_action = np.argmax(ucb)

        # Search next state by recursively calling this function
        rp_next = self.net.dynamics.inference(rp, state.action_feature(best_action, to_cuda=True), pass_to_cpu=False)
        path.append(best_action)
        q_new = -self.search(state, path, rp_next, depth + 1) # With the assumption of changing player by turn
        node.update(best_action, q_new)

        return q_new

    def think(self, state, num_simulations, temperature = 0, show=False):
        # End point of MCTS
        if show:
            print(state)
        start, prev_time = time.time(), 0
        project_once = self.net.representation.inference(state.feature(to_cuda=True), pass_to_cpu=False) # Muzero在根节点状态只需rp一次，因为rpnet是固定的
        for _ in range(num_simulations):
            self.search(state, [], project_once, depth=0)
            # Display search result on every second
            if show:
                tmp_time = time.time() - start
                if int(tmp_time) > int(prev_time):
                    prev_time = tmp_time
                    root, pv = self.nodes[state.record_string()], self.pv(state)
                    print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s'
                          % (tmp_time, state.action2str(pv[0]), root.q_sum[pv[0]] / root.n[pv[0]],
                             root.n[pv[0]], root.n_all, ' '.join([state.action2str(a) for a in pv])))

        #  Return probability distribution weighted by the number of simulations
        root = self.nodes[state.record_string()]
        n = root.n + 1
        n = (n / np.max(n)) ** (1 / (temperature + 1e-8))
        return n / n.sum() # teacher--MCTS

    def pv(self, state):
        # Return principal variation (action sequence which is considered as the best)
        s, pv_seq = copy.deepcopy(state), []
        while True:
            key = s.record_string()
            if key not in self.nodes or self.nodes[key].n.sum() == 0:
                break
            best_action = sorted([(a, self.nodes[key].n[a]) for a in s.legal_actions()], key=lambda x: -x[1])[0][0]
            pv_seq.append(best_action)
            s.play(best_action)
        return pv_seq

In [20]:
# Search with initialized net

tree = Tree(Net().cuda())
next_step_0 = tree.think(State(), 100, show=True)

tree = Tree(Net().cuda())
next_step_n = tree.think(State().play('E4 F5 E5 F6 E6 F7 E7'), 200, show=True)
print(next_step_n.reshape((9, 9)))

# tree = Tree(Net().cuda())
# tree.think(State().play('F4 D5 F5 D6 F6 D7 F7'), 200, show=True)

# tree = Tree(Net().cuda())
# tree.think(State().play('B2 A2 A3 C1'), 200, show=True)

   1 2 3 4 5 6 7 8 9
A _ _ _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ _ _ _ _ _ _
F _ _ _ _ _ _ _ _ _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = 
   1 2 3 4 5 6 7 8 9
A _ _ _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ O O O O _ _
F _ _ _ _ X X X _ _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = E4 F5 E5 F6 E6 F7 E7
[[0.07692308 0.07692308 0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.07692308]
 [0.         0.         0.         0.         0.07692308 0.07692308
  0.         0.         0.        ]
 [0.         0.         0.07692308 0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.        ]
 [0.         0.         0.07692308 0.         0.         0.
  0.      

In [21]:
# Training of neural net
from tqdm import tqdm
import torch.optim as optim

batch_size = 32
num_steps = 100

def gen_target(ep, k):
    '''Generate inputs and targets for training'''
    # path, reward, observation, action, policy
    turn_idx = np.random.randint(len(ep[0]))
    ps, vs, ax = [], [], []
    for t in range(turn_idx, turn_idx + k + 1):
        if t < len(ep[0]):
            p = ep[4][t]
            a = ep[3][t]
        else: # state after finishing game
            # p is 0 (loss is 0)
            p = np.zeros_like(ep[4][-1])
            # random action selection
            a = np.zeros(np.prod(ep[3][-1].shape), dtype=np.float32)
            a[np.random.randint(len(a))] = 1
            a = a.reshape(ep[3][-1].shape)
        vs.append([ep[1] if t % 2 == 0 else -ep[1]])
        ps.append(p)
        ax.append(a)
        
    return ep[2][turn_idx], ax, ps, vs

def train(episodes, net, optimizer):
    #     episodes = List(record:List[int], 
    #                     reward:int(0,1,-1), 
    #                     features:state.feature(), 
    #                     action_features:state.action_feature(action) from random.choice based on distribution p_targets, 
    #                     p_targets:teacher--MCTS))
    '''Train neural net on GPU'''
    p_loss_sum, v_loss_sum = torch.as_tensor(0, dtype=torch.float32).cuda() , torch.as_tensor(0, dtype=torch.float32).cuda()
    net.train()
    k = 3 # 与alpha不同（2个状态对就可以），至少3个连续状态用来训练Muzero
    for _ in tqdm(range(num_steps)):
        x, ax, p_target, v_target = zip(*[gen_target(episodes[np.random.randint(len(episodes))], k) for j in range(batch_size)])
        x = torch.from_numpy(np.array(x)).cuda()
        ax = torch.from_numpy(np.array(ax))
        p_target = torch.from_numpy(np.array(p_target))
        v_target = torch.FloatTensor(np.array(v_target))

        # Change the order of axis as [time step, batch, ...]
        ax = torch.transpose(ax, 0, 1).cuda()
        p_target = torch.transpose(p_target, 0, 1).cuda()
        v_target = torch.transpose(v_target, 0, 1).cuda()

        # Compute losses for k (+ current) steps
        p_loss, v_loss = torch.as_tensor(0, dtype=torch.float32).cuda() , torch.as_tensor(0, dtype=torch.float32).cuda()
        for t in range(k + 1):
            rp = net.representation(x) if t == 0 else net.dynamics(rp, ax[t - 1])
            p, v = net.prediction(rp)
            p_loss += F.kl_div(torch.log(p), p_target[t], reduction='sum')
            v_loss += torch.sum(((v_target[t] - v) ** 2) / 2)

        p_loss_sum += p_loss.item()
        v_loss_sum += v_loss.item()

        optimizer.zero_grad()
        (p_loss + v_loss).backward()
        optimizer.step()

    num_train_datum = num_steps * batch_size
    print('p_loss %f v_loss %f' % (p_loss_sum.cpu().numpy() / num_train_datum, v_loss_sum.cpu().numpy() / num_train_datum))
    return net

In [22]:
#  Battle against random agents

def vs_random(net, n=100):
    results = { 0 : 0 , -1 : 0 , 1 : 0}
    for i in range(n):
        first_turn = i % 2 == 0
        turn = first_turn
        state = State()
        while not state.terminal():
            if turn:
                p, _ = net.predict(state, [])[-1]
                action = sorted([(a, p[a]) for a in state.legal_actions()], key=lambda x:-x[1])[0][0]
            else:
                action = np.random.choice(state.legal_actions())
            state.play(action)
            turn = not turn
        r = state.terminal_reward() if turn else -state.terminal_reward()
        results[r] = results.get(r, 0) + 1
    return results

In [25]:
# Main algorithm of MuZero
from collections import deque

PATH = '.model/checkpoints/'

num_games = 100          # selfplay的总局数
num_games_one_epoch = 50 # 每selfplay多少轮，训练模型一次
num_simulations = 100    # 每个Node的search搜索次数

# load net parameters train&save on GPU
net = Net()
net.load_state_dict(torch.load(PATH + '1000r.pt'))
net.cuda()

# net = Net().cuda() # initial training

optimizer = optim.SGD(net.parameters(), lr=3e-4, weight_decay=3e-5, momentum=0.8)

# Display battle results as {-1: lose 0: draw 1: win} (for episode generated for training, 1 means that the first player won)
vs_random_sum = vs_random(net)
print('vs_random = ', sorted(vs_random_sum.items()))

# episodes = []
episodes = deque([], maxlen=200) # 移动队列用于存储最新的模型游戏

result_distribution = {1: 0, 0: 0, -1: 0}

for g in tqdm(range(num_games)):
    # Generate one episode
    record, p_targets, features, action_features = [], [], [], []
    state = State()
    # temperature using to make policy targets from search results
    temperature = 0.7

    tree = Tree(net) # 每轮游戏维护同一个树

    while not state.terminal():
        # tree = Tree(net) # 有必要每走一步都新建一个搜索树么？移到循环外面可以么？
        p_target = tree.think(state, num_simulations, temperature)
        p_targets.append(p_target)
        features.append(state.feature())

        # Select action with generated distribution, and then make a transition by that action
        action = np.random.choice(np.arange(len(p_target)), p=p_target)
        record.append(action)
        action_features.append(state.action_feature(action))
        state.play(action)
        temperature *= 0.8 # 这个温度参数感觉意义不明。。。？

    # reward seen from the first turn player
    reward = state.terminal_reward() * (1 if len(record) % 2 == 0 else -1)
    result_distribution[reward] += 1
    episodes.append((record, reward, features, action_features, p_targets))

    if g % num_games_one_epoch == 0:
        print('game ', end='')
    print(g, ' ', end='')

    # Training of neural net
    if (g + 1) % num_games_one_epoch == 0:
        # Show the result distributiuon of generated episodes
        print('generated = ', sorted(result_distribution.items()))
        epi = list(episodes.copy())
        net = train(episodes=epi, net=net, optimizer=optimizer)
        vs_random_once = vs_random(net)
        print('vs_random = ', sorted(vs_random_once.items()), end='')
        for r, n in vs_random_once.items():
            vs_random_sum[r] += n
        print(' sum = ', sorted(vs_random_sum.items()))
        # new_checkpoint_path = f'/.model/checkpoints/{g}.pt'
        # torch.save(net.state_dict(), new_checkpoint_path)

print('finished')

  0%|          | 0/100 [00:00<?, ?it/s]

vs_random =  [(-1, 2), (0, 0), (1, 98)]


  1%|          | 1/100 [00:03<05:20,  3.23s/it]

game 0  

  2%|▏         | 2/100 [00:05<04:45,  2.92s/it]

1  

  3%|▎         | 3/100 [00:10<05:59,  3.70s/it]

2  

  4%|▍         | 4/100 [00:15<06:22,  3.98s/it]

3  

  5%|▌         | 5/100 [00:19<06:15,  3.95s/it]

4  

  6%|▌         | 6/100 [00:21<05:20,  3.41s/it]

5  

  7%|▋         | 7/100 [00:30<07:46,  5.02s/it]

6  

  8%|▊         | 8/100 [00:34<07:04,  4.62s/it]

7  

  9%|▉         | 9/100 [00:37<06:26,  4.25s/it]

8  

 10%|█         | 10/100 [00:43<07:02,  4.69s/it]

9  

 11%|█         | 11/100 [00:51<08:23,  5.66s/it]

10  

 12%|█▏        | 12/100 [00:56<08:18,  5.66s/it]

11  

 13%|█▎        | 13/100 [01:01<07:40,  5.29s/it]

12  

 14%|█▍        | 14/100 [01:04<06:34,  4.58s/it]

13  

 15%|█▌        | 15/100 [01:06<05:42,  4.03s/it]

14  

 16%|█▌        | 16/100 [01:12<06:27,  4.61s/it]

15  

 17%|█▋        | 17/100 [01:16<06:08,  4.44s/it]

16  

 18%|█▊        | 18/100 [01:26<08:16,  6.05s/it]

17  

 19%|█▉        | 19/100 [01:32<07:54,  5.85s/it]

18  

 20%|██        | 20/100 [01:37<07:47,  5.85s/it]

19  

 21%|██        | 21/100 [01:40<06:25,  4.88s/it]

20  

 22%|██▏       | 22/100 [01:44<06:04,  4.67s/it]

21  

 23%|██▎       | 23/100 [01:48<05:48,  4.52s/it]

22  

 24%|██▍       | 24/100 [01:52<05:26,  4.30s/it]

23  

 25%|██▌       | 25/100 [01:57<05:32,  4.44s/it]

24  

 26%|██▌       | 26/100 [02:00<04:51,  3.93s/it]

25  

 27%|██▋       | 27/100 [02:07<05:55,  4.87s/it]

26  

 28%|██▊       | 28/100 [02:10<05:21,  4.47s/it]

27  

 29%|██▉       | 29/100 [02:24<08:28,  7.16s/it]

28  

 30%|███       | 30/100 [02:27<07:02,  6.03s/it]

29  

 31%|███       | 31/100 [02:31<06:07,  5.32s/it]

30  

 32%|███▏      | 32/100 [02:36<05:55,  5.23s/it]

31  

 33%|███▎      | 33/100 [02:40<05:20,  4.79s/it]

32  

 34%|███▍      | 34/100 [02:47<06:12,  5.64s/it]

33  

 35%|███▌      | 35/100 [02:50<05:10,  4.78s/it]

34  

 36%|███▌      | 36/100 [02:55<05:02,  4.73s/it]

35  

 37%|███▋      | 37/100 [02:58<04:33,  4.34s/it]

36  

 38%|███▊      | 38/100 [03:03<04:43,  4.56s/it]

37  

 39%|███▉      | 39/100 [03:13<06:14,  6.14s/it]

38  

 40%|████      | 40/100 [03:20<06:24,  6.40s/it]

39  

 41%|████      | 41/100 [03:28<06:50,  6.96s/it]

40  

 42%|████▏     | 42/100 [03:33<06:06,  6.31s/it]

41  

 43%|████▎     | 43/100 [03:39<05:47,  6.09s/it]

42  

 44%|████▍     | 44/100 [03:51<07:34,  8.11s/it]

43  

 45%|████▌     | 45/100 [03:55<06:15,  6.83s/it]

44  

 46%|████▌     | 46/100 [03:58<05:04,  5.64s/it]

45  

 47%|████▋     | 47/100 [04:02<04:24,  4.99s/it]

46  

 48%|████▊     | 48/100 [04:06<04:12,  4.86s/it]

47  

 49%|████▉     | 49/100 [04:12<04:25,  5.20s/it]

48  



49  generated =  [(-1, 16), (0, 0), (1, 34)]


100%|██████████| 100/100 [00:01<00:00, 54.87it/s]


p_loss 2.586219 v_loss 0.583698


 50%|█████     | 50/100 [04:22<05:26,  6.54s/it]

vs_random =  [(-1, 1), (0, 0), (1, 99)] sum =  [(-1, 3), (0, 0), (1, 197)]


 51%|█████     | 51/100 [04:27<04:54,  6.02s/it]

game 50  

 52%|█████▏    | 52/100 [04:35<05:21,  6.70s/it]

51  

 53%|█████▎    | 53/100 [04:40<04:53,  6.25s/it]

52  

 54%|█████▍    | 54/100 [04:46<04:44,  6.18s/it]

53  

 55%|█████▌    | 55/100 [05:03<07:03,  9.40s/it]

54  

 56%|█████▌    | 56/100 [05:08<06:01,  8.22s/it]

55  

 57%|█████▋    | 57/100 [05:15<05:27,  7.61s/it]

56  

 58%|█████▊    | 58/100 [05:22<05:16,  7.53s/it]

57  

 59%|█████▉    | 59/100 [05:28<04:46,  6.98s/it]

58  

 60%|██████    | 60/100 [05:34<04:28,  6.72s/it]

59  

 61%|██████    | 61/100 [05:38<03:52,  5.96s/it]

60  

 62%|██████▏   | 62/100 [05:44<03:42,  5.85s/it]

61  

 63%|██████▎   | 63/100 [05:50<03:42,  6.00s/it]

62  

 64%|██████▍   | 64/100 [05:56<03:32,  5.89s/it]

63  

 65%|██████▌   | 65/100 [05:59<02:55,  5.02s/it]

64  

 66%|██████▌   | 66/100 [06:03<02:49,  4.99s/it]

65  

 67%|██████▋   | 67/100 [06:07<02:34,  4.67s/it]

66  

 68%|██████▊   | 68/100 [06:11<02:22,  4.45s/it]

67  

 69%|██████▉   | 69/100 [06:15<02:09,  4.16s/it]

68  

 70%|███████   | 70/100 [06:22<02:29,  4.99s/it]

69  

 71%|███████   | 71/100 [06:26<02:17,  4.73s/it]

70  

 72%|███████▏  | 72/100 [06:29<01:59,  4.26s/it]

71  

 73%|███████▎  | 73/100 [06:35<02:07,  4.73s/it]

72  

 74%|███████▍  | 74/100 [06:48<03:10,  7.32s/it]

73  

 75%|███████▌  | 75/100 [06:58<03:21,  8.08s/it]

74  

 76%|███████▌  | 76/100 [07:03<02:48,  7.01s/it]

75  

 77%|███████▋  | 77/100 [07:06<02:19,  6.07s/it]

76  

 78%|███████▊  | 78/100 [07:11<02:03,  5.61s/it]

77  

 79%|███████▉  | 79/100 [07:15<01:46,  5.08s/it]

78  

 80%|████████  | 80/100 [07:26<02:15,  6.78s/it]

79  

 81%|████████  | 81/100 [07:31<02:02,  6.45s/it]

80  

 82%|████████▏ | 82/100 [07:36<01:46,  5.93s/it]

81  

 83%|████████▎ | 83/100 [07:42<01:42,  6.02s/it]

82  

 84%|████████▍ | 84/100 [07:46<01:23,  5.25s/it]

83  

 85%|████████▌ | 85/100 [07:51<01:20,  5.40s/it]

84  

 86%|████████▌ | 86/100 [07:55<01:07,  4.79s/it]

85  

 87%|████████▋ | 87/100 [07:58<00:55,  4.27s/it]

86  

 88%|████████▊ | 88/100 [08:01<00:48,  4.03s/it]

87  

 89%|████████▉ | 89/100 [08:07<00:49,  4.51s/it]

88  

 90%|█████████ | 90/100 [08:11<00:42,  4.26s/it]

89  

 91%|█████████ | 91/100 [08:17<00:45,  5.04s/it]

90  

 92%|█████████▏| 92/100 [08:32<01:03,  7.89s/it]

91  

 93%|█████████▎| 93/100 [08:35<00:45,  6.43s/it]

92  

 94%|█████████▍| 94/100 [08:39<00:33,  5.55s/it]

93  

 95%|█████████▌| 95/100 [08:55<00:44,  8.86s/it]

94  

 96%|█████████▌| 96/100 [09:00<00:30,  7.72s/it]

95  

 97%|█████████▋| 97/100 [09:12<00:26,  8.83s/it]

96  

 98%|█████████▊| 98/100 [09:15<00:14,  7.26s/it]

97  

 99%|█████████▉| 99/100 [09:19<00:06,  6.25s/it]

98  



99  generated =  [(-1, 27), (0, 0), (1, 73)]


100%|██████████| 100/100 [00:01<00:00, 50.84it/s]


p_loss 2.656111 v_loss 0.560035


100%|██████████| 100/100 [09:30<00:00,  5.70s/it]

vs_random =  [(-1, 5), (0, 0), (1, 95)] sum =  [(-1, 8), (0, 0), (1, 292)]
finished





In [16]:
new_checkpoint_path = f'.model/checkpoints/1000r.pt'
torch.save(net.state_dict(), new_checkpoint_path)

# related tips
# 4. Trained and Save on GPU, Load on GPU
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 
# When loading a model on a GPU that was trained and saved on GPU, 
# simply
# convert the initialized model to a CUDA optimized model using
# ``model.to(torch.device('cuda'))``.
# 
# Be sure to use the ``.to(torch.device('cuda'))`` function 
# on all model inputs 
# to prepare the data for the model.

# PATH = "model.pt"
# net.cuda()
# # Save
# torch.save(net.state_dict(), PATH)

# # Load
# device = torch.device("cuda")
# model = Net()
# model.load_state_dict(torch.load(PATH))
# model.to(device)

In [27]:
# Search with trained net

tree = Tree(net)
state = State()

next_step = tree.think(state.play('E4 F5 E5 F6 E6 F7 E7'), 2000, show=True)
print(next_step.reshape((9 , 9)))

   1 2 3 4 5 6 7 8 9
A _ _ _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ O O O O _ _
F _ _ _ _ X X X _ _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = E4 F5 E5 F6 E6 F7 E7
1.00 sec. best D8. q = -0.6277. n = 32 / 245. pv = D8
2.00 sec. best D8. q = -0.6442. n = 38 / 414. pv = D8
3.00 sec. best D8. q = -0.6584. n = 47 / 552. pv = D8
4.00 sec. best D8. q = -0.6583. n = 61 / 675. pv = D8
5.00 sec. best D8. q = -0.6892. n = 75 / 797. pv = D8
6.00 sec. best D8. q = -0.6955. n = 89 / 956. pv = D8
7.00 sec. best D8. q = -0.7111. n = 99 / 1104. pv = D8
8.01 sec. best D8. q = -0.7193. n = 109 / 1233. pv = D8
9.00 sec. best D8. q = -0.7177. n = 119 / 1345. pv = D8
10.01 sec. best D8. q = -0.7185. n = 129 / 1463. pv = D8
11.00 sec. best D8. q = -0.7191. n = 141 / 1588. pv = D8
12.00 sec. best D8. q = -0.7202. n = 145 / 1732. pv = D8
13.00 sec. best D8. q = -0.7243. n = 149 / 1882. pv = D8
[[0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 

In [28]:
next_step = tree.think(state.play('F8'), 800, show=True)
print('是否已经终局: ', state.terminal())
print(next_step.reshape((9 , 9)))

   1 2 3 4 5 6 7 8 9
A _ _ _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ O O O O _ _
F _ _ _ _ X X X X _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = E4 F5 E5 F6 E6 F7 E7 F8
1.00 sec. best E8. q = 0.9508. n = 92 / 203. pv = E8
2.00 sec. best E8. q = 0.9554. n = 148 / 381. pv = E8
3.01 sec. best E8. q = 0.9560. n = 207 / 528. pv = E8
4.01 sec. best E8. q = 0.9579. n = 253 / 667. pv = E8
是否已经终局:  False
[[0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [15]:
next_step = tree.think(state.play('E8'), 800, show=True)
print('是否已经终局: ', state.terminal())
print(next_step.reshape((9 , 9)))

   1 2 3 4 5 6 7 8 9
A _ _ _ _ _ _ _ _ _
B _ _ _ _ _ _ _ _ _
C _ _ _ _ _ _ _ _ _
D _ _ _ _ _ _ _ _ _
E _ _ _ O O O O O _
F _ _ _ _ X X X X _
G _ _ _ _ _ _ _ _ _
H _ _ _ _ _ _ _ _ _
I _ _ _ _ _ _ _ _ _
record = E4 F5 E5 F6 E6 F7 E7 F8 E8
1.00 sec. best F9. q = -0.8868. n = 36 / 189. pv = F9
2.01 sec. best F9. q = -0.8929. n = 62 / 330. pv = F9
3.01 sec. best F9. q = -0.8966. n = 76 / 449. pv = F9
4.00 sec. best F9. q = -0.9030. n = 91 / 548. pv = F9
5.01 sec. best F9. q = -0.9080. n = 107 / 640. pv = F9
6.01 sec. best F9. q = -0.9083. n = 120 / 726. pv = F9
是否已经终局:  True
[[0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0.]]
