In [4]:
import torch
import copy
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
class tictactoe():
    def __init__(self):
        self.board = np.zeros(9, dtype='int')
        
    def reset(self):
        self.board[:] = 0
        self.done = False
        return np.copy(self.board)
    
    def choose_board(self, board):
        self.board = board
        self.done = False
        return np.copy(self.board)
    
    def legal_moves(self, player):
        moves = np.where(self.board == 0)[0]
        boards = []
        for move in moves:
            board = np.copy(self.board)
            board[move] = player
            boards.append(board)
        return moves, np.array(boards, dtype='double')
    
    def swap_player(self):
        self.board = -self.board
    
    # oppents random move
    def make_move(self, player=-1):
        moves, _ = self.legal_moves(player)
        return self.step(np.random.choice(moves, 1),player)
    
    def step(self, move, player=1):
        assert self.board[move] == 0, "Tried to play an illegal move, player = %d"%player
        assert not self.done, "Game has finished must call tictactoe.reset()"
        self.board[move] = player
        reward = 0
        self.done = False
        if self.iswin(player):
            reward = 1
            self.done = True
        if not np.any(self.board==0):
            self.done = True
        return np.copy(self.board), reward, self.done
        
    def iswin(self, player):
        for i in range(3):
            if np.all(self.board[[i*3, i*3+1, i*3+2]]==player) | np.all(self.board[[i, i+3, i+6]]==player):
                return True
        if np.all(self.board[[0, 4, 8]] == player) | np.all(self.board[[2, 4, 6]] == player):
            return True
        return False
        
    def render(self):
        data_mat = self.board.reshape(3, 3)
        for i in range(0, 3):
            print('-------------')
            out = '| '
            for j in range(0, 3):
                token = ""
                if data_mat[i, j] == 1:
                    token = 'x'
                if data_mat[i, j] == 0:
                    token = ' '
                if data_mat[i, j] == -1:
                    token = 'o'
                out += token + ' | '
            print(out)
        print('-------------')
        

def reset_graph(seed=42):
    #tf.reset_default_graph()
    #tf.set_random_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
reset_graph()


# D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
D_in, H, D_out = 9, 50, 1

actor = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    torch.nn.Softmax(dim=0),
)
critic = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    torch.nn.Tanh(),
)
memory = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    torch.nn.Tanh(),
)

# save initial parameters of transient memory
initial_memory = [copy.deepcopy(param) for param in memory.parameters()]

def get_action_and_value(actor, boards):
    boards = torch.from_numpy(boards).float()
    possible_actions_probs = actor(boards)
    action = int(torch.multinomial(possible_actions_probs.view(1,-1), 1))
    action_value = possible_actions_probs[action]
    return action, action_value
    
def get_action_value(actor, boards, action):
    boards = torch.from_numpy(boards).float()
    possible_actions_probs = actor(boards)
    action_value = possible_actions_probs[action]
    return action_value

def get_action(actor, boards):
    boards = torch.from_numpy(boards).float()
    possible_actions_probs = actor(boards)
    action = torch.multinomial(possible_actions_probs.view(1,-1), 1)
    return int(action)

def get_state_value(nn_model, after_state):
    after_state = torch.from_numpy(after_state).float()
    value = nn_model(after_state)
    return value

def get_composite_value(critic, memory, after_state):
    after_state = torch.from_numpy(after_state).float()
    critic_value = critic(after_state)
    memory_value = memory(after_state)
    value = critic_value + memory_value
    return value
    

def epsilon_greedy(critic, possible_boards, epsilon=1):
    possible_boards = torch.from_numpy(possible_boards).float()
    values = critic(possible_boards)
    if np.random.random()<epsilon:
        _ , index = values.max(0)
    else:
        index = np.random.randint(0, len(possible_boards))
    return int(index)

def composite_greedy(critic, memory, possible_boards, epsilon=1):
    possible_boards = torch.from_numpy(possible_boards).float()
    critic_values = critic(possible_boards)
    memory_values = memory(possible_boards)
    values = critic_values + memory_values
    if np.random.random()<epsilon:
        _ , index = values.max(0)
    else:
        index = np.random.randint(0, len(possible_boards))
    return int(index)
        
        
gamma = .1
critic_alpha = 0.05
memory_alpha = 0.05

critic_lambda = 0.9
memory_lambda = 0.9

critic_Z = [0 for layer in critic.parameters()]
memory_Z = [0 for layer in memory.parameters()]

plt_iter = 1000
rew = []
rew_plt = []

from time import time
tic = time()

env = tictactoe()



"""
Use: valuefn_temp = search(pre_state, memory, n_dreams, max_steps)
Input: pre_state is current state,
       memory is the transient memorythe value function,
       old_value is the value of last after state
       n_dreams is number of dreams,
       max_steps is maximum number of steps,
Output: memory has been updated for episodes in dream
"""
def search(pre_state, pre_value, n_dreams, max_steps = 1000):
    # Clear transient memory and eligibility trace
    memory_Z = [0 for layer in critic.parameters()]
    with torch.no_grad():
        for i, param in enumerate(memory.parameters()):
            param.data.copy_(initial_memory[i])
    
    # Dream n_dreams
    for dreams in range(n_dreams):
        done = False
        memory_Z = [0 for layer in critic.parameters()]
        I = 1
        step = 1
        state = env.choose_board(np.copy(pre_state))
        old_value = pre_value
        
        #print('dream: ', dreams)
        #print('pre_state: ', pre_state)
        #env.render()
        
        # play one round and update
        while not (done or step > max_steps):
            #print('step: ', step)
            possible_moves, possible_boards = env.legal_moves(1)
            #env.render()
            # Use composite value function to choose action
            action = composite_greedy(critic, memory, possible_boards)
        
            after_state, reward, done = env.step(possible_moves[action])
        
            if not done:
                critic_value = get_state_value(critic, after_state)
                memory_value = get_state_value(memory, after_state)
                value = critic_value + memory_value
                #calc critic gradient
                memory.zero_grad()
                critic.zero_grad()
                value.backward()
                with torch.no_grad():
                    for i, param in enumerate(memory.parameters()):
                        memory_Z[i] = memory_lambda * memory_Z[i] + param.grad
            else:
                value = 0
            
            with torch.no_grad():
                # other players move
                if not done:
                    next_state, reward, done = env.make_move()
                    reward = -reward
                    next_value = get_state_value(critic, next_state)
                else:
                    next_value = 0
                
                delta = reward + gamma*value - old_value
                old_value = value
            
                # apply gradients
                for i, param in enumerate(memory.parameters()):
                    param += memory_alpha * delta * memory_Z[i]

            I *= gamma
            step +=1

    state = env.choose_board(pre_state)
    old_value = pre_value
        




In [11]:
forever = 2
for episode in range(forever):      
    state = env.reset()
    done = False
    critic_Z = [0 for layer in critic.parameters()]
    I = 1
    step = 1
    n_dreams = 50
    max_steps = 10
    
    while not done:
        
        possible_moves, possible_boards = env.legal_moves(1)
        if (step > 1):
            search(next_state, old_value, n_dreams, max_steps)
            #env.render()
            action = composite_greedy(critic, memory, possible_boards)
            
            print('currnt state')
            env.render()
            critic_value = get_state_value(critic, after_state)
            memory_value = get_state_value(memory, after_state)
            print('critic_value: ', critic_value)
            print('memory_value: ', memory_value)
            critic_action = epsilon_greedy(critic, possible_boards)
            print('action: ', possible_moves[action])
            print('critic_action: ', possible_moves[critic_action])
            
        else:
            action = epsilon_greedy(critic, possible_boards) # No search on first step
        
        after_state, reward, done = env.step(possible_moves[action])
        print('after state')
        env.render()
        
        if not done:
            value = get_state_value(critic, after_state)
            #calc critic gradient
            critic.zero_grad()
            value.backward()
            with torch.no_grad():
                for i, param in enumerate(critic.parameters()):
                    critic_Z[i] = critic_lambda * critic_Z[i] + param.grad
        else:
            value = 0
            
        with torch.no_grad():
            # other players move
            if not done:
                next_state, reward, done = env.make_move()
                reward = -reward
                next_value = get_state_value(critic, next_state)
            else:
                next_value = 0
                
            if step > 1:
                delta = reward + gamma*value - old_value
            
            old_value = value
        
            ###### plot
            """
            if episode%plt_iter == 0:
                env.render()
                if done:
                    print('Reward: ',reward)
                    rew_plt.append(np.mean(np.equal(rew,-1)))
                    rew = []
                    plt.plot(rew_plt)
                    plt.show()
                    rnd = False
                    print("Episode: {}".format(episode))
                    toc=time()
                    print('time per',plt_iter,':',toc-tic)
                    tic=toc
            """
            ######
            
            # apply gradients
            if step > 1:
                for i, param in enumerate(critic.parameters()):
                    param += critic_alpha * delta * critic_Z[i]
 
            
        I *= gamma
        step +=1
        
    rew.append(reward)

after state
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
| x |   |   | 
-------------
currnt state
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
| x | o |   | 
-------------
critic_value:  tensor([0.3140], grad_fn=<TanhBackward>)
memory_value:  tensor([-0.0975], grad_fn=<TanhBackward>)
action:  2
critic_action:  8
after state
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
| x | o |   | 
-------------
currnt state
-------------
|   | o | x | 
-------------
|   |   |   | 
-------------
| x | o |   | 
-------------
critic_value:  tensor([0.1909], grad_fn=<TanhBackward>)
memory_value:  tensor([-0.3758], grad_fn=<TanhBackward>)
action:  4
critic_action:  4
after state
-------------
|   | o | x | 
-------------
|   | x |   | 
-------------
| x | o |   | 
-------------
after state
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
| x |   |   | 
-------------
currnt state
-------------
|   |   |

In [None]:
env = tictactoe()
possible_moves, possible_boards = env.legal_moves(1)
b = np.array([1,  0, -1,  0,  1, -1, 0, -1,  1])
c = env.choose_board(b)
possible_moves, possible_boards = env.legal_moves(1)
possible_boards


In [None]:
c

In [None]:
print(c)

In [None]:
c

In [None]:
u

In [None]:
b = np.arrya([1,  0, -1,  1,  1, -1, 0, -1,  1]
a = u.choose_board(b)