In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
YourPATH = '/content/drive/MyDrive/Colab Notebooks/YAI-quoridor'  ### 

os.chdir(YourPATH)
!dir

In [None]:
'''Pre-activation ResNet in PyTorch.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''
from pip import main
import torch
import torch.nn as nn
import torch.nn.functional as F



class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out

class GlobalPooling(nn.Module):
    def __init__(self,):
        super(GlobalPooling, self).__init__()
        self.b_avg = 9
        self.sigma = 0

    def forward(self, x, value_head=False):
        mean1 = torch.mean(torch.mean(x,dim=2),dim=2)
        mean2 = mean1*((x.shape[2]-self.b_avg)/10)
        if value_head:
            max = mean1*(((x.shape[2]-self.b_avg)**2-self.sigma**2)/100)
        else:
            max, _ = torch.max(x,dim=2)
            max, _ = torch.max(max,dim=2)
        return mean1, mean2, max

class GlobalPoolingBias(nn.Module):
    def __init__(self,cX, cG):
        super(GlobalPoolingBias, self).__init__()
        self.bn = nn.BatchNorm2d(cG)
        self.GP = GlobalPooling()
        self.fc = nn.Linear(3*cG, cX)
    
    def forward(self, X, G):
        '''
        X: bxbxcX
        G: bxbxcG
        '''
        G_out = F.relu(self.bn(G))
        b1, b2, b3 = self.GP(G_out)
        bias = torch.cat((b1,b2,b3),dim=1)
        out = self.fc(bias)
        X += out.unsqueeze(-1).unsqueeze(-1)
        return X, bias

class GlobalPoolingBlock(nn.Module):
    def __init__(self, in_planes, out_channels, c_pool=32):
        super(GlobalPoolingBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.c_pool=c_pool
        self.GPB = GlobalPoolingBias(c_pool, out_channels-c_pool)
        self.bn2 = nn.BatchNorm2d(c_pool)
        self.conv2 = nn.Conv2d(c_pool, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = x
        out = self.conv1(out)
        out, _ = self.GPB(out[:,:self.c_pool],out[:,self.c_pool:])
        out = F.relu(self.bn2(out))
        out = self.conv2(out)
        out += shortcut
        return out

class PolicyHead(nn.Module):
    def __init__(self, in_channels, c_head=32):
        super(PolicyHead, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, c_head, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels, c_head, kernel_size=1, bias=False)
        self.GPB = GlobalPoolingBias(c_head, c_head)
        self.bn = nn.BatchNorm2d(c_head)

        self.final_conv1 = nn.Conv2d(c_head, 2, kernel_size=1, bias=False)      
        self.final_conv1_1 = nn.Conv2d(c_head, 2, kernel_size=1, bias=False)

        self.final_fc1 =  nn.Linear(2*9*9, 12)

        self.final_conv2 = nn.Conv2d(c_head, 2, kernel_size=1, bias=False)
        self.final_conv2_1 = nn.Conv2d(c_head, 2, kernel_size=1, bias=False)

        self.final_fc2 = nn.Linear(2*9*9, 12)

        #self.final_fc = nn.Linear(c_head, 2)
    def forward(self, x):
        P = self.conv1(x)
        G = self.conv2(x)
        out, bias = self.GPB(P,G)
        out = F.relu(self.bn(out))

        out1 = torch.nn.functional.interpolate(self.final_conv1(out),  size = (8,8), mode='bicubic', align_corners=True)
        out1 = nn.Flatten()(out1)  #vertical 64, horizontal 64  (2,9,9)
        out1_1 = self.final_fc1(nn.Flatten()(self.final_conv1_1(out)))  # move 12

        policy_pred = torch.cat([out1,out1_1 ],dim = 1)
        policy_pred = F.softmax(policy_pred, dim=1)

        out2 = torch.nn.functional.interpolate(self.final_conv2(out),  size = (8,8), mode='bicubic', align_corners=True)
        out2 = nn.Flatten()(out2)  #opp policy 
        out2_1 =self.final_fc2(nn.Flatten()(self.final_conv2_1(out)))  # move 12

        opp_pred = torch.cat([out2,out2_1 ],dim = 1)
        opp_pred = F.softmax(opp_pred, dim=1)
        
        #out2 = self.final_fc(bias)
        return policy_pred, opp_pred

class ValueHead(nn.Module):
    def __init__(self, in_channels, c_head=32, c_val=48):
        super(ValueHead, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, c_head, kernel_size=1, bias=False)
        self.GP = GlobalPooling()
        #game-outcome
        self.fc1 = nn.Linear(3*c_head, c_val)
        self.fc2 = nn.Linear(c_val, 1)
        
        #final score distribution
        self.fc3 = nn.Linear(3*c_head, c_val)
        self.scaling_component = nn.Linear(c_val, 1)
        self.possible_dist = [-20,-19,-18,-17,-16,-15,-14,-13,-12,-11,-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20]  #[-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        self.fc4 = nn.Linear(3*c_head+1, c_val)
        self.fc5 = nn.Linear(c_val, 1)

    def forward(self, x):
        V = self.conv1(x)
        b1,b2,b3= self.GP(V)
        V_pooled = torch.cat((b1,b2,b3),dim=1)
        game_out = F.relu(self.fc1(V_pooled))
        game_out = self.fc2(game_out)
        game_out = nn.Tanh()(game_out)

        #distance_diff
        scaling_factor = F.relu(self.fc3(V_pooled))
        scaling_factor = self.scaling_component(scaling_factor)
        logits = torch.empty(0).to(x.device)
        for d in self.possible_dist:
            V_concat = torch.cat((V_pooled, torch.full((x.shape[0],1), d).to(x.device)),dim=1)
            h = F.relu(self.fc4(V_concat))
            h = self.fc5(h)
            logits = torch.cat((logits, h),dim=1)
        logits = F.softmax(logits*scaling_factor,dim=1)
        return game_out, logits
        
class ResNetb5c64(nn.Module):
    def __init__(self, input_channels=16, start_channels=64, c_pool=16, c_head=16, c_val=32):
        super(ResNetb5c64, self).__init__()

        self.start_conv = nn.Conv2d(input_channels, start_channels, kernel_size=5, stride=1, padding=2, bias=False)

        self.ResBlock1 = PreActBlock(start_channels, start_channels, stride=1)
        self.GPBlock1 = GlobalPoolingBlock(start_channels, start_channels, c_pool)
        self.ResBlock2 = PreActBlock(start_channels, start_channels, stride=1)
        self.GPBlock2 = GlobalPoolingBlock(start_channels, start_channels, c_pool)
        self.ResBlock3 = PreActBlock(start_channels, start_channels, stride=1)

        self.last_bn = nn.BatchNorm2d(start_channels)
        self.P_head = PolicyHead(start_channels, c_head)
        self.V_head = ValueHead(start_channels, c_head, c_val)

    #def forward(self, piece_board, block_board, properties):
    def forward(self, board) :#properties):
        '''
        piece_board: 9x9x2
        block_board: 8x8x4 --> need to be upsampled
        properties: 6 scalars
        '''
        #board = torch.cat((piece_board, block_board),dim=1)
        out = self.start_conv(board)

        # n stack of residual block
        out = self.ResBlock1(out)
        out = self.GPBlock1(out)
        out = self.ResBlock2(out)
        out = self.GPBlock2(out)
        out = self.ResBlock3(out)

        out = F.relu(self.last_bn(out))

        policy_out = self.P_head(out)
        value_out = self.V_head(out)
        return policy_out, value_out

class ResNetb6c96(nn.Module):
    def __init__(self, input_channels=16, start_channels=96, c_pool=32, c_head=32, c_val=48):
        super(ResNetb6c96, self).__init__()

        self.start_conv = nn.Conv2d(input_channels, start_channels, kernel_size=5, stride=1, padding=2, bias=False)

        self.ResBlock1 = PreActBlock(start_channels, start_channels, stride=1)
        self.ResBlock2 = PreActBlock(start_channels, start_channels, stride=1)
        self.GPBlock1 = GlobalPoolingBlock(start_channels, start_channels, c_pool)
        self.ResBlock3 = PreActBlock(start_channels, start_channels, stride=1)
        self.ResBlock4 = PreActBlock(start_channels, start_channels, stride=1)
        self.GPBlock2 = GlobalPoolingBlock(start_channels, start_channels, c_pool)

        self.last_bn = nn.BatchNorm2d(start_channels)
        self.P_head = PolicyHead(start_channels, c_head)
        self.V_head = ValueHead(start_channels, c_head, c_val)

    #def forward(self, piece_board, block_board, properties):
    def forward(self, board):
        '''
        piece_board: 9x9x2
        block_board: 8x8x4 --> need to be upsampled
        properties: 6 scalars
        '''
        #board = torch.cat((piece_board, block_board),dim=1)
        out = self.start_conv(board)

        # n stack of residual block
        out = self.ResBlock1(out)
        out = self.ResBlock2(out)
        out = self.GPBlock1(out)
        out = self.ResBlock3(out)
        out = self.ResBlock4(out)
        out = self.GPBlock2(out)

        out = F.relu(self.last_bn(out))

        policy_out = self.P_head(out)
        value_out = self.V_head(out)
        return policy_out, value_out

class ResNetb8c128(nn.Module):
    def __init__(self, input_channels=16, start_channels=128, c_pool=32, c_head=32, c_val=64):
        super(ResNetb8c128, self).__init__()

        self.start_conv = nn.Conv2d(input_channels, start_channels, kernel_size=5, stride=1, padding=2, bias=False)

        self.ResBlock1 = PreActBlock(start_channels, start_channels, stride=1)
        self.ResBlock2 = PreActBlock(start_channels, start_channels, stride=1)
        self.ResBlock3 = PreActBlock(start_channels, start_channels, stride=1)
        self.GPBlock1 = GlobalPoolingBlock(start_channels, start_channels, c_pool)
        self.ResBlock4 = PreActBlock(start_channels, start_channels, stride=1)
        self.ResBlock5 = PreActBlock(start_channels, start_channels, stride=1)
        self.ResBlock6 = PreActBlock(start_channels, start_channels, stride=1)
        self.GPBlock2 = GlobalPoolingBlock(start_channels, start_channels, c_pool)

        self.last_bn = nn.BatchNorm2d(start_channels)
        self.P_head = PolicyHead(start_channels, c_head)
        self.V_head = ValueHead(start_channels, c_head, c_val)

    #def forward(self, piece_board, block_board, properties):
    def forward(self, board):
        '''
        piece_board: 9x9x2
        block_board: 8x8x4 --> need to be upsampled
        properties: 6 scalars
        '''
        #board = torch.cat((piece_board, block_board),dim=1)
        out = self.start_conv(board)

        # n stack of residual block
        out = self.ResBlock1(out)
        out = self.ResBlock2(out)
        out = self.ResBlock3(out)
        out = self.GPBlock1(out)
        out = self.ResBlock4(out)
        out = self.ResBlock5(out)
        out = self.ResBlock6(out)
        out = self.GPBlock2(out)

        out = F.relu(self.last_bn(out))

        policy_out = self.P_head(out)
        value_out = self.V_head(out)
        return policy_out, value_out

# cfg = [[5,64], [6,96], [8,128]      
#[5,64] [6, 96] [8, 128]
# 5       6       8
# 64      96      128
# 16      32      32
# 16      32      32
# 32      48      64
def getModel(B,C) :
    key = f"b{B}c{C}"
    models = {'b5c64':ResNetb5c64(), 'b6c96' :ResNetb6c96(), 'b8c128' : ResNetb8c128()}
    return models[key]

In [None]:
'''
YAI-quoridor/src/player.py
'''

import torch
import torch.nn as nn
import torch.nn.functional as F


class Player:

    def __init__(self, tparameters):
      self.game = Board()
      self.mcts ={0:MCTS(),1:MCTS()} 
      self.data = {0:[],1:[]}
      self.playout_num = 0
      self.playout_max = PLAYOUT_FULL if random.random()<0.25 else PLAYOUT
      self.currentMCTS = 0  
      self.currentBOARD_CNT = 0
      self.player_no = tparameters['player_no']
      tparameters['player_no']+= 1


    def toNNinput(self, board):
        NNinput = torch.zeros(1,16, 9, 9) 
        NNinput[0,0,board.a[0], board.a[1]] = 1
        NNinput[0,1,board.b[0], board.b[1]] = 1 
        curr_hor, curr_ver = torch.zeros(1,1,8,8), torch.zeros(1,1,8,8)
        curr_hor[0,0,:,:] = torch.from_numpy(np.array(board.blocks_horizon))      
        curr_ver[0,0,:,:] = torch.from_numpy(np.array(board.blocks_vertical))      
        NNinput[0,2,:,:] = F.interpolate(curr_hor, size = (9,9), mode='bicubic', align_corners=True)
        NNinput[0,3,:,:] = F.interpolate(curr_ver, size = (9,9), mode='bicubic', align_corners=True) 
        NNinput[0,8,:,:] = torch.full((9, 9), board.a_remain)                      
        NNinput[0,9,:,:] = torch.full((9, 9), board.b_remain)                      
        if board.count == 0 :   
            NNinput[0,4,:,:] = NNinput[0,0,:,:]
            NNinput[0,5,:,:] = NNinput[0,1,:,:]
            NNinput[0,6,:,:] = NNinput[0,2,:,:]
            NNinput[0,7,:,:] = NNinput[0,3,:,:]           
            NNinput[0,10,:,:] = NNinput[0,8,:,:]
            NNinput[0,11,:,:] = NNinput[0,9,:,:]
        else :
            action = board.undo()
            NNinput[0,4,board.a[0], board.a[1]] = 1   
            NNinput[0,5,board.b[0], board.b[1]] = 1  
            past_hor, past_ver = torch.zeros(1,1,8,8), torch.zeros(1,1,8,8)
            past_hor[0,0,:,:] = torch.from_numpy(np.array(board.blocks_horizon))    
            past_ver[0,0,:,:] = torch.from_numpy(np.array(board.blocks_vertical))  
            NNinput[0,6,:,:] = F.interpolate(past_hor, size = (9,9), mode='bicubic', align_corners=True) 
            NNinput[0,7,:,:] = F.interpolate(past_ver, size = (9,9), mode='bicubic', align_corners=True) 
            NNinput[0,10,:,:] = torch.full((9, 9), board.a_remain)                     
            NNinput[0,11,:,:] = torch.full((9, 9), board.b_remain)                        
            board.doAction(action)
        NNinput[0,12,:,:] = torch.full((9, 9), board.a_path_len)                  
        NNinput[0,13,:,:] = torch.full((9, 9), board.b_path_len)                   
        NNinput[0,14,:,:] = torch.full((9, 9), board.count)                  
        NNinput[0,15,:,:] = torch.full((9, 9), board.colour )                   
        return NNinput

    def playout(self):

      colour = self.currentMCTS
      if self.playout_num == 1:        
          assert (self.game.colour == self.currentMCTS), print(f"new playout started : {self.playout_num}/{self.playout_max}, player,{self.player_no}")
          if self.playout_max==PLAYOUT_FULL:    
              self.mcts[self.game.colour].makeNoise()
          elif self.currentBOARD_CNT <= 4 :
              self.mcts[self.game.colour].makeNoise()   ################### first r moves randomize 
      self.mcts[colour].playout(self.game) 
      self.playout_num += 1
      return self.toNNinput(self.game)


    def step(self, Parr, v):    
      colour = self.currentMCTS    
      self.mcts[colour].backprop(self.game,Parr,v) 

      if self.playout_num >= self.playout_max: 
          action, Narr = self.mcts[colour].getAction()
          print(f"Next action chosen// player_no {self.player_no},\t\t[board.count = {self.game.count} -> {self.game.count + 1}],\
          \t\t[playout_num = {self.playout_num}],\t\t[doaction = {action}],\t\tNarr : {Narr}")
          if self.playout_max==PLAYOUT_FULL:
              self.data[colour].append([self.toNNinput(self.game), Narr, None , None, None])
          self.game.doAction(action)
          self.mcts[0].update(action)
          self.mcts[1].update(action)
          self.playout_num = 0
          self.playout_max = PLAYOUT_FULL if random.random()<0.25 else PLAYOUT
          self.currentMCTS = 1 - self.currentMCTS
          self.currentBOARD_CNT += 1

      if self.game.gameend():
          print("====================================")
          print(f"player_no {self.player_no} reach gameend.  winner is {self.game.winner}")
          if any(self.data[0]) and any(self.data[1]) :
              print(f"collected len(data[0]):{len(self.data[0])},  len(data[1]):{len(self.data[1])}")
          print("   Current opponent State : ")
          self.game.printboard()
          for sample in self.data[0]:
              cnt = sample[0][0,14,0,0].int()       
              if cnt+1 == len(self.game.history) :     
                  del self.data[0][-1]
                  continue
              opp_action = self.game.history[cnt + 1] 
              opp = [0 for _ in range(140)]
              opp[opp_action] = 1
              sample[2] = opp
              sample[3] = 1-2*self.game.winner
              sample[4] = self.game.b_path_len - self.game.a_path_len
          for sample in self.data[1]:
              cnt = sample[0][0,14,0,0].int()         
              if cnt+1 == len(self.game.history) :
                  del self.data[1][-1]
                  continue              
              opp_action = self.game.history[cnt + 1] 
              opp = [0 for _ in range(140)]
              opp[opp_action] = 1
              sample[2] = opp
              sample[3] = -1+2*self.game.winner
              sample[4] = self.game.a_path_len - self.game.b_path_len



In [None]:
'''
YAI-quoridor/src/board.py
'''
import heapq
from queue import PriorityQueue

class AstarNode():

    def __init__(self, x, y, parent, direction):
        self.x = x
        self.y = y
        self.direction = direction
        self.parent = parent

    def __eq__(self, other):
        return self.x == other.x and self.y==other.y

    def __lt__(self, other):
        if self.x == other.x:
            return self.y < other.y
        else:
            return self.x < other.x

class Board:

    def __init__(self):
        #redo-able
        self.blocks_vertical = [[0 for i in range(8)] for i in range(8)]
        self.blocks_horizon = [[0 for i in range(8)] for i in range(8)]
        self.a = [8,4]
        self.b = [0,4]
        self.a_remain = 10
        self.b_remain = 10
        self.count = 0
        self.colour = 0
        self.winner = 0

        #redo-unable
        self.path_vertical = [[0 for i in range(8)] for i in range(8)]
        self.path_horizon = [[0 for i in range(8)] for i in range(8)]
        self.visited = [[0 for i in range(9)] for i in range(9)]
        self.a_path_len = 8
        self.b_path_len = 8

        self.history = []

    def blocks_vertical_get(self,x,y):
        if y <= -1 or y >= 8:
            return 1
        elif x <= -1 or x >= 8:
            return 0
        else:
            return self.blocks_vertical[x][y]

    def blocks_horizon_get(self,x,y):
        if x <= -1 or x >= 8:
            return 1
        elif y <= -1 or y >= 8:
            return 0
        else:
            return self.blocks_horizon[x][y]

    def blocks_vertical_record(self,x,y):
        if y <= -1 or y >= 8 or x <= -1 or x >= 8:
            return 
        else:
            self.path_vertical[x][y] = 1

    def blocks_horizon_record(self,x,y):
        if x <= -1 or x >= 8 or y <= -1 or y >= 8:
            return 
        else:
            self.path_horizon[x][y] = 1

    def dfs_a(self, x, y):
        self.visited[x][y] = 1
        if x == 0:
            return 0
        #위쪽
        if x != 0 and self.visited[x-1][y]==0 and self.blocks_horizon_get(x-1,y-1)==0 and self.blocks_horizon_get(x-1,y)==0:
            value = self.dfs_a(x-1, y)
            if value != -1:
                return value+1
        #오른쪽
        if y != 8 and self.visited[x][y+1]==0 and self.blocks_vertical_get(x,y)==0 and self.blocks_vertical_get(x-1,y)==0:
            value = self.dfs_a(x, y+1)
            if value != -1:
                return value+1
        #왼쪽
        if y != 0 and self.visited[x][y-1]==0 and self.blocks_vertical_get(x-1,y-1)==0 and self.blocks_vertical_get(x,y-1)==0:
            value = self.dfs_a(x, y-1)
            if value != -1:
                return value+1
        #아래쪽
        if x != 8 and self.visited[x+1][y]==0 and self.blocks_horizon_get(x,y)==0 and self.blocks_horizon_get(x,y-1)==0:
            value = self.dfs_a(x+1, y)
            if value != -1:
                return value+1
        return -1

    def dfs_b(self, x, y):
        self.visited[x][y] = 1
        if x == 8:
            return 0
        #아래쪽
        if x != 8 and self.visited[x+1][y]==0 and self.blocks_horizon_get(x,y)==0 and self.blocks_horizon_get(x,y-1)==0:
            value = self.dfs_b(x+1, y)
            if value != -1:
                return value+1
        #왼쪽
        if y != 0 and self.visited[x][y-1]==0 and self.blocks_vertical_get(x-1,y-1)==0 and self.blocks_vertical_get(x,y-1)==0:
            value = self.dfs_b(x, y-1)
            if value != -1:
                return value+1
        #오른쪽
        if y != 8 and self.visited[x][y+1]==0 and self.blocks_vertical_get(x,y)==0 and self.blocks_vertical_get(x-1,y)==0:
            value = self.dfs_b(x, y+1)
            if value != -1:
                return value+1
        #위쪽
        if x != 0 and self.visited[x-1][y]==0 and self.blocks_horizon_get(x-1,y-1)==0 and self.blocks_horizon_get(x-1,y)==0:
            value = self.dfs_b(x-1, y)
            if value != -1:
                return value+1
        return -1

    def dfs(self, colour):
        ret = -1
        if colour == 0:
            ret = self.dfs_a(self.a[0], self.a[1])
        elif colour == 1:
            ret = self.dfs_b(self.b[0], self.b[1])
        for i in range(9):
            for j in range(9):
                self.visited[i][j]=0
        return ret

    def Astar_a(self, x, y):
        ret = 0
        firstnode = AstarNode(x, y, None, -1)
        finalnode = None
        # q = PriorityQueue()
        # q.put((x,firstnode))
        q = []
        heapq.heappush(q, (x,firstnode))
        while True:
            h, node = heapq.heappop(q)
            #h, node = q.get()
            x = node.x
            y = node.y
            self.visited[x][y] = 1

            if x == 0:
                ret = h
                finalnode = node
                break

            #위쪽
            if x != 0 and self.visited[x-1][y]==0 and self.blocks_horizon_get(x-1,y-1)==0 and self.blocks_horizon_get(x-1,y)==0:
                # q.put((h, AstarNode(x-1, y, node, 0)))
                heapq.heappush(q, (h, AstarNode(x-1, y, node, 0)))
            #오른쪽
            if y != 8 and self.visited[x][y+1]==0 and self.blocks_vertical_get(x,y)==0 and self.blocks_vertical_get(x-1,y)==0:
                # q.put((h+1, AstarNode(x, y+1, node, 1)))
                heapq.heappush(q, (h+1, AstarNode(x, y+1, node, 1)))
            #왼쪽
            if y != 0 and self.visited[x][y-1]==0 and self.blocks_vertical_get(x-1,y-1)==0 and self.blocks_vertical_get(x,y-1)==0:
                # q.put((h+1, AstarNode(x, y-1, node, 2)))
                heapq.heappush(q, (h+1, AstarNode(x, y-1, node, 2)))
            #아래쪽
            if x != 8 and self.visited[x+1][y]==0 and self.blocks_horizon_get(x,y)==0 and self.blocks_horizon_get(x,y-1)==0:
                # q.put((h+2, AstarNode(x+1, y, node, 3)))
                heapq.heappush(q, (h+2, AstarNode(x+1, y, node, 3)))

        for i in range(9):
            for j in range(9):
                self.visited[i][j]=0

        node = finalnode
        while not node.parent is None:
            x = node.x
            y = node.y
            if node.direction == 3:
                self.blocks_horizon_record(x-1,y-1)
                self.blocks_horizon_record(x-1,y)
            elif node.direction == 2:
                self.blocks_vertical_record(x,y)
                self.blocks_vertical_record(x-1,y)
            elif node.direction == 1:
                self.blocks_vertical_record(x-1,y-1)
                self.blocks_vertical_record(x,y-1)
            elif node.direction == 0:
                self.blocks_horizon_record(x,y)
                self.blocks_horizon_record(x,y-1)

            node = node.parent

        return ret

    def Astar_b(self, x, y):
        #https://slowsure.tistory.com/130
        # use heapq instead of priorityqueue()
        ret = 0
        firstnode = AstarNode(x, y, None, -1)
        finalnode = None
        q = []
        heapq.heappush(q, (8-x,firstnode))
        #q = PriorityQueue()
        #q.put((8-x,firstnode))
        while True:
            h, node = heapq.heappop(q)
            #h, node = q.get()
            x = node.x
            y = node.y
            self.visited[x][y] = 1
            if x == 8:
                ret = h
                finalnode = node
                break

            #위쪽
            if x != 0 and self.visited[x-1][y]==0 and self.blocks_horizon_get(x-1,y-1)==0 and self.blocks_horizon_get(x-1,y)==0:
                #q.put((h+2, AstarNode(x-1, y, node, 0)))
                heapq.heappush(q, (h+2, AstarNode(x-1, y, node, 0)))
            #오른쪽
            if y != 8 and self.visited[x][y+1]==0 and self.blocks_vertical_get(x,y)==0 and self.blocks_vertical_get(x-1,y)==0:
                #q.put((h+1, AstarNode(x, y+1, node, 1)))
                heapq.heappush(q, (h+1, AstarNode(x, y+1, node, 1)) )
            #왼쪽
            if y != 0 and self.visited[x][y-1]==0 and self.blocks_vertical_get(x-1,y-1)==0 and self.blocks_vertical_get(x,y-1)==0:
                #q.put((h+1, AstarNode(x, y-1, node, 2)))
                heapq.heappush(q, (h+1, AstarNode(x, y-1, node, 2)) )
            #아래쪽
            if x != 8 and self.visited[x+1][y]==0 and self.blocks_horizon_get(x,y)==0 and self.blocks_horizon_get(x,y-1)==0:
                #q.put((h, AstarNode(x+1, y, node, 3)))
                heapq.heappush(q, (h, AstarNode(x+1, y, node, 3)))

        for i in range(9):
            for j in range(9):
                self.visited[i][j]=0

        node = finalnode
        while not node.parent is None:
            x = node.x
            y = node.y
            if node.direction == 3:
                self.blocks_horizon_record(x-1,y-1)
                self.blocks_horizon_record(x-1,y)
            elif node.direction == 2:
                self.blocks_vertical_record(x,y)
                self.blocks_vertical_record(x-1,y)
            elif node.direction == 1:
                self.blocks_vertical_record(x-1,y-1)
                self.blocks_vertical_record(x,y-1)
            elif node.direction == 0:
                self.blocks_horizon_record(x,y)
                self.blocks_horizon_record(x,y-1)

            node = node.parent

        return ret

    def simulate_vertical(self, x, y):
        self.blocks_vertical[x][y] = 1
        ret = self.dfs(0) != -1 and self.dfs(1) != -1
        self.blocks_vertical[x][y] = 0
        return ret

    def simulate_horizon(self, x, y):
        self.blocks_horizon[x][y] = 1
        ret = self.dfs(0) != -1 and self.dfs(1) != -1
        self.blocks_horizon[x][y] = 0
        return ret

    def reset_path(self):  # c로 구현하면서 memset으로 변경
        for i in range(8):
            for j in range(8):
                self.path_vertical[i][j] = 0
        for i in range(8):
            for j in range(8):
                self.path_horizon[i][j] = 0
        self.a_path_len = self.Astar_a(self.a[0], self.a[1])
        self.b_path_len = self.Astar_b(self.b[0], self.b[1])
        
    def getAvailableAction(self):
        actions = []

        if (self.colour == 0 and self.a_remain != 0) or (self.colour == 1 and self.b_remain != 0):

            #세로
            for i in range(8):
                for j in range(8):
                    if self.blocks_vertical[i][j]==1 or self.blocks_horizon[i][j]==1 or self.blocks_vertical_get(i-1,j)==1 or self.blocks_vertical_get(i+1,j)==1:
                        continue
                    if self.path_vertical[i][j]==1:
                        num = 0
                        if self.blocks_horizon_get(i-1,j-1)==1 or self.blocks_horizon_get(i-1,j)==1 or self.blocks_horizon_get(i-1,j+1)==1 or self.blocks_vertical_get(i-2,j)==1:
                            num += 1
                        if self.blocks_horizon_get(i+1,j-1)==1 or self.blocks_horizon_get(i+1,j)==1 or self.blocks_horizon_get(i+1,j+1)==1 or self.blocks_vertical_get(i+2,j)==1:
                            num += 1
                        if self.blocks_horizon_get(i,j-1)==1 or self.blocks_horizon_get(i,j+1)==1:
                            num += 1
                        if num >= 2 and not self.simulate_vertical(i,j):
                            continue
                    actions.append(8*i+j)

            #가로
            for i in range(8):
                for j in range(8):
                    if self.blocks_vertical[i][j]==1 or self.blocks_horizon[i][j]==1 or self.blocks_horizon_get(i,j-1)==1 or self.blocks_horizon_get(i,j+1)==1:
                        continue
                    if self.path_horizon[i][j]==1:
                        num = 0
                        if self.blocks_vertical_get(i-1,j-1)==1 or self.blocks_vertical_get(i,j-1)==1 or self.blocks_vertical_get(i+1,j-1)==1 or self.blocks_horizon_get(i,j-2)==1:
                            num += 1
                        if self.blocks_vertical_get(i-1,j+1)==1 or self.blocks_vertical_get(i,j+1)==1 or self.blocks_vertical_get(i+1,j+1)==1 or self.blocks_horizon_get(i,j+2)==1:
                            num += 1
                        if self.blocks_vertical_get(i-1,j)==1 or self.blocks_vertical_get(i+1,j)==1:
                            num += 1
                        if num >= 2 and not self.simulate_horizon(i,j):
                            continue
                    actions.append(64+8*i+j)

        player = self.a if self.colour == 0 else self.b
        x = player[0]
        y = player[1]
        opponent = self.a if self.colour == 1 else self.b
        #위쪽
        if self.blocks_horizon_get(x-1,y-1)==0 and self.blocks_horizon_get(x-1,y)==0:
            if [x-1,y] == opponent:
                if self.blocks_horizon_get(x-2,y-1)==0 and self.blocks_horizon_get(x-2,y)==0:
                    actions.append(132)
                elif self.blocks_vertical_get(x-1,y-1)==0 and self.blocks_vertical_get(x-2,y-1)==0:
                    actions.append(137)
                elif self.blocks_vertical_get(x-1,y)==0 and self.blocks_vertical_get(x-2,y)==0:
                    actions.append(136)
            else:
                actions.append(128)
        #오른쪽
        if self.blocks_vertical_get(x,y)==0 and self.blocks_vertical_get(x-1,y)==0:
            if [x,y+1] == opponent:
                if self.blocks_vertical_get(x-1,y+1)==0 and self.blocks_vertical_get(x,y+1)==0:
                    actions.append(133)
                elif self.blocks_horizon_get(x-1,y+1)==0 and self.blocks_horizon_get(x-1,y)==0:
                    actions.append(136)
                elif self.blocks_horizon_get(x,y+1)==0 and self.blocks_horizon_get(x,y)==0:
                    actions.append(138)
            else:
                actions.append(129)
        #왼쪽
        if self.blocks_vertical_get(x-1,y-1)==0 and self.blocks_vertical_get(x,y-1)==0:
            if [x,y-1] == opponent:
                if self.blocks_vertical_get(x-1,y-2)==0 and self.blocks_vertical_get(x,y-2)==0:
                    actions.append(134)
                elif self.blocks_horizon_get(x-1,y-1)==0 and self.blocks_horizon_get(x-1,y-2)==0:
                    actions.append(137)
                elif self.blocks_horizon_get(x,y-2)==0 and self.blocks_horizon_get(x,y-1)==0:
                    actions.append(139)
            else:
                actions.append(130)
        #아래쪽
        if self.blocks_horizon_get(x,y)==0 and self.blocks_horizon_get(x,y-1)==0:
            if [x+1,y] == opponent:
                if self.blocks_horizon_get(x+1,y-1)==0 and self.blocks_horizon_get(x+1,y)==0:
                    actions.append(135)
                elif self.blocks_vertical_get(x,y-1)==0 and self.blocks_vertical_get(x+1,y-1)==0:
                    actions.append(139)
                elif self.blocks_vertical_get(x,y)==0 and self.blocks_vertical_get(x+1,y)==0:
                    actions.append(138)
            else:
                actions.append(131)

        return actions

    def doAction(self, num):

        if num < 64:
            x = num // 8
            y = num % 8
            self.blocks_vertical[x][y] = 1
            if self.colour == 0:
                self.a_remain -= 1
            else:
                self.b_remain -= 1
        elif num < 128:
            x = (num-64) // 8
            y = (num-64) % 8
            self.blocks_horizon[x][y] = 1
            if self.colour == 0:
                self.a_remain -= 1
            else:
                self.b_remain -= 1
        else:
            player = self.a if self.colour == 0 else self.b

            if num == 128:
                player[0] -= 1
            elif num == 129:
                player[1] += 1
            elif num == 130:
                player[1] -= 1
            elif num == 131:
                player[0] += 1
            elif num == 132:
                player[0] -= 2
            elif num == 133:
                player[1] += 2
            elif num == 134:
                player[1] -= 2
            elif num == 135:
                player[0] += 2
            elif num == 136:
                player[0] -= 1
                player[1] += 1
            elif num == 137:
                player[0] -= 1
                player[1] -= 1
            elif num == 138:
                player[0] += 1
                player[1] += 1
            elif num == 139:
                player[0] += 1
                player[1] -= 1

        self.colour = 1 - self.colour
        self.count += 1

        self.history.append(num)

    def undo(self):
        self.winner = 0
        self.colour = 1 - self.colour
        self.count -= 1
        
        num = self.history.pop()

        if num < 64:
            x = num // 8
            y = num % 8
            self.blocks_vertical[x][y] = 0
            if self.colour == 0:
                self.a_remain += 1
            else:
                self.b_remain += 1
            return num
        elif num < 128:
            x = (num-64) // 8
            y = (num-64) % 8
            self.blocks_horizon[x][y] = 0
            if self.colour == 0:
                self.a_remain += 1
            else:
                self.b_remain += 1
            return num

        player = self.a if self.colour == 0 else self.b

        if num == 128:
            player[0] += 1
        elif num == 129:
            player[1] -= 1
        elif num == 130:
            player[1] += 1
        elif num == 131:
            player[0] -= 1
        elif num == 132:
            player[0] += 2
        elif num == 133:
            player[1] -= 2
        elif num == 134:
            player[1] += 2
        elif num == 135:
            player[0] -= 2
        elif num == 136:
            player[0] += 1
            player[1] -= 1
        elif num == 137:
            player[0] += 1
            player[1] += 1
        elif num == 138:
            player[0] -= 1
            player[1] -= 1
        elif num == 139:
            player[0] -= 1
            player[1] += 1

        return num


    def printboard(self):
        print()
        print("   0 1 2 3 4 5 6 7 8 ")
        print("  · · · · · · · · · ·")
        for i in range(17):
            if i % 2 == 0:
                print(i // 2, end='  ')
                for j in range(17):
                    if j % 2 == 0:
                        if self.a == [i//2,j//2]:
                            print("a", end='')
                        elif self.b == [i//2,j//2]:
                            print("b", end='')
                        else:
                            print(" ", end='')
                    else:
                        if (i != 16 and self.blocks_vertical[i//2][j//2] == 1) or ( i != 0 and self.blocks_vertical[i//2 - 1][j//2] == 1):
                            print("┃", end='')
                        else:
                            print(" ", end='')
                print(" ")
            else:
                print(" ", end=' ·')
                for j in range(17):
                    if j % 2 == 0:
                        if (j != 16 and self.blocks_horizon[i//2][j//2] == 1) or ( j != 0 and self.blocks_horizon[i//2][j//2 - 1] == 1):
                            print("━", end='')
                        else:
                            print(" ", end='')
                    else:
                        if self.blocks_vertical[i//2][j//2] == 1:
                            print("┃", end='')
                        elif self.blocks_horizon[i//2][j//2] == 1:
                            print("━", end='')
                        else:
                            print("·", end='')

                print("·")
                        
                
        print("  · · · · · · · · · ·")
        print(f"colour is {self.colour}, count is {self.count}, a_remain is {self.a_remain}, b_remain is {self.b_remain}")
        print()

    def gameend(self):
        if self.a[0] == 0:
            self.winner = 0
            return True
        elif self.b[0] == 8:
            self.winner = 1
            return True
        return False

In [None]:
'''
YAI-quoridor/src/MCTS.py
'''

import random
import math
import numpy as np
# from board import Board

CPUCT = 5
DIR_ALPHA = 0.15
EPSILON = 0.3
TEMP = 0.8
FPU = 0.2
FORCE_K = 2

class Node:   

    def __init__(self, parent, Q, P):
        self.parent = parent
        self.childs = {}
        self.expanded = False
        self.N = 0
        self.N_forced = 0
        self.Q = Q
        self.P = P

    def expand(self, availables, Ps):
        self.expanded = True
        for action in availables:
            child = Node(self, -self.Q-FPU, Ps[action])
            self.childs[action] = child

    def makeNoise(self):
        noise = np.random.dirichlet(DIR_ALPHA*np.ones(140))
        for action, child in self.childs.items():
            child.P = (1-EPSILON)*child.P+EPSILON*noise[action]

    def update(self, v):
        if self.parent != None:
            self.parent.update(-v)
        self.N += 1
        self.Q += (v-self.Q)/self.N
    

    def select(self, forced_playout):
        bestpuct = -100000
        bestaction = -1
        for action, child in self.childs.items():
            if forced_playout and child.N != 0 and child.N<math.sqrt(FORCE_K*child.P*self.N):   # comment 1
                bestaction = action
                child.N_forced += 1
                break
            else:
                puct = -child.Q + CPUCT*child.P*math.sqrt(self.N)/(child.N+1)
                if puct>bestpuct:
                    bestpuct = puct
                    bestaction = action
        return bestaction, self.childs[bestaction]


class MCTS:

    def __init__(self):
        self.rootnode = Node(None, 0, 0)
        self.searchingnode = self.rootnode
        self.searching_depth = 0

        self.getAction_cnt = 0
        self.TEMP = TEMP      ### 수정

    def makeNoise(self):
        self.rootnode.makeNoise()
        
    def playout(self, board):
        self.searchingnode = self.rootnode
        while self.searchingnode.expanded:
            action, node = self.searchingnode.select(self.searching_depth == 0)
            board.doAction(action)
            self.searchingnode = node
            self.searching_depth += 1
        board.reset_path()

    def backprop(self, board, Ps, V):
        availables = board.getAvailableAction()
        v = V

        if board.gameend():          
            if board.winner == board.colour:
                v = 1
            else:
                v = -1
        else:
            self.searchingnode.expand(availables, Ps)

        self.searchingnode.update(v)                            
        for _ in range(self.searching_depth):
            board.undo()
        self.searching_depth = 0


    def getAction(self, deterministic=False):
        self.getAction_cnt += 1
        bestN = -1
        bestQ = -1000000
        bestaction = -1


        # [Policy Target Pruning step]          
        # [1902.10565] Accelerating Self-Play Learning in Go - arXiv   
        for action, child in self.rootnode.childs.items():
            if child.N > bestN:
                bestN = child.N
                bestaction = action
        bestchild = self.rootnode.childs[bestaction]
        bestpuct = (-bestchild.Q) +CPUCT*bestchild.P*math.sqrt(self.rootnode.N)/(1+bestchild.N)
        for action, child in self.rootnode.childs.items():
            if child != bestchild :
                for i in range(child.N_forced):
                    if (-child.Q) + CPUCT*child.P*math.sqrt(self.rootnode.N)/(1+child.N) >= bestpuct:
                        child.N += 1  
                        break
                    child.N -= 1
                    self.rootnode.N -= 1 
        selectaction = bestaction

        # Training Temperature               
        if not deterministic:
            if self.getAction_cnt == 1 :
                self.TEMP = 0.8
            elif self.getAction_cnt % 9 == 0 :
                self.TEMP /= 2
                if self.TEMP <0.2 :
                    self.TEMP = 0.2

        # Real Game Temperature (gating)       
        elif deterministic :
            if self.getAction_cnt == 1 :
                self.TEMP = 0.5
            elif self.getAction_cnt % 9 == 0 :
                self.TEMP /= 2
                if self.TEMP < 0.1 :  
                    self.TEMP = 0.1
        Narr = [0 for _ in range(140)]
        sum_N = 0
        for action, child in self.rootnode.childs.items():
            Narr[action] = math.pow(child.N,1/TEMP)        
            sum_N += Narr[action]
        for action, child in self.rootnode.childs.items():
            Narr[action] = Narr[action]/sum_N


        if not deterministic :
            randnum = random.random()
            sum_ = 0.0
            for action, child in self.rootnode.childs.items():
                sum_ += Narr[action]
                if sum_ > randnum:
                      selectaction = action
                      break

        self.rootnode.parent = None

        return selectaction, Narr       


    def update(self, action):
        if not self.rootnode.expanded or not (action in self.rootnode.childs):    # comment 1
            child = Node(None, -self.rootnode.Q, 0)                                               
            self.rootnode = child
        else:
            self.rootnode = self.rootnode.childs[action]    
            self.rootnode.parent = None

    def visualize(self):
        Ns = [0 for i in range(140)]
        for action, child in self.rootnode.childs.items():
            Ns[action] = child.N

        print("Visualize N_arr")
        print()
        print("vertical")
        print("    0    1    2    3    4    5    6    7    8 ")
        print("  ┌────────────────────────────────────────────┐")
        for i in range(17):
            if i % 2 == 0:
                print(i // 2, end=' │')
                for j in range(17):
                    if j % 2 == 0:
                        if i < 16 and j < 16:
                            print("{:^4}".format(Ns[8*(i//2)+(j//2)]), end='')
                        else:
                            print("    ", end='')
                    else:
                        print(" ", end='')
                print("│")
            else:
                print(" ", end=' │   ')
                for j in range(17):
                    if j % 2 == 0:
                        print(" ", end='')
                    else:
                        print("┼   ", end='')

                print("│")
                        
        print("  └────────────────────────────────────────────┘")

        print("horizon")
        print("    0    1    2    3    4    5    6    7    8 ")
        print("  ┌────────────────────────────────────────────┐")
        for i in range(17):
            if i % 2 == 0:
                print(i // 2, end=' │')
                for j in range(17):
                    if j % 2 == 0:
                        if i < 16 and j < 16:
                            print("{:^4}".format(Ns[64+8*(i//2)+(j//2)]), end='')
                        else:
                            print("    ", end='')
                    else:
                        print(" ", end='')
                print("│")
            else:
                print(" ", end=' │   ')
                for j in range(17):
                    if j % 2 == 0:
                        print(" ", end='')
                    else:
                        print("┼   ", end='')

                print("│")
                        
        print("  └────────────────────────────────────────────┘")
        print()
        print(f" ↑:{Ns[128]},  →:{Ns[129]},  ←:{Ns[130]},  ↓:{Ns[131]}")
        print(f"2↑:{Ns[132]}, 2→:{Ns[133]}, 2←:{Ns[134]}, 2↓:{Ns[135]}")
        print(f"↗:{Ns[136]}, ↖:{Ns[137]}, ↘:{Ns[138]}, ↙:{Ns[139]}")
        print()

    def visualize_Q(self):

        print("visualize movement's  Narr and Q")
        Ns = [0 for i in range(140)]
        for action, child in self.rootnode.childs.items():
            Ns[action] = child.N
        Qs = [0 for i in range(140)]
        for action, child in self.rootnode.childs.items():
            Qs[action] = child.Q
        print("child.N")
        print(f" ↑:{Ns[128]},  →:{Ns[129]},  ←:{Ns[130]},  ↓:{Ns[131]}")
        print(f"2↑:{Ns[132]}, 2→:{Ns[133]}, 2←:{Ns[134]}, 2↓:{Ns[135]}")
        print(f"↗:{Ns[136]}, ↖:{Ns[137]}, ↘:{Ns[138]}, ↙:{Ns[139]}")
        print()
        print("child.Q")
        print(f" ↑:{Qs[128]},  →:{Qs[129]},  ←:{Qs[130]},  ↓:{Qs[131]}")
        print(f"2↑:{Qs[132]}, 2→:{Qs[133]}, 2←:{Qs[134]}, 2↓:{Qs[135]}")
        print(f"↗:{Qs[136]}, ↖:{Qs[137]}, ↘:{Qs[138]}, ↙:{Qs[139]}")
        print()


In [None]:
from torch.utils.data import DataLoader
from torch import nn
from math import floor
from math import ceil
from tqdm import tqdm
# from src import Model

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


class Board_Dataset(torch.utils.data.Dataset):
    def __init__(self, data_root = './data',  hparameters= None, transforms_=None):

        self.root_folder = sorted(glob.glob(data_root + '/*.pkl'))     # ['./data/package-00001[1500,b6c96].pkl', './data/package-00002[1500,b6c96].pkl',]
        self.Replay_buffer = []

         
        ################# 최근 데이터 250000개 불러오기 ###########
        if len(self.root_folder) != 0 :
            count = len(self.root_folder) -1
            samples_sum = 0
            while(True) :
                data_idx, samples_num, _, _ =  re.findall(r'\d+',self.root_folder[count])
                samples_sum += int(samples_num)
                with open(self.root_folder[count], 'rb') as f : 
                    data = pickle.load(f) 
                    self.Replay_buffer += data             
                if samples_sum >= 250000 :
                    break
                count -= 1
                if count < 0 :
                    break
        ###########################################################

        self.memory_s = []
        self.memory_pi = []
        self.memory_opp = []
        self.memory_z = []
        self.memory_score = []
        for idx, sample in enumerate(self.Replay_buffer) :
            self.memory_s.append(sample[0])
            self.memory_pi.append(sample[1])
            self.memory_opp.append(sample[2])
            self.memory_z.append(sample[3])
            self.memory_score.append(sample[4])
        print("=========================================================================")
        print("board_dataset __init__ , data len = ", len(self.memory_s),"samples")
        print("=========================================================================")

    def __getitem__(self, index):
        
        s = self.memory_s[index]
        squez_s = s.squeeze(0)

        z = self.memory_z[index]
        one_hot_z = 1 if z == 1 else 0

        score = self.memory_score[index]
        if score > 20 :
            score = 20
        elif score < -20 :
            score = -20
        score += 20
        one_hot_score = torch.zeros(41)
        one_hot_score[score] = 1

        
        return {'s' : squez_s, \
                'p' : torch.Tensor(self.memory_pi[index]), \
                'opp' : torch.Tensor(self.memory_opp[index]), \
                'z' : one_hot_z, \
                'score' : one_hot_score}  

    def __len__(self):
        return len(self.memory_s)


class Trainer:

    def __init__(self):
      self.dualnet = TPARAMS['dual_net']
      self.model_ver = TPARAMS['model_ver']
      self.optimizer = TPARAMS['optimizer']

      print("=========================================================================")
      print(f"trainer __init__, model ver = ", self.model_ver)
      print("=========================================================================")

    def train(self, B, C):
      # ./model 폴더로부터 최신 모델을 가져와서, 데이터 최근 25만개로 학습
      self.dualnet = self.load_checkpoint().to(DEVICE)
      self.dualnet.train()
      self.optimizer = optim.SGD(self.dualnet.parameters(), \
                                                  lr=HPARAMS['LR'], momentum=HPARAMS['MOMENTUM_DECAY'], weight_decay=HPARAMS['C_L2'])
      TPARAMS['train_dataset'] = Board_Dataset('./data', hparameters = HPARAMS)
      TPARAMS['train_loader'] = torch.utils.data.DataLoader(TPARAMS['train_dataset'],
                                        batch_size=HPARAMS['BATCH_SIZE'], 
                                        shuffle=True,
                                        drop_last=True)
      batch_size = HPARAMS['BATCH_SIZE']
      C_g = HPARAMS['C_g']
      W_opp = HPARAMS['W_OPP']
      W_spdf = HPARAMS['W_SPDF']
      W_scdf = HPARAMS['W_SCDF']
      W_sbreg = HPARAMS['W_SBREG'] 
      W_scale = HPARAMS['W_SCALE']

      n_epoches = min(ceil(250000/len(TPARAMS['train_dataset'])) , 10 )
      print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
      print(f"training start. training dataset len: {len(TPARAMS['train_dataset'])}samples, Total Epoches : {n_epoches}" )
      print("If dataset >= 250000 samples, Total Epoches = 1  Else  Total Epoches = min[ 10, 250000/len(dataset) ]")
      print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
      training_start_time = time.time()
      result = {}
      result['loss_sum'] = 0
      for epoch in range(n_epoches):
          print(f"Processing epoch {epoch+1}/{n_epoches}.. ")
          for i_batch, batch in enumerate(tqdm(TPARAMS['train_loader'])): 
              self.optimizer.zero_grad()
              
              s =   batch['s'].to(DEVICE)
              p =   batch['p'].to(DEVICE)     
              opp = batch['opp'].to(DEVICE)   
              z =   batch['z'].unsqueeze(-1).to(DEVICE)               
              score = batch['score'].to(DEVICE)

              

              ps_, vs_ = self.dualnet(s)
              p_, opp_  = ps_[0], ps_[1]        
              v_, score_ = vs_[0], vs_[1]    
              v_ = nn.Sigmoid()(v_)     

              # Game outcome value loss, Policy loss:
              Zloss = C_g * (-torch.sum(torch.log(v_)*z)/batch_size) 
              Ploss = -torch.sum(torch.log(p_)*p)/batch_size  


              # Auxiliary Policy Targets loss (Opponent policy loss) :
              OPPloss = W_opp * (-torch.sum(torch.log(opp_)*opp)/batch_size) 


              # Score belief loss (“pdf”):
              SBPloss = W_spdf * (-torch.sum(torch.log(score_)*score)/batch_size) 


              # Score belief loss (“cdf”):
              SBCloss = 0
              for CDF_x in range(score.shape[1]):  # P1.shape = 5,  CDF_x = 0,1,2,3,4  마지막거는 어차피 0
                  SBCloss += torch.pow(torch.sum(score[:,:CDF_x+1].clone()- score_[:,:CDF_x+1],dim= 1,keepdim=True), 2.0)
              SBCloss = W_scdf * (torch.sum(SBCloss)/batch_size)


              # Score belief mean self-prediction:
              HUMloss = 0
              M_p = 0
              M_phat = 0
              for PDF_x in range(score.shape[1]):  #0~40
                  M_p += score[:,PDF_x].clone() * (PDF_x - 20)  # -20~20
                  M_phat += score_[:,PDF_x] * (PDF_x - 20)
              HUMloss = nn.HuberLoss(reduction='mean', delta=10.0)(M_p.clone(), M_phat)
              HUMloss = W_sbreg * HUMloss 


              # Score belief standard deviation self-prediction:
              HUSloss = 0
              SIGMA_p = 0
              SIGMA_phat = 0
              for PDF_x in range(score.shape[1]): 
                  SIGMA_p += score[:,PDF_x].clone() * torch.pow((PDF_x - 20  - M_p.clone()), 2.0)  # always zero
                  SIGMA_phat += score_[:,PDF_x] * torch.pow((PDF_x - 20  - M_phat), 2.0)
              SIGMA_phat = torch.sqrt(SIGMA_phat)
              HUSloss = nn.HuberLoss(reduction='mean', delta=10.0)(SIGMA_p.clone(), SIGMA_phat) 
              HUSloss = W_sbreg * HUSloss 


              # Score belief scaling penalty:
              V_head_scaling_params = []
              Params = TPARAMS['dual_net'].V_head.scaling_component.parameters()
              for param in Params:
                  V_head_scaling_params.append(param.view(-1))
              Params = torch.cat(V_head_scaling_params)
              Penalty_Score = W_scale * (torch.sum(torch.pow(Params,2.0)))

              #print(f"Zloss{Zloss} Ploss{Ploss} OPPloss{OPPloss} SBPloss{SBPloss} SBCloss{SBCloss} HUMloss{HUMloss} HUSloss{HUSloss} Penalty_Score{Penalty_Score}")
              loss = Zloss + Ploss + OPPloss + SBPloss + SBCloss + HUMloss + HUSloss + Penalty_Score
              loss.backward()
              result['loss_sum'] += loss.item()
              self.optimizer.step()
              

      result['loss_sum'] = result['loss_sum']/len(TPARAMS['train_loader'])
      result['Zloss'] = Zloss
      result['Ploss'] = Ploss
      result['OPPloss'] = OPPloss
      result['SBPloss'] = SBPloss
      result['SBCloss'] = SBCloss
      result['HUMloss'] = HUMloss
      result['HUSloss'] = HUSloss
      result['Penalty_Score'] = Penalty_Score

      self.dualnet.eval()
      print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
      print(f"training ended. ")
      print(f"    training time :{floor((time.time() - training_start_time)/60)} min {round(time.time() - training_start_time) % 60} secs")
      return result

    def save_checkpoint(self, result = None, folder = "./models", filename = None) :
        recent_package_idx = self.check_data()
        name = filename
        if filename is None :
            name = f"dual_net[{TPARAMS['B']},{TPARAMS['C']}]-{TPARAMS['model_ver']:05d}.pth"
        torch.save({'model_state_dict': TPARAMS['dual_net'].state_dict(),
                    'optimizer_state_dict': TPARAMS['optimizer'].state_dict(),
                    'recent_package_idx': recent_package_idx,
                    'Training Loss log' : result,
                    'HPARAMS' : HPARAMS,
                    'comment' : f"ResNetb{TPARAMS['B']}c{TPARAMS['C']}",  #"ResNetb5c64"
                    }, f"{folder}/{name}")
        
        print(f"    model saved   {folder}/{name}")
        TPARAMS['model_ver'] += 1
        return TPARAMS['model_ver']

    def load_checkpoint(self, folder = "./models", filename = None) :
        recent_model = max(glob.glob(f'{folder}/*.pth'))       # ['./models/dual_net[6,96]-0001.pth', './models/dual_net[6,96]-0002.pth',]
        B, C, ver =  re.findall(r'\d+',recent_model)         

        print(f"LOADING the most recent model checkpoint.. dual_net[b{B}c{C}]-{ver}")
        model = getModel(int(B),int(C))
        model.load_state_dict( torch.load(recent_model)['model_state_dict'] )    
        return model

    def update(self,B,C):
        result = self.train(B,C)
        print(f"Updating model...")
        new_model_ver = self.save_checkpoint(result)
        print(f"   model Version updated : {new_model_ver-1} -> {new_model_ver}")
        print(result)
        print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 
        

    def check_data(self,):
        data_list = glob.glob('./data/*.pkl')                             # ['./data/package-00001[1500,b6c96].pkl', './data/package-00002[1500,b6c96].pkl',]
        recent_package_idx = 0
        if any(data_list) :
            recent_package_idx, _, _, _ = re.findall(r'\d+', max(data_list))
        return recent_package_idx
      
    def step(self,B,C):
      recent_package_idx = self.check_data()
      print("current data idx:", recent_package_idx)
      if (int(recent_package_idx)+1) % 250 == 0 : # 대략25만개마다 업데이트 
          self.update(B,C)
      else :
          print(f"not enough data. next model update is when (current data idx)%250 ==0")



In [None]:
import torch
import os
import time
import datetime
from math import floor
from torch import nn
from torch import optim
from torchsummary import summary
# from src import board
# from src import mcts
import numpy as np
import glob
import re
import pickle
from multiprocessing import Process, Queue

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"    
NOTES = 'Qzero'
START_DATE = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
HPARAMS = {
    #player
    'PLAYOUT_N_FULL' : 600,
    'PLAYOUT_n_SMALL' : 100, 
    #trainer
    'GAME_BATCH_NUM' : 100,
    'BATCH_SIZE': 256,
    'INFERENCE_BATCH_NUM' : 16,      ##############player num
    'DATA_TRAIN_NUM': 250000,
    'DATASET_WINDOW' : 250000,
    #optimizer
    'MOMENTUM_DECAY': 0.9,
    'LR': 6e-5,
    'C_L2'  : 3e-5,       
    #loss
    'C_g'   : 1.5,
    'W_OPP' : 0.15,
    'W_SPDF' :0.02,
    'W_SCDF' :0.02,
    'W_SBREG':0.004,
    'W_SCALE' : 0.0005,
    #gating
    'GATING_weight_average_snapshots_no' : 4,
    'GATING_weight_average_decay' : 0.75,
    #'NUM_WORKERS': 16,  
}
PLAYOUT_FULL = HPARAMS['PLAYOUT_N_FULL']
PLAYOUT = HPARAMS['PLAYOUT_n_SMALL']

TPARAMS = {
    'training_date' : START_DATE,
    'model_ver' : None,
    'B' : None,
    'C' : None,    
    'dual_net' : None,
    'models' : None,

    'optimizer' : None,
    'train_dataset' : None,
    'train_loader' : None,
    'trainer' : None,

    'games_count' : None,
    'player_no' : 0 ,
}

def save_data(data,B,C,ver):
    data_list = glob.glob('./data/*.pkl')                            # ['./data/package-00001[1500,b6c96].pkl', './data/package-00002[1500,b6c96].pkl',]
    recent_package_idx = 0
    if any(data_list) :
        recent_package_idx, _, _, _ = re.findall(r'\d+', max(data_list))
    new_idx = int(recent_package_idx) + 1

    if not any(data) :
        print("Fatal Error, databuffer is empty list !")
        exit(0)
    samples_num = len(data)

    model_type = f"b{B}c{C}"
    filename = f'./data/package-{new_idx:05d}[{samples_num},{model_type}].pkl'
    with open(filename,'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
    
    print(f"Saved data : {filename}, made by resnet{model_type}-{ver}")
    return new_idx

def main():
    TPARAMS['models'] = sorted(glob.glob('./models/*.pth'))       # ['./models/dual_net[6,96]-0001.pth', './models/dual_net[6,96]-0002.pth',]
    B, C, ver =  re.findall(r'\d+',TPARAMS['models'][-1])         
    TPARAMS['B'], TPARAMS['C'], TPARAMS['model_ver'] = int(B), int(C), int(ver)

    TPARAMS['dual_net'] = getModel(B,C)
    TPARAMS['dual_net'].load_state_dict( torch.load(TPARAMS['models'][-1])['model_state_dict'] )    
    TPARAMS['dual_net'] = TPARAMS['dual_net'].to(DEVICE)
    TPARAMS['dual_net'].eval()

    TPARAMS['trainer'] = Trainer()
    TPARAMS['optimizer'] = optim.SGD(TPARAMS['dual_net'].parameters(), \
                                                      lr=HPARAMS['LR'], momentum=HPARAMS['MOMENTUM_DECAY'], weight_decay=HPARAMS['C_L2'])

    start_time = time.time()
    inference_total_time = 0

    playoutCount = 0                          # total # of playouts
    data_buffer = []                          # Replay Buffer    
    TPARAMS['games_count'] = 0                # total number of games played
    TPARAMS['players'] = [Player(TPARAMS) for i in range(HPARAMS['INFERENCE_BATCH_NUM'])]   
    print("=========================================================================")
    print(f"selfplay init. dual_net[{B},{C}], model ver: {ver}, {HPARAMS['INFERENCE_BATCH_NUM']} players,  {DEVICE}")
    print("=========================================================================")
    while True :
        batch = [player.playout() for player in TPARAMS['players']]
        batch = torch.cat(batch, dim=0).float().to(DEVICE)
        with torch.no_grad():
            delta_time = time.time()
            output1, output2 = TPARAMS['dual_net'](batch)
            p, opp = output1
            v, score = output2
        inference_total_time += time.time()-delta_time
        Ps = [p[i].cpu().numpy() for i in range(HPARAMS['INFERENCE_BATCH_NUM'])]
        Vs = [v[i].item() for i in range(HPARAMS['INFERENCE_BATCH_NUM'])]

        for i, player in enumerate(TPARAMS['players']):
            playoutCount += 1
            player.step(Ps[i],Vs[i])

            if player.game.count>=150 :
                print("====================================")
                print(f"player_no {player.player_no} exceed MAX_EPISODE_DURATION.  Abort game")
                print("   Current opponent State : ")
                player.game.printboard()
                TPARAMS['players'][i] = Player(TPARAMS)
                continue 
          
            if player.game.gameend():
                TPARAMS['games_count'] += 1 ;
                data_buffer += player.data[1]
                data_buffer += player.data[0]
                TPARAMS['players'][i] = Player(TPARAMS)

                print("==============gameend===============  gamecount = ",TPARAMS['games_count'] )
                elapsed_total_time = time.time() - start_time
                MCTS_time = round(elapsed_total_time - inference_total_time, 2)
                average_game_time = round(elapsed_total_time / TPARAMS['games_count'])                
                if any(data_buffer): 
                    print(f"data saved --> Replay_Buffer now has {len(data_buffer)} samples, Replay_Buffer[-1] =  [{data_buffer[-1][0].shape}, {len(data_buffer[-1][1])}, {len(data_buffer[-1][2])}, {data_buffer[-1][3]}, {data_buffer[-1][4]}]")
                else :
                    print(f"fatal error : no data ")
                    break
                print(f"elapsed total time : {floor(elapsed_total_time/60)} min {round(elapsed_total_time) % 60} secs / {TPARAMS['games_count']} games   =  average {floor(average_game_time/60)}min {round(average_game_time)%60} sec per game")
                print(f"MCTS sampling time : {floor(MCTS_time/60)} min {round(MCTS_time) % 60} secs     NeuralNet inference time : {floor(inference_total_time/60)} min {round(inference_total_time) % 60} secs")
                print("====================================")

                # 1000 샘플마다 데이터저장
                if len(data_buffer) >= 1000 :       
                    dataver = save_data(data_buffer, B, C, ver)
                    print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")   
                    print("===========================1000 samples complete, saving data... ===================================")
                    print(f"{TPARAMS['games_count']} games finished,  elapsed time :{floor((time.time() - start_time)/60)} min {round((time.time() - start_time)) % 60} secs")
                    print(f"Sampled data saved.  data len : {len(data_buffer)}, data ver : {dataver}, model ver : {ver}")
                    print("====================================================================================================")
                    print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")                  
                    data_buffer.clear()  #del data_buffer[:]

        # player들 총합 게임 1000판마다 훈련 돌릴만큼 데이터 쌓였는지 폴더 체크 후, 1000판째 끝낸 플레이어가 훈련 시작    
        if (TPARAMS['games_count']+1) % 1000 ==0 :
              TPARAMS['trainer'].step(int(B), int(C))

if __name__ == "__main__":
    main()