<a href="https://colab.research.google.com/github/mahiinn62/DLRL_assignment/blob/main/TicTacToe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# tic_tac_toe_rl.py

from __future__ import annotations
import numpy as np
import pickle
from pathlib import Path
import argparse
import sys

BOARD_ROWS = 3
BOARD_COLS = 3


class State:
    def __init__(self, p1, p2):
        self.board = np.zeros((BOARD_ROWS, BOARD_COLS), dtype=int)
        self.p1 = p1
        self.p2 = p2
        self.isEnd = False
        self.boardHash = None

        self.playerSymbol = 1


    def getHash(self) -> str:
        self.boardHash = str(self.board.reshape(BOARD_COLS * BOARD_ROWS))
        return self.boardHash

    def winner(self):
        # row
        for i in range(BOARD_ROWS):
            row_sum = sum(self.board[i, :])
            if row_sum == 3:
                self.isEnd = True
                return 1
            if row_sum == -3:
                self.isEnd = True
                return -1
        # col
        for i in range(BOARD_COLS):
            col_sum = sum(self.board[:, i])
            if col_sum == 3:
                self.isEnd = True
                return 1
            if col_sum == -3:
                self.isEnd = True
                return -1
        # diagonal
        diag_sum1 = sum([self.board[i, i] for i in range(BOARD_COLS)])
        diag_sum2 = sum([self.board[i, BOARD_COLS - i - 1] for i in range(BOARD_COLS)])
        if diag_sum1 == 3 or diag_sum2 == 3:
            self.isEnd = True
            return 1
        if diag_sum1 == -3 or diag_sum2 == -3:
            self.isEnd = True
            return -1


        if len(self.availablePositions()) == 0:
            self.isEnd = True
            return 0


        self.isEnd = False
        return None

    def availablePositions(self):
        positions = []
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                if self.board[i, j] == 0:
                    positions.append((i, j))
        return positions

    def updateState(self, position):
        # place current player's symbol then flip symbol
        self.board[position] = self.playerSymbol
        self.playerSymbol = -1 if self.playerSymbol == 1 else 1

    # only when game ends
    def giveReward(self):
        result = self.winner()
        # backpropagate reward (p1 is positive player)
        if result == 1:
            self.p1.feedReward(1)
            self.p2.feedReward(0)
        elif result == -1:
            self.p1.feedReward(0)
            self.p2.feedReward(1)
        else:  # draw
            self.p1.feedReward(0.5)
            self.p2.feedReward(0.5)

    # board reset
    def reset(self):
        self.board = np.zeros((BOARD_ROWS, BOARD_COLS), dtype=int)
        self.boardHash = None
        self.isEnd = False
        self.playerSymbol = 1

    def play(self, rounds=5000, verbose_every: int = 1000):
        """
        Self-play training loop. p1 and p2 learn from rewards.
        """
        for i in range(rounds):
            if verbose_every and i % verbose_every == 0:
                print(f"Training round: {i}/{rounds}")
            while not self.isEnd:
                # Player 1 turn
                positions = self.availablePositions()
                p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
                self.updateState(p1_action)
                board_hash = self.getHash()
                self.p1.addState(board_hash)

                win = self.winner()
                if win is not None:
                    self.giveReward()
                    self.p1.reset()
                    self.p2.reset()
                    self.reset()
                    break

                # Player 2 turn
                positions = self.availablePositions()
                p2_action = self.p2.chooseAction(positions, self.board, self.playerSymbol)
                self.updateState(p2_action)
                board_hash = self.getHash()
                self.p2.addState(board_hash)

                win = self.winner()
                if win is not None:
                    self.giveReward()
                    self.p1.reset()
                    self.p2.reset()
                    self.reset()
                    break

    # play with human
    def play2(self):
        """
        Human vs computer interactive play. Assumes p1 is computer, p2 is human in code below.
        """
        while not self.isEnd:
            # Computer or player1
            positions = self.availablePositions()
            p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
            self.updateState(p1_action)
            self.showBoard()
            win = self.winner()
            if win is not None:
                if win == 1:
                    print(f"{self.p1.name} wins!")
                else:
                    print("Tie!")
                self.reset()
                break

            # Human (p2)
            positions = self.availablePositions()
            p2_action = self.p2.chooseAction(positions)
            self.updateState(p2_action)
            self.showBoard()
            win = self.winner()
            if win is not None:
                if win == -1:
                    print(f"{self.p2.name} wins!")
                else:
                    print("Tie!")
                self.reset()
                break

    def showBoard(self):
        for i in range(BOARD_ROWS):
            print('-------------')
            out = '| '
            for j in range(BOARD_COLS):
                if self.board[i, j] == 1:
                    token = 'x'
                elif self.board[i, j] == -1:
                    token = 'o'
                else:
                    token = ' '
                out += token + ' | '
            print(out)
        print('-------------')


class Player:
    def __init__(self, name, exp_rate=0.3, lr=0.2, decay_gamma=0.9):
        self.name = name
        self.states = []  # record all positions taken (hashes)
        self.lr = lr
        self.exp_rate = exp_rate
        self.decay_gamma = decay_gamma
        self.states_value = {}  # state -> value

    def getHash(self, board):
        return str(board.reshape(BOARD_COLS * BOARD_ROWS))

    def chooseAction(self, positions, current_board=None, symbol=1):
        """
        Choose action based on epsilon-greedy and learned state-values.
        positions: list of (i,j) tuples
        """
        if np.random.uniform(0, 1) <= self.exp_rate:
            # exploration
            idx = np.random.choice(len(positions))
            action = positions[idx]
        else:

            value_max = -np.inf
            action = positions[0]
            for p in positions:
                next_board = current_board.copy()
                next_board[p] = symbol
                next_boardHash = self.getHash(next_board)
                value = self.states_value.get(next_boardHash, 0)
                if value >= value_max:
                    value_max = value
                    action = p
        return action

    def addState(self, state_hash):
        self.states.append(state_hash)

    def feedReward(self, reward):
        """
        Backpropagate reward to visited states (reverse order).
        """
        for st in reversed(self.states):
            if self.states_value.get(st) is None:
                self.states_value[st] = 0
            self.states_value[st] += self.lr * (self.decay_gamma * reward - self.states_value[st])
            reward = self.states_value[st]

    def reset(self):
        self.states = []

    def savePolicy(self, folder: str = "."):
        folder_path = Path(folder)
        folder_path.mkdir(parents=True, exist_ok=True)
        file_path = folder_path / f"policy_{self.name}.pkl"
        with open(file_path, 'wb') as fw:
            pickle.dump(self.states_value, fw)
        print(f"Policy saved to {file_path}")

    def loadPolicy(self, file: str):
        file_path = Path(file)
        if not file_path.exists():
            print(f"Policy file {file} not found. Continuing without loading.")
            return
        with open(file_path, 'rb') as fr:
            self.states_value = pickle.load(fr)
        print(f"Policy loaded from {file_path}")


class HumanPlayer:
    def __init__(self, name):
        self.name = name

    def chooseAction(self, positions):
        """
        Ask the human for a move until a valid one is provided.
        Accepts zero-based row/col integers.
        """
        while True:
            try:
                row = int(input("Input your action row (0/1/2): "))
                col = int(input("Input your action col (0/1/2): "))
                action = (row, col)
                if action in positions:
                    return action
                print("Invalid position or already taken. Try again.")
            except ValueError:
                print("Invalid input. Please input integer indices 0,1 or 2.")

    def addState(self, state):
        pass

    def feedReward(self, reward):
        pass

    def reset(self):
        pass


def main(args):
    # training
    print("Creating players and state...")
    p1 = Player(name="p1")
    p2 = Player(name="p2")
    st = State(p1, p2)
    print(f"Training for {args.rounds} rounds... (this may take a moment)")
    st.play(rounds=args.rounds, verbose_every=args.verbose_every)

    # Save learned policy for p1 (first player)
    p1.savePolicy(folder=args.policy_dir)

    # Play against human
    print("\nNow you can play against the trained agent (p1).")
    agent = Player("computer", exp_rate=0)  # deterministic agent
    agent.loadPolicy(str(Path(args.policy_dir) / "policy_p1.pkl"))
    human = HumanPlayer("human")
    game_state = State(agent, human)

    cont = 'y'
    while cont.lower() == 'y':
        game_state.play2()
        cont = input("Play again? (y/n): ").strip() or 'n'
    print("Thanks for playing!")


if __name__ == "__main__":

    try:
        parser = argparse.ArgumentParser(description="Tic-Tac-Toe RL demo", add_help=True)
        parser.add_argument("--rounds", type=int, default=5000, help="Number of self-play training rounds")
        parser.add_argument("--verbose_every", type=int, default=1000, help="Print training progress every N rounds")
        parser.add_argument("--policy_dir", type=str, default="policies", help="Directory to save/load policies")
        args, unknown = parser.parse_known_args()
    except Exception:

        class Args: pass
        args = Args()
        args.rounds = 5000
        args.verbose_every = 1000
        args.policy_dir = "policies"

    main(args)


Creating players and state...
Training for 5000 rounds... (this may take a moment)
Training round: 0/5000
Training round: 1000/5000
Training round: 2000/5000
Training round: 3000/5000
Training round: 4000/5000
Policy saved to policies/policy_p1.pkl

Now you can play against the trained agent (p1).
Policy loaded from policies/policy_p1.pkl
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
|   | x |   | 
-------------
Input your action row (0/1/2): 0
Input your action col (0/1/2): 1
-------------
|   | o |   | 
-------------
|   |   |   | 
-------------
|   | x |   | 
-------------
-------------
|   | o |   | 
-------------
|   |   |   | 
-------------
| x | x |   | 
-------------
Input your action row (0/1/2): 1
Input your action col (0/1/2): 2
-------------
|   | o |   | 
-------------
|   |   | o | 
-------------
| x | x |   | 
-------------
-------------
|   | o |   | 
-------------
|   |   | o | 
-------------
| x | x | x | 
-------------
computer wins!
Play ag