<a href="https://colab.research.google.com/github/datapirate09/Tic-Tac-Toe-Game-using-Reinforcement-Learning-Methods/blob/main/sarsa_algo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import random
import copy

state_action_pairs = {}

def board_to_tuple(board):
    return tuple(tuple(row) for row in board)

def getCurrentTurn(board):
    count_0 = sum(row.count(0) for row in board)
    count_1 = sum(row.count(1) for row in board)
    return 1 if count_0 > count_1 else 0

def getNextStates(board, turn):
    next_states = []
    for i in range(3):
        for j in range(3):
            if board[i][j] == -1:
                new_board = [row[:] for row in board]
                new_board[i][j] = turn
                next_states.append([new_board, (i, j)])
    return next_states

def getReward(board, player):
    for i in range(3):
        if board[i][0] == board[i][1] == board[i][2] != -1:
            return 1 if board[i][0] == player else -1

    for i in range(3):
        if board[0][i] == board[1][i] == board[2][i] != -1:
            return 1 if board[0][i] == player else -1

    if board[0][0] == board[1][1] == board[2][2] != -1:
        return 1 if board[0][0] == player else -1
    if board[0][2] == board[1][1] == board[2][0] != -1:
        return 1 if board[0][2] == player else -1
    return 0

def isEndOfGame(board):
    if getReward(board, 0) != 0 or getReward(board, 1) != 0:
        return True

    return all(cell != -1 for row in board for cell in row)

def initialize_states():
    initial_board = [[-1 for _ in range(3)] for _ in range(3)]
    queue = [initial_board]
    visited = {board_to_tuple(initial_board)}

    while queue:
        board = queue.pop(0)
        board_tuple = board_to_tuple(board)

        if isEndOfGame(board):
            continue

        turn = getCurrentTurn(board)
        for next_board, action in getNextStates(board, turn):
            next_tuple = board_to_tuple(next_board)
            if next_tuple not in visited:
                visited.add(next_tuple)
                queue.append(next_board)
            if (board_tuple, action) not in state_action_pairs:
                state_action_pairs[(board_tuple, action)] = 0

def epsilon_greedy(state, epsilon):
    valid_actions = [(s, action) for (s, action) in state_action_pairs.keys() if s == state]
    if not valid_actions:
        return None

    if random.random() < epsilon:
        return random.choice([action for _, action in valid_actions])
    else:
        max_value = float('-inf')
        best_actions = []

        for _, action in valid_actions:
            value = state_action_pairs[(state, action)]
            if value > max_value:
                max_value = value
                best_actions = [action]
            elif value == max_value:
                best_actions.append(action)

        return random.choice(best_actions)

def sarsa_algo(num_episodes=10000, learning_rate=0.1, discount_factor=0.9, epsilon=0.1):
    initialize_states()

    for episode in range(num_episodes):
        if episode % 1000 == 0:
          print("Episode ", episode)
        grid = [[-1 for _ in range(3)] for _ in range(3)]
        state_tuple = board_to_tuple(grid)
        turn = 0
        action = epsilon_greedy(state_tuple, epsilon)
        if action is None:
            continue

        while not isEndOfGame(grid):
            grid[action[0]][action[1]] = turn
            immediate_reward = getReward(grid, turn)

            if isEndOfGame(grid):
                state_action_pairs[(state_tuple, action)] += learning_rate * (
                    immediate_reward - state_action_pairs[(state_tuple, action)]
                )
                break

            turn = 1 - turn
            next_state_tuple = board_to_tuple(grid)
            next_action = epsilon_greedy(next_state_tuple, epsilon)

            if next_action is None:
                break

            state_action_pairs[(state_tuple, action)] += learning_rate * (
                immediate_reward + discount_factor * state_action_pairs[(next_state_tuple, next_action)] -
                state_action_pairs[(state_tuple, action)]
            )

            state_tuple = next_state_tuple
            action = next_action

def get_best_action(board):
    board_tuple = board_to_tuple(board)
    valid_actions = [(state, action) for (state, action) in state_action_pairs.keys() if state == board_tuple]

    if not valid_actions:
        empty_cells = [(i, j) for i in range(3) for j in range(3) if board[i][j] == -1]
        if empty_cells:
            return random.choice(empty_cells)
        return None

    max_value = float('-inf')
    best_action = None

    for _, action in valid_actions:
        value = state_action_pairs[(board_tuple, action)]
        if value > max_value:
            max_value = value
            best_action = action

    return best_action

def play_game():
    board = [[-1 for _ in range(3)] for _ in range(3)]
    turn = 0

    print("Game starts! Computer plays X, you play O")
    print_board(board)

    while not isEndOfGame(board):
        if turn == 0:
            print("\nComputer's turn:")
            best_action = get_best_action(board)
            if best_action:
                board[best_action[0]][best_action[1]] = turn
                print(f"Computer placed at position ({best_action[0]+1}, {best_action[1]+1})")
            else:
                print("Computer couldn't find a valid move!")
                break
        else:
            print("\nYour turn:")
            valid_move = False
            while not valid_move:
                try:
                    row = int(input("Enter row (1-3): ")) - 1
                    col = int(input("Enter column (1-3): ")) - 1

                    if 0 <= row <= 2 and 0 <= col <= 2 and board[row][col] == -1:
                        board[row][col] = turn
                        valid_move = True
                    else:
                        print("Invalid move. That position is either occupied or out of bounds.")
                except ValueError:
                    print("Please enter numbers between 1 and 3.")

        print_board(board)

        if isEndOfGame(board):
            reward_x = getReward(board, 0)
            if reward_x == 1:
                print("Computer wins!")
            elif reward_x == -1:
                print("You win!")
            else:
                print("It's a draw!")
            break

        turn = 1 - turn

def print_board(board):
    symbols = {-1: " ", 0: "X", 1: "O"}
    print("  1 2 3")
    for i in range(3):
        print(f"{i+1} ", end="")
        for j in range(3):
            print(symbols[board[i][j]], end="")
            if j < 2:
                print("|", end="")
        print()
        if i < 2:
            print("  -+-+-")

print("Training")
sarsa_algo(num_episodes=10000)
print("Training complete!")
play_game()

Training
Episode  0
Episode  1000
Episode  2000
Episode  3000
Episode  4000
Episode  5000
Episode  6000
Episode  7000
Episode  8000
Episode  9000
Training complete!
Game starts! Computer plays X, you play O
  1 2 3
1  | | 
  -+-+-
2  | | 
  -+-+-
3  | | 

Computer's turn:
Computer placed at position (2, 2)
  1 2 3
1  | | 
  -+-+-
2  |X| 
  -+-+-
3  | | 

Your turn:
Enter row (1-3): 1
Enter column (1-3): 1
  1 2 3
1 O| | 
  -+-+-
2  |X| 
  -+-+-
3  | | 

Computer's turn:
Computer placed at position (2, 1)
  1 2 3
1 O| | 
  -+-+-
2 X|X| 
  -+-+-
3  | | 

Your turn:
Enter row (1-3): 2
Enter column (1-3): 3
  1 2 3
1 O| | 
  -+-+-
2 X|X|O
  -+-+-
3  | | 

Computer's turn:
Computer placed at position (1, 2)
  1 2 3
1 O|X| 
  -+-+-
2 X|X|O
  -+-+-
3  | | 

Your turn:
Enter row (1-3): 3
Enter column (1-3): 2
  1 2 3
1 O|X| 
  -+-+-
2 X|X|O
  -+-+-
3  |O| 

Computer's turn:
Computer placed at position (1, 3)
  1 2 3
1 O|X|X
  -+-+-
2 X|X|O
  -+-+-
3  |O| 

Your turn:
Enter row (1-3): 3
Enter c