In [57]:
import numpy as np
from itertools import product
from collections import defaultdict
from tqdm import tqdm

In [58]:
class TicTacToe:
    def __init__(self):
        self.reset()
    def reset(self):
        self.board = np.zeros((3,3))
        # 1 for X, -1 for O
        self.current_player=1
        return self.get_state()
    
    def get_state(self):
        return tuple(self.board.flatten())
    
    def can_move(self):
        return [(i,j) for i,j in product(range(3),range(3))
            if self.board[i,j]==0]
    
    def check_winner(self):
        for i in range(3):
            if abs(sum(self.board[i,:]))==3:
                return self.board[i,0]
            if abs(sum(self.board[:,i]))==3:
                return self.board[0,i]

        if abs(sum(np.diag(self.board)))==3:
            return self.board[0,0]
        if abs(sum(np.diag(np.rot90(self.board))))==3:
            return self.board[0,2]

        if len(self.can_move())==0:
            return 0

        return None
    def step(self,action):
        i,j=action

        if self.board[i,j] !=0:
            return self.get_state(), -10, True

        self.board[i,j]= self.current_player
        
        winner = self.check_winner()
        
        reward=0
        finish=False
        
        if winner is not None:
            finish = True
            if winner == 0:
                reward = 0
            elif winner == self.current_player:
                reward =1
            else:
                reward = -1
                
        self.current_player *= -1
        
        return self.get_state(), reward, finish

In [59]:
def print_board(board):
    symbols = {0: ' ', 1: 'X', -1: '0'}
    for i in range(3):
        print('-------------')
        row='| '
        for j in range(3):
            row+=symbols[board[i,j]]+' | '
        print(row)
    print('-------------')


In [60]:
env = TicTacToe()
state = env.reset()
print_board(env.board)

-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------


In [61]:
class ValueIteration:
    def __init__(self, gamma=0.9):
        self.V = defaultdict(float)
        self.policy={}
        self.gamma = gamma
        
    def get_action_value(self,game,state,action):
        board = np.array(state).reshape(3,3)
        game.board = board.copy()
        
        next_state, reward, finish = game.step(action)
        
        if finish:
            return reward
        r_next = reward+self.gamma*self.V[next_state]
        return r_next
    def generate_all_states(self):
        states=[]
        for board in product([-1,0,1],repeat=9):
            if abs(sum(board))>1:
                continue
            states.append(tuple(board))
        return states
    
    def train(self,episodes=1000,e=1e-4):
        game=TicTacToe()
        
        for _ in tqdm(range(episodes)):
            delta=0;
            all_states=self.generate_all_states()
            
            for state in all_states:
                board = np.array(state).reshape(3,3)
                game.board = board.copy()
                move = game.can_move()
                
                if not move:
                    continue
                    
                old_v = self.V[state]
                
                action_values = [self.get_action_value(game, state, action) 
                                 for action in move]
                
                self.V[state] = max(action_values)
                self.policy[state]=move[np.argmax(action_values)]
                
                delta = max(delta, abs(old_v-self.V[state]))
                
            if delta<e:
                break
                
    
        

In [62]:
vi_agent = ValueIteration(gamma=0.9)
print("Value Iteration agent...")
vi_agent.train(episodes=1000)

def play_game(agent, player_first=True):
    game = TicTacToe()
    state = game.reset()
    done = False
    
    while not done:
        print("\nCurrent board: ")
        print_board(game.board)
        
        if game.current_player == (1 if player_first else -1):
            while True:
                try:
                    i=int(input("Enter row (0-2): "))
                    j=int(input("Enter col (0-2): "))
                    if 0<=i<=2 and 0<=j<=2 and game.board[i,j]==0:
                        break
                    print("invalid move")
                except ValueError:
                    print("invalid move")
            action=(i,j)
            
        else:
            action = agent.policy.get(state,game.can_move()[0])
        state,reward, done = game.step(action)
    print("\n Final board: ")
    print_board(game.board)
    winner = game.check_winner()
    if winner == 0:
        print("Draw")
    elif winner ==1:
        print("X wins")
    else:
        print("O wins")
        
play_game(vi_agent,player_first=True)

Value Iteration agent...


100%|█████████████████████████████████████████████████████████████████| 1000/1000 [09:57<00:00,  1.67it/s]



Current board: 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
Enter row (0-2): 1
Enter col (0-2): 1

Current board: 
-------------
|   |   |   | 
-------------
|   | X |   | 
-------------
|   |   |   | 
-------------

Current board: 
-------------
|   | 0 |   | 
-------------
|   | X |   | 
-------------
|   |   |   | 
-------------
Enter row (0-2): 2
Enter col (0-2): 2

Current board: 
-------------
|   | 0 |   | 
-------------
|   | X |   | 
-------------
|   |   | X | 
-------------

Current board: 
-------------
| 0 | 0 |   | 
-------------
|   | X |   | 
-------------
|   |   | X | 
-------------
Enter row (0-2): 0
Enter col (0-2): 2

Current board: 
-------------
| 0 | 0 | X | 
-------------
|   | X |   | 
-------------
|   |   | X | 
-------------

Current board: 
-------------
| 0 | 0 | X | 
-------------
|   | X | 0 | 
-------------
|   |   | X | 
-------------
Enter row (0-2): 2
Enter col (0-2): 0

 Final board: 
------