<a href="https://colab.research.google.com/github/nomomon/drl-js/blob/main/tic-tac-toe/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import numpy as np

# Model

In [23]:
class Policy(tf.keras.Model):
    def __init__(self):
        super(Policy, self).__init__()
        self.dense1 = tf.keras.layers.Dense(100, activation=tf.nn.relu, input_shape = (-1, 9))
        self.dense2 = tf.keras.layers.Dense(100, activation=tf.nn.relu)
        self.dense3 = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)

        self.dropout = tf.keras.layers.Dropout(0.1)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return self.dense3(x)

In [24]:
policy = Policy()
policy.compile(
    optimizer = tf.keras.optimizers.Adam(), 
    loss = tf.keras.losses.BinaryCrossentropy(), 
    metrics=[tf.keras.metrics.BinaryAccuracy()]
)

In [25]:
policy.predict([[1, 0, -1, 0, -1, 0, 1, 0, 0]])

array([[0.558074]], dtype=float32)

In [26]:
policy.summary()

Model: "policy_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_6 (Dense)             multiple                  1000      
                                                                 
 dense_7 (Dense)             multiple                  10100     
                                                                 
 dense_8 (Dense)             multiple                  101       
                                                                 
 dropout_1 (Dropout)         multiple                  0         
                                                                 
Total params: 11,201
Trainable params: 11,201
Non-trainable params: 0
_________________________________________________________________


In [6]:
# 0 1 2
# 3 4 5
# 6 7 8

def gameState(board):
    lines = [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8],
        [0, 3, 6],
        [1, 4, 7],
        [2, 5, 8],
        [0, 4, 8],
        [2, 4, 6],
    ]
    for line in lines:
        if (board[line[0]] == board[line[1]] and board[line[1]] == board[line[2]] and board[line[1]] != 0):
            return board[line[1]]
    return 0

In [7]:
gameState([0, 0, 0, 
           1, 1, 1, 
           1, 0, 0])

1

In [27]:
def chooseAction(policy, board):
    probs = [0, 0, 0,
             0, 0, 0, 
             0, 0, 0]

    for i, cell in enumerate(board):
        if cell == 0:
            playBoard = board
            playBoard[i] = 1

            probs[i] = policy.predict([playBoard])[0][0]

    maxprob = probs[0]
    maxi = 0

    for i, prob in enumerate(probs):
        if prob >= maxprob:
            maxprob = prob
            maxi = i

    return maxi

In [28]:
chooseAction(policy, [0, 0, 1,
                      1, 1, 0,
                      0, 0, 1])

7

In [41]:
def getData(policy):
    X_all = []
    y_all = []
    
    X = []
    y = []

    board = [0, 0, 0,
             0, 0, 0,
             0, 0, 0]

    winner = 0

    for i in range(9):
        action = chooseAction(policy, board.copy())
        
        board[action] = 1


        X.append(board)
        y.append((i % 2) * 2 - 1)

        if(gameState(board) != 0):
            winner = (i % 2) * 2 - 1
            break

        board = (np.array(board) * -1).tolist()

    y = list(map(lambda q: ((q == winner) - 0), y))
    
    X_all.extend(X)
    y_all.extend(y)

    X = []
    y = []

    board = [0, 0, 0,
             0, 0, 0,
             0, 0, 0]

    winner = 0

    for i in range(9):
        if(i % 2 == 1):
            action = chooseAction(policy, board.copy())
        else:
            action = np.random.randint(0, 9, size = 1)[0]
            while(board[action] != 0):
                action = np.random.randint(0, 9, size = 1)[0]

        board[action] = 1


        X.append(board)
        y.append((i % 2) * 2 - 1)

        if(gameState(board) != 0):
            winner = (i % 2) * 2 - 1
            break

        board = (np.array(board) * -1).tolist()
    
    y = list(map(lambda q: ((q == winner) - 0), y))

    X_all.extend(X)
    y_all.extend(y)

    X = []
    y = []

    board = [0, 0, 0,
             0, 0, 0,
             0, 0, 0]

    winner = 0

    for i in range(9):
        if(i % 2 == 0):
            action = chooseAction(policy, board.copy())
        else:
            action = np.random.randint(0, 9, size = 1)[0]
            while(board[action] != 0):
                action = np.random.randint(0, 9, size = 1)[0]

        board[action] = 1


        X.append(board)
        y.append((i % 2) * 2 - 1)

        if(gameState(board) != 0):
            winner = (i % 2) * 2 - 1
            break

        board = (np.array(board) * -1).tolist()
    
    y = list(map(lambda q: ((q == winner) - 0), y))

    X_all.extend(X)
    y_all.extend(y)

    return X_all, y_all

In [42]:
getData(policy)

([[1, 0, 0, 0, 0, 0, 0, 0, 0],
  [-1, 0, 0, 1, 0, 0, 0, 0, 0],
  [1, 1, 0, -1, 0, 0, 0, 0, 0],
  [-1, -1, 1, 1, 0, 0, 0, 0, 0],
  [1, 1, -1, -1, 1, 0, 0, 0, 0],
  [-1, -1, 1, 1, -1, 0, 0, 0, 1],
  [1, 1, -1, -1, 1, 1, 0, 0, -1],
  [-1, -1, 1, 1, -1, -1, 1, 0, 1],
  [1, 1, -1, -1, 1, 1, -1, 1, -1],
  [0, 0, 0, 0, 0, 0, 0, 1, 0],
  [1, 0, 0, 0, 0, 0, 0, -1, 0],
  [-1, 0, 0, 0, 0, 0, 1, 1, 0],
  [1, 1, 0, 0, 0, 0, -1, -1, 0],
  [-1, -1, 0, 0, 1, 0, 1, 1, 0],
  [1, 1, 1, 0, -1, 0, -1, -1, 0],
  [1, 0, 0, 0, 0, 0, 0, 0, 0],
  [-1, 0, 0, 0, 0, 1, 0, 0, 0],
  [1, 0, 0, 0, 1, -1, 0, 0, 0],
  [-1, 0, 0, 0, -1, 1, 0, 0, 1],
  [1, 1, 0, 0, 1, -1, 0, 0, -1],
  [-1, -1, 0, 1, -1, 1, 0, 0, 1],
  [1, 1, 1, -1, 1, -1, 0, 0, -1]],
 [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1])

In [53]:
for i in range(100):
    X, y = getData(policy)
    hist = policy.fit(
        X, y, 
        epochs = 1, 
        verbose = 0
    ).history

    print("epoch {:2} | loss: {:1.5} - acc: {:2.2%}".format(i, hist["loss"][0], hist["binary_accuracy"][0]))

epoch  0 | loss: 0.74351 - acc: 41.67%
epoch  1 | loss: 0.66199 - acc: 60.00%
epoch  2 | loss: 0.67562 - acc: 58.33%
epoch  3 | loss: 0.44128 - acc: 96.15%
epoch  4 | loss: 0.52802 - acc: 90.91%
epoch  5 | loss: 0.74594 - acc: 52.00%
epoch  6 | loss: 0.60236 - acc: 66.67%
epoch  7 | loss: 0.59285 - acc: 75.00%
epoch  8 | loss: 0.46898 - acc: 90.91%
epoch  9 | loss: 0.71668 - acc: 61.90%
epoch 10 | loss: 0.4997 - acc: 86.36%
epoch 11 | loss: 0.57439 - acc: 78.95%
epoch 12 | loss: 0.89863 - acc: 36.36%
epoch 13 | loss: 0.54501 - acc: 86.36%
epoch 14 | loss: 0.57674 - acc: 70.83%
epoch 15 | loss: 0.62228 - acc: 71.43%
epoch 16 | loss: 0.63812 - acc: 60.87%
epoch 17 | loss: 0.54426 - acc: 77.27%
epoch 18 | loss: 0.52966 - acc: 70.00%
epoch 19 | loss: 0.74895 - acc: 60.87%
epoch 20 | loss: 0.61448 - acc: 69.57%
epoch 21 | loss: 0.91174 - acc: 35.00%
epoch 22 | loss: 0.41686 - acc: 95.24%
epoch 23 | loss: 0.33415 - acc: 100.00%
epoch 24 | loss: 0.76074 - acc: 54.55%
epoch 25 | loss: 0.53626 

KeyboardInterrupt: ignored

# Web demo

Play against the AI

In [54]:
from IPython.display import clear_output 

def symbol(x):
    if x == 1:
        return "X"
    elif x == -1:
        return "O"
    else:
        return "?"

def printBoard(board):
    clear_output()
    cBoard = list(map(symbol, board))
    for i in range(0, 3):
        row = ""
        for j in range(0, 3):
            row += (cBoard[j + i * 3]) if (cBoard[j + i * 3] != "?") else str(j + i * 3 + 1)
            if j != 2:
                row += " | "
        print(row)
        if i != 2:
            print("---------")

def play(policy):
    player = (int(input("which player you want to be? (1 or 2) ")) + 1) % 2 

    board = [0, 0, 0,
             0, 0, 0,
             0, 0, 0]

    winner = 0

    for i in range(9):
        if (i % 2 == player):
            printBoard(board)
            action = int(input("what cell? ")) - 1
        else:
            action = chooseAction(policy, board.copy())

        board[action] = 1

        if(gameState(board) != 0):
            winner = i % 2 == player
            break

        board = (np.array(board) * -1).tolist()
    
    printBoard(board)
    if(gameState(board) != 0):
        print("\nwinner is the", "humen" if winner else "ai")
    else:
        print("\nit's a tie!")

In [57]:
play(policy)

O | O | 3
---------
4 | O | 6
---------
X | X | X

winner is the humen
