In [15]:
import numpy as np

class State:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.board = np.zeros([3, 3])
        self.result = 0 # 1 for p1, 2 for p2, -1 for draw, 0 for ongoing game
        self.turn = 1
        
    def legal_actions(self):
        if self.result != 0:
            return []
        
        board = self.board.reshape(-1)
        indices = np.array(range(len(board)), dtype=int)
        return indices[board == 0]
    
    def update(self, action):
        # assumed action is legal
        self.board[action//3, action%3] = self.turn
        self.turn = 3 - self.turn # turn changes even if game has ended but doesnt matter

        for p in range(1, 3):
            for i in range(3):
                if (self.board[i, :] == p).all():
                    self.result = p
            for j in range(3):
                if (self.board[:, j] == p).all():
                    self.result = p
            
            if self.board[0, 0] == self.board[1, 1] == self.board[2, 2] == p:
                self.result = p
            if self.board[0, 2] == self.board[1, 1] == self.board[2, 0] == p:
                self.result = p
            
        if self.result == 0 and (self.board.reshape(-1) != 0).all():
            self.result = -1
        
        return self.get_index()
    
    def get_index(self):
        res = np.full(27, 0)
        for i in range(3):
            for j in range(3):
                for k in range(3):
                    res[i*9+j*3+k] = (self.board[i, j] == k).astype(float)
#         print(res)
        return torch.tensor(res, dtype=torch.float32).view([1, -1])
#         return torch.tensor(
#             [-1 if i == 2 else i for i in self.board.reshape(-1)],
#             dtype=torch.float32).view(1, -1)
#         board = self.board.reshape(-1)
#         res = 0
#         for i in range(len(board)):
#             res = res*3 + board[i]
#         return res
    
    def get_reward(self):
        if self.result == 0:
            return 0
        elif self.result == -1:
            return 0.5
        else:
            return 1
    
    def print_board(self):
        for i in range(3):
            for j in range(3):
                c = self.board[i, j]
                c = 'X' if c == 1 else 'O' if c == 2 else '-'
                print(c, end=' ')
            print('')

        

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(27, 243)
#         self.fc2 = nn.Linear(81, 27)
        self.fc3 = nn.Linear(243, 9)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
#         x = self.fc2(x)
#         x = F.relu(x)
        x = self.fc3(x)
        return x

In [57]:
from collections import deque
import random
class Agent:
    def __init__(self, alpha, gamma, epsilon, epsilon_min, epsilon_decay):
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.reset()
        
    def reset(self):
        self.dqn = DQN()
        self.opt = torch.optim.Adam(self.dqn.parameters(), lr=self.alpha)
        self.criterion = nn.MSELoss()
        self.memory = deque(maxlen = 100000)
        
    def choose_action(self, state, debug=False):
        state_index = state.get_index()
        legal_actions = state.legal_actions()
        #print(legal_actions) #debug
        if np.random.random() <= self.epsilon:
            action = legal_actions[np.random.randint(len(legal_actions))]
        else:
            actions = np.full(9, -5.0)
            with torch.no_grad():
                y = self.dqn(state_index).numpy()
                if debug:
                    print(y)
                for i in legal_actions:
                    actions[i] = y[0, i]
            if debug:
                print(actions)
            action = np.argmax(actions)
            #action = np.random.choice(np.flatnonzero(actions == actions.max()))
        return action
    
    def remember(self, sarsA):
        self.memory.append(sarsA)
        
    def replay(self):
        y_batch, y_target_batch = [], []
        minibatch = random.sample(self.memory, min(len(self.memory), 100))
        state_index_batch = torch.cat([x[0] for x in minibatch])
        y_batch = self.dqn(state_index_batch)
#         y_target_batch = y_batch.clone().detach()
        i = 0
        for state_index, action, reward, next_state_index, legal_actions \
                in minibatch:
#             print(state_index, action, reward, next_state_index, legal_actions)
#             y = self.dqn(state_index)
            y = y_batch[i]
            y_target = y.clone().detach()
            reward = torch.tensor(reward)
            with torch.no_grad():
                y_target[action] = reward if len(legal_actions) == 0 else \
                    reward + self.gamma * torch.max(self.dqn(next_state_index)[0, legal_actions])
#             y_batch.append(y[0])
            y_target_batch.append(y_target)
            i += 1
            
#         y_batch = torch.cat(y_batch)
        y_target_batch = torch.stack(y_target_batch)
        
        self.opt.zero_grad()
        loss = self.criterion(y_batch, y_target_batch)
        loss.backward()
        self.opt.step()
            #self.update_Q(sarsA[0], sarsA[1], sarsA[2], sarsA[3], sarsA[4])
        if self.epsilon > self.epsilon_min:
            self.epsilon = self.epsilon * self.epsilon_decay
        
#         self.memory.clear()
    
    def update_Q(self, state_index, action, reward, next_state_index, legal_actions):
        #print(legal_actions)
        self.Q[state_index, action] = (1 - self.alpha) * self.Q[state_index, action] + \
                                         self.alpha * reward
        if len(legal_actions) != 0:
            #nxt = np.max(self.Q[next_state_index, legal_actions])
            #if nxt != 0:
               # print("nxt:", )
            self.Q[state_index, action] += self.alpha * self.gamma * \
                                            np.max(self.Q[next_state_index, legal_actions])
        

In [58]:
class Environment:
    def __init__(self, ALPHA=0.01, GAMMA=1.0, EPSILON=1.0, EPSILON_DECAY=0.995, LAST_EPSILON=0):
#         ALPHA = 0.003
#         GAMMA = 1.0
#         EPSILON = 0.327
        EPSILON_MIN = 0.01
        self.epsilon = EPSILON
        self.last_epsilon = LAST_EPSILON
#         EPSILON_DECAY = 1.0
        self.p1 = Agent(ALPHA, GAMMA, EPSILON, EPSILON_MIN, EPSILON_DECAY)
        self.p2 = Agent(ALPHA, GAMMA, EPSILON, EPSILON_MIN, EPSILON_DECAY)
        self.state = State()
        
    def reset(self):
        self.p1.reset()
        self.p2.reset()
        self.state.reset()
        
    def learn(self, num_episodes):
        self.p1.epsilon = self.p2.epsilon = self.epsilon
        rs = np.zeros([10])
        for i in range(num_episodes):
            self.state.reset()
            prev_p2_state = None
            while True:
                prev_p1_state = self.state.get_index()
                p1_action = self.p1.choose_action(self.state)
                now_p2_state = self.state.update(p1_action)
                reward = self.state.get_reward()
                if prev_p2_state is not None:
                    legal_actions = self.state.legal_actions()
                    self.p2.remember((prev_p2_state, p2_action, reward if reward != 1 else -1,
                                      now_p2_state, legal_actions))
                if reward != 0:
                    self.p1.remember((prev_p1_state, p1_action, reward, None, []))
                    break
                    
                prev_p2_state = now_p2_state
                p2_action = self.p2.choose_action(self.state)
                now_p1_state = self.state.update(p2_action)
                reward = self.state.get_reward()
                legal_actions = self.state.legal_actions()
                self.p1.remember((prev_p1_state, p1_action, reward if reward != 1 else -1,
                                  now_p1_state, legal_actions))
                if reward != 0:
                    self.p2.remember((prev_p2_state, p2_action, reward, None, []))
                    break
                
#                 action = self.p1.choose_action(self.state)
#                 state_index = self.state.get_index()
#                 next_state_index = self.state.update(action)
#                 reward = self.state.get_reward()
#                 legal_actions = self.state.legal_actions()
#                 self.p1.remember((state_index, action, reward, next_state_index, legal_actions))
#                 if reward != 0:
#                     break
                
#                 action = self.p2.choose_action(self.state)
#                 state_index = self.state.get_index()
#                 next_state_index = self.state.update(action)
#                 reward = self.state.get_reward()
#                 legal_actions = self.state.legal_actions()
#                 self.p2.remember((state_index, action, reward, next_state_index, legal_actions))
#                 if reward != 0:
#                     break
            
            self.p1.replay()
            self.p2.replay()
            if num_episodes - i < 10:
                self.p1.epsilon = self.p2.epsilon = self.last_epsilon
                rs[num_episodes - i] = self.state.result
            
        first_won = sum((rs == 1).astype(int))
        second_won = sum((rs == 2).astype(int))
        draw = sum((rs == -1).astype(int))
        #print(rs == 1)
        return (first_won, second_won, draw)                    
        

In [59]:
env = Environment(ALPHA=0.01, EPSILON=0.3, GAMMA=0.7, EPSILON_DECAY=1.0)

In [66]:
from tqdm.notebook import trange
for i in trange(5):
#      env.p1.epsilon = 0.3
#      env.p2.epsilon = 0.3
    #print(env.p1.epsilon, env.p2.epsilon)
    print(env.learn(800))
# cnt = 0
# for i in range(3**9):
#     if np.sum(env.p1.Q[i, ]) > 2.7:
#         #print(env.p1.Q[i, ])
#         cnt  += 1
# print(cnt)
#for i in range(10):
   # print(env.p1.Q[np.random.randint(3**9)])
#print(np.sum(env.p1.Q[:, :]))

Widget Javascript not detected.  It may not be installed or enabled properly.


(0, 0, 9)
(0, 0, 9)
(3, 2, 4)
(0, 0, 9)
(0, 0, 9)



In [67]:
import copy
p1 = copy.deepcopy(env.p1)
p1.epsilon = 0
print([x for x in p1.dqn.parameters()])
p2 = copy.deepcopy(env.p2)
p2.epsilon = 0
print([x for x in p2.dqn.parameters()])
game = State()

[Parameter containing:
tensor([[-0.2339,  0.1556, -0.8679,  ..., -1.4896, -0.7616,  0.1076],
        [-0.4072,  0.2342, -1.3025,  ..., -0.1416, -0.3626,  0.0964],
        [-0.1299, -0.4486,  0.2227,  ..., -0.3393, -0.6022,  0.1188],
        ...,
        [-0.1619, -0.0657, -0.2739,  ..., -0.0895, -0.2674, -0.0437],
        [-1.4632,  0.1178, -0.3631,  ..., -0.6100,  0.0053, -1.0096],
        [-0.4208, -1.1076,  0.1053,  ..., -0.3191,  0.2866, -0.8766]],
       requires_grad=True), Parameter containing:
tensor([-6.3363e-01, -2.8839e-01, -3.3937e-01, -2.5867e-01, -2.0855e-01,
        -4.0770e-01, -4.2733e-01, -7.4391e-01, -3.4704e-01, -3.2095e-01,
        -6.3928e-01, -3.9254e-01, -8.2338e-01, -2.1011e-01, -8.3252e-01,
        -2.3474e-01, -6.2762e-01, -6.1426e-02, -3.4527e-01, -4.8078e-01,
        -5.9223e-01, -1.4345e-01, -6.3170e-01, -2.1476e-01, -6.1273e-01,
        -2.0477e-01, -2.6297e-01, -2.6964e-01, -2.2097e-01, -1.9344e-01,
        -2.7126e-01, -3.7851e-01, -2.4742e-01, -2.8758e

In [68]:
game.reset()
while True:
#     #cpu first
#     game.print_board()
#     action = p1.choose_action(game, debug=True)
#     #state_index = game.get_index()
#     next_state_index = game.update(action)
#     reward = game.get_reward()
#     if reward != 0:
#         break
#     print()
#     game.print_board()
#     action = int(input())
#     #state_index = self.state.get_index()
#     next_state_index = game.update(action)
#     reward = game.get_reward()
#     #legal_actions = self.state.legal_actions()
#     if reward!= 0:
#         break
        
    # player first
    game.print_board()
    action = int(input())
    #state_index = self.state.get_index()
    next_state_index = game.update(action)
    reward = game.get_reward()
    #legal_actions = self.state.legal_actions()
    if reward != 0:
        break
    game.print_board()
    action = p2.choose_action(game, debug=True)
    #state_index = game.get_index()
    next_state_index = game.update(action)
    reward = game.get_reward()
    if reward != 0:
        break
    print()
            
print(game.result)
    

- - - 
- - - 
- - - 
4
- - - 
- X - 
- - - 
[[ 0.2162276   0.08804827  0.19759193  0.1177718   0.34314924 -0.01184662
   0.13096912 -0.01980584  0.18402025]]
[ 0.21622761  0.08804827  0.19759193  0.1177718  -5.         -0.01184662
  0.13096912 -0.01980584  0.18402025]

O - - 
- X - 
- - - 
1
O X - 
- X - 
- - - 
[[-0.40721548  0.28417298 -1.3313626  -1.3842998   1.5233128  -0.8686765
  -0.9423363   0.22546801 -1.5427141 ]]
[-5.         -5.         -1.33136261 -1.38429976 -5.         -0.86867648
 -0.94233632  0.22546801 -1.54271412]

O X - 
- X - 
- O - 
3
O X - 
X X - 
- O - 
[[ 1.7340415   0.4855237  -0.77528673  0.12298936 -0.05742446  0.36794394
  -0.6440845   0.33592376 -0.8002629 ]]
[-5.         -5.         -0.77528673 -5.         -5.          0.36794394
 -0.64408451 -5.         -0.80026293]

O X - 
X X O 
- O - 
8
O X - 
X X O 
- O X 
[[ 1.5829349   0.32340115  0.55825675  0.69160366 -0.87538636  0.4214168
  -0.22652675  0.6398905  -1.7852392 ]]
[-5.         -5.          0.558256

In [21]:
dummy = np.zeros(200)
print(sum((dummy == 0).astype(int)))

200
