In [1]:
# import numpy 
import numpy as np
import math
#Deep Learning Neural Network Model
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Creating the Tic Tac Toes Game

class TicTacToe:
    def __init__(self):
        self.row_count=3
        self.column_count=3
        self.action_size=self.row_count*self.column_count
        
    def get_init_state(self):
        return np.zeros((self.row_count,self.column_count))
    def get_next_state(self,state,action,player):
        row = action // self.column_count
        column = action % self.column_count
        state[row,column]=player
        return state
    def get_valid_moves(self,state):
        return(state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self,state,action):
        
        if action == None:
            return False
        
        row = action // self.column_count
        column = action % self.column_count
        player = state[row,column]
        
        return(
            np.sum(state[row,:]) == player *(self.column_count)
            or np.sum(state[:,column]) == player * (self.row_count)
            or np.sum(np.diag(state))==player * self.row_count
            or np.sum(np.diag(np.flip(state,axis=0 ))) == player * self.row_count
        )
    
    def get_val_and_end(self,state,action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opp(self,player):
        return -player
    
    def get_opp_value(self,value):
        return -value
    
    def change_perspective(self,state,player):
        return state * player
    
    def get_encoded_state(self,state):
        encoded_state=np.stack((state==-1,state==0,state==1)).astype(np.float32)
        return encoded_state

In [3]:
#Simulating tic tac toe playing to ensure game works correctly
ticTac= TicTacToe()
player=1
state = ticTac.get_init_state()

while True:
    print(state)
    valid_moves= ticTac.get_valid_moves(state)
    print("valid moves", [i for i in range(ticTac.action_size) if valid_moves[i]==1])
    
    action = int(input(f"{player:}"))
    
    if valid_moves[action]==0:
        print("action not valid")
        continue
    state = ticTac.get_next_state(state,action,player)
    
    value,is_terminal = ticTac.get_val_and_end(state,action)
    
    if is_terminal:
        print(state)
        if value==1:
            print(player,"won")
        else:
            print("tie")
        break;
    player=ticTac.get_opp(player)
    

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves [0, 1, 2, 3, 4, 5, 6, 7, 8]
10
[[1. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves [1, 2, 3, 4, 5, 6, 7, 8]
-11
[[ 1. -1.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]
valid moves [2, 3, 4, 5, 6, 7, 8]
12
[[ 1. -1.  1.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]
valid moves [3, 4, 5, 6, 7, 8]
-13
[[ 1. -1.  1.]
 [-1.  0.  0.]
 [ 0.  0.  0.]]
valid moves [4, 5, 6, 7, 8]
14
[[ 1. -1.  1.]
 [-1.  1.  0.]
 [ 0.  0.  0.]]
valid moves [5, 6, 7, 8]
-15
[[ 1. -1.  1.]
 [-1.  1. -1.]
 [ 0.  0.  0.]]
valid moves [6, 7, 8]
16
[[ 1. -1.  1.]
 [-1.  1. -1.]
 [ 1.  0.  0.]]
1 won


In [8]:
# Monte Carlo Tree Search

class Node:
    def __init__(self,game,args,state, parent=None,action_taken=None,prior=0):
        self.game=game
        self.args=args
        self.state=state
        self.parent=parent
        self.action_taken=action_taken
        self.prior=prior
        
        self.children=[]
        
        
        self.visit_count=0
        self.value_sum=0
        
    def is_fully_expanded(self):
        return len(self.children)>0
    
    def select(self):
        best_child=None
        best_ucb=-np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if(ucb>best_ucb):
                best_child=child
                best_ucb=ucb
        return best_child
    
    def get_ucb(self,child):
        if child.visit_count==0:
            q_value=0
        else:
            q_value = 1 -((child.value_sum/child.visit_count)+1)/2
        return q_value +self.args['C'] * math.sqrt((self.visit_count)/(child.visit_count+1)) * child.prior
    
    def expand(self,policy):
        
        for action, prob in enumerate(policy):
            if prob>0:
                child_state=self.state.copy()
                child_state=self.game.get_next_state(child_state,action,1)
                child_state= self.game.change_perspective(child_state,player=-1)

                child=Node(self.game,self.args,child_state,self,action,prob)
                self.children.append(child)
        return child
    
    def simulate(self):
        value, is_terminal = self.game.get_val_and_end(self.state,self.action_taken)
        value=self.game.get_opp_value(value)
        if is_terminal:
            return value
        rollout_state=self.state.copy()
        rollout_player=1
        
        while True:
            valid_moves=self.game.get_valid_moves(rollout_state)
            action =np.random.choice(np.where(valid_moves==1)[0])
            rollout_state = self.game.get_next_state(rollout_state,action,rollout_player)
            
            value,is_terminal = self.game.get_val_and_end(rollout_state,action)
            if is_terminal:
                if rollout_player==-1:
                    value=self.game.get_opp_value(value)
                return value
            rollout_player=self.game.get_opp(rollout_player)
            
    def backpropagate(self,value):
        self.value_sum += value
        self.visit_count += 1
        
        value = self.game.get_opp_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)  
        
        
        
class MCTS:
    def __init__(self,game,args,model):
        self.game=game
        self.args=args
        self.model=model
        
    @torch.no_grad()    
    def search(self,state):
        #define root
        root= Node(self.game,self.args,state)
        
        #Search iterations
        for search in range(self.args['num_searches']):
            node = root
            
            while node.is_fully_expanded():
                node=node.select()
                
            value, is_terminal = self.game.get_val_and_end(node.state,node.action_taken)
            value=self.game.get_opp_value(value)
            
            if not is_terminal:
                
                policy,value=self.model(
                    torch.tensor(self.game.get_encoded_state(node.state)).unsqueeze(0)
                )
                policy=torch.softmax(policy,axis=1).squeeze(0).cpu().numpy()
                valid_moves=self.game.get_valid_moves(node.state)
                policy*=valid_moves
                policy/=np.sum(policy)
                
                value=value.item()
                #Expansion
                node=node.expand(policy)
                #simulation
                #value=node.simulate()
            #backPropagation
            node.backpropagate(value)
        
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken]=child.visit_count
        action_probs/np.sum(action_probs)
        return action_probs
            

In [10]:
class AlphaZero:
    def __init__(self,model,optimizer,game,args):
        self.model=model
        self.optimizer=optimizer
        self.game=game
        self.args = args
        
        self.mcts=MCTS(game,args,model)
    
    def selfPlay(self):
        memory=[]
        player=1
        state=self.game.get_init_state()
        
        while True:
            neutral_state=self.game.change_perspective(state,player)
            action_probs=self.mcts.search(neutral_state)
            
            memory.append((neutral_state,action_probs,player))
            action=np.random.choice(self.game.action_size,p=action_probs)
            
            state=self.game.get_next_state(state,action,player)
            
            value, is_terminal = self.game.get_value_and_end(state,action)
            
            if is_terminal:
                returnMemory=[]
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player==play else -value
                
            
    
    def train(self,memory):
        pass
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory=[]
            
            self.model.eval()
            for selfPlay_iteration in range(self.args['num_selfPlay_iterations']):
                memory+=self.selfPlay()
                
            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(),f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(),f"optimizer_{iteration}.pt")
                

IndentationError: expected an indented block (1212841568.py, line 13)

In [9]:
#Simulating tic tac toe playing to ensure game works correctly
ticTac= TicTacToe()
player=1

args={
    'C':2,
    'num_searches':1000
}

model=ResNet(ticTac,4,64)
model.eval()

mcts=MCTS(ticTac,args,model)
state = ticTac.get_init_state()

while True:
    print(state)
    if player==1:
        valid_moves= ticTac.get_valid_moves(state)
        print("valid moves", [i for i in range(ticTac.action_size) if valid_moves[i]==1])

        action = int(input(f"{player:}"))

        if valid_moves[action]==0:
            print("action not valid")
            continue
    else:
        neutral_state=ticTac.change_perspective(state,player)
        mcts_probs=mcts.search(neutral_state)
        action=np.argmax(mcts_probs)
        
    state = ticTac.get_next_state(state,action,player)
    
    value,is_terminal = ticTac.get_val_and_end(state,action)
    
    if is_terminal:
        print(state)
        if value==1:
            print(player,"won")
        else:
            print("tie")
        break;
    player=ticTac.get_opp(player)
    

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves [0, 1, 2, 3, 4, 5, 6, 7, 8]
10
[[1. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[ 1.  0.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]
valid moves [1, 2, 3, 5, 6, 7, 8]
11
[[ 1.  1.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]
[[ 1.  1. -1.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]
valid moves [3, 5, 6, 7, 8]
18
[[ 1.  1. -1.]
 [ 0. -1.  0.]
 [ 0.  0.  1.]]
[[ 1.  1. -1.]
 [-1. -1.  0.]
 [ 0.  0.  1.]]
valid moves [5, 6, 7]
17
[[ 1.  1. -1.]
 [-1. -1.  0.]
 [ 0.  1.  1.]]
[[ 1.  1. -1.]
 [-1. -1. -1.]
 [ 0.  1.  1.]]
-1 won


In [4]:
#Deep Learning Neural Network Model
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResNet(nn.Module):
    def __init__(self,game,num_resBlocks,num_hidden):
        super().__init__()
        
        self.startBlock=nn.Sequential(
            nn.Conv2d(3,num_hidden,kernel_size=3,padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU() 
        )
        self.backBone=nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.policyHead=nn.Sequential(
            nn.Conv2d(num_hidden,32,kernel_size=3,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32*game.row_count*game.column_count,game.action_size)
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
    def forward(self,x):
        x=self.startBlock(x)
        for resBlock in self.backBone:
            x=resBlock(x)
        policy=self.policyHead(x)
        value = self.valueHead(x)
        return policy,value
        
class ResBlock(nn.Module):
    def __init__(self,num_hidden):
        super().__init__()
        
        self.conv1=nn.Conv2d(num_hidden,num_hidden,kernel_size=3,padding=1)
        self.bh1 = nn.BatchNorm2d(num_hidden)
        self.conv2=nn.Conv2d(num_hidden,num_hidden,kernel_size=3,padding=1)
        self.bh2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self,x):
        residual=x
        x=F.relu(self.bh1(self.conv1(x)))
        x=(self.bh2(self.conv2(x)))
        x+=residual
        x=F.relu(x)
        return x
        

In [5]:
tictactoe=TicTacToe()
state = tictactoe.get_init_state()
state=tictactoe.get_next_state(state,2,1)
state=tictactoe.get_next_state(state,7,-1)
print(state)
encoded_state=tictactoe.get_encoded_state(state)
print(encoded_state)

tensor_state = torch.tensor(encoded_state).unsqueeze(0)

model = ResNet(tictactoe, 4, 64)

policy, value = model(tensor_state)

value = value.item()
policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

print(value, policy)

[[ 0.  0.  1.]
 [ 0.  0.  0.]
 [ 0. -1.  0.]]
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 1. 0.]]

 [[1. 1. 0.]
  [1. 1. 1.]
  [1. 0. 1.]]

 [[0. 0. 1.]
  [0. 0. 0.]
  [0. 0. 0.]]]
0.5841861367225647 [0.16173169 0.11846372 0.0833969  0.16649143 0.10786711 0.11294491
 0.05666931 0.12728618 0.06514869]
