In [113]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.optim import Adam

BOARD_ROWS = 3
BOARD_COLS = 3


In [114]:
class State:
    def __init__(self):
        self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
        #self.p1 = p1
        #self.p2 = p2
        self.isEnd = False
        self.boardHash = None
        # init p1 plays first
        self.playerSymbol = 1

    # get unique hash of current board state
    def getHash(self):
        t = self.board.copy()
        t = t.reshape(1,1,3,3)
        self.boardHash = torch.from_numpy(t).double()
        return self.boardHash

    def winner(self):
        # row
        for i in range(BOARD_ROWS):
            if sum(self.board[i, :]) == 3:
                self.isEnd = True
                return 1
            if sum(self.board[i, :]) == -3:
                self.isEnd = True
                return -1
        # col
        for i in range(BOARD_COLS):
            if sum(self.board[:, i]) == 3:
                self.isEnd = True
                return 1
            if sum(self.board[:, i]) == -3:
                self.isEnd = True
                return -1
        # diagonal
        diag_sum1 = sum([self.board[i, i] for i in range(BOARD_COLS)])
        diag_sum2 = sum([self.board[i, BOARD_COLS - i - 1] for i in range(BOARD_COLS)])
        diag_sum = max(abs(diag_sum1), abs(diag_sum2))
        if diag_sum == 3:
            self.isEnd = True
            if diag_sum1 == 3 or diag_sum2 == 3:
                return 1
            else:
                return -1

        # tie
        # no available positions
        if len(self.availablePositions()) == 0:
            self.isEnd = True
            return 0
        # not end
        self.isEnd = False
        return 0

    def availablePositions(self):
        positions = []
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                if self.board[i, j] == 0:
                    positions.append((i, j))  # need to be tuple
        return positions


    def updateState(self, position0, position1):
        if self.board[position0][position1] == 0:
            #print(self.board[position0][position1])
            self.board[position0][position1] = self.playerSymbol
            self.playerSymbol = -1 if self.playerSymbol == 1 else 1
            return False
        else: 
            self.board[position0][position1] = self.board[position0][position1]
            self.playerSymbol = -1 if self.playerSymbol == 1 else 1
            self.isEnd = True
            return True
            
    

    # board reset
    def reset(self):
        self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
        self.boardHash = self.getHash()
        self.isEnd = False
        self.playerSymbol = 1
        return self.boardHash

    def play(self, rounds=100):
        for i in range(rounds):
            if i % 1000 == 0:
                print("Rounds {}".format(i))
            while not self.isEnd:
                # Player 1
                positions = self.availablePositions()
                p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
                # take action and upate board state
                self.updateState(p1_action)
                board_hash = self.getHash()
                self.p1.addState(board_hash)
                # check board status if it is end

                win = self.winner()
                if win is not None:
                    # self.showBoard()
                    # ended with p1 either win or draw
                    self.giveReward()
                    self.p1.reset()
                    self.p2.reset()
                    self.reset()
                    break

                else:
                    # Player 2
                    positions = self.availablePositions()
                    p2_action = self.p2.chooseAction(positions, self.board, self.playerSymbol)
                    self.updateState(p2_action)
                    board_hash = self.getHash()
                    self.p2.addState(board_hash)

                    win = self.winner()
                    if win is not None:
                        # self.showBoard()
                        # ended with p2 either win or draw
                        self.giveReward()
                        self.p1.reset()
                        self.p2.reset()
                        self.reset()
                        break

    # play with human
    def play2(self):
        while not self.isEnd:
            # Player 1
            positions = self.availablePositions()
            p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
            # take action and upate board state
            self.updateState(p1_action)
            self.showBoard()
            # check board status if it is end
            win = self.winner()
            if win is not None:
                if win == 1:
                    print(self.p1.name, "wins!")
                else:
                    print("tie!")
                self.reset()
                break

            else:
                # Player 2
                positions = self.availablePositions()
                p2_action = self.p2.chooseAction(positions)

                self.updateState(p2_action)
                self.showBoard()
                win = self.winner()
                if win is not None:
                    if win == -1:
                        print(self.p2.name, "wins!")
                    else:
                        print("tie!")
                    self.reset()
                    break

    def showBoard(self):
        # p1: x  p2: o
        for i in range(0, BOARD_ROWS):
            print('-------------')
            out = '| '
            for j in range(0, BOARD_COLS):
                if self.board[i, j] == 1:
                    token = 'x'
                if self.board[i, j] == -1:
                    token = 'o'
                if self.board[i, j] == 0:
                    token = ' '
                out += token + ' | '
            print(out)
        print('-------------')


In [115]:
def get_policy(obs):
    logits = logits_net(obs)
    #print(Categorical(logits=logits))
    return Categorical(logits=logits)



# make action selection function (outputs int actions, sampled from policy)
def chooseAction(obs, st):
    if random.uniform(0, 1) < .09 or st.playerSymbol == -1:
        t = random.randint(0, 8)
    else:
        t = get_policy(obs).sample().item()
    #print(t)
    if t == 0:
        a = 0
        b = 0
    else:
        a = t//3
        if a ==0 :
            b = t%(3)
        else:
            b = t%(a*3)
    while st.board[a][b] != 0 and st.playerSymbol == -1:
        t = random.randint(0, 8)
        if t == 0:
            a = 0
            b = 0
        else:
            a = t//3
            if a ==0 :
                b = t%(3)
            else:
                b = t%(a*3)
    #print(a,b)
    return a,b, t


def compute_loss(obs, act, weights):
    logp = get_policy(obs).log_prob(act)
    return -(logp * weights).mean()


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 19683, kernel_size=3, stride=1, padding=0)
        self.fc1 = nn.Linear(19683, 81)
        self.fc2 = nn.Linear(81, 9)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x

In [116]:
logits_net = Net().double()
optimizer = Adam(logits_net.parameters(), lr=1e-2)

def train():
    st = State()
    obs = st.reset()
    batch_size = 10000
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    
    epi = 0
    
    # reset episode-specific variables
    obs = st.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # render first episode of each epoch
    finished_rendering_this_epoch = False

    # collect experience by acting in the environment with current policy
    while True:

        # rendering
        if (not finished_rendering_this_epoch):
            st.showBoard

        # save obs
        batch_obs.append(obs)

        # act in the environment
        act0, act1, act = chooseAction(obs, st)#change chooseAction to be made using NN
        illegal = st.updateState(act0, act1)
        
        obs = st.getHash()# change method so that it returns NN ingestible array
        rew = st.winner()
        
        if illegal:
            rew = -1
            st.isEnd = True

        # save action, reward
        batch_acts.append(act)
        ep_rews.append(rew)
        
        epi +=1 
        if epi >=9999:
            st.showBoard()
            print(rew)
            epi = 0
        
        if st.isEnd:
            #print(rew)
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)

            # the weight for each logprob(a|s) is R(tau)
            batch_weights += [ep_ret] * ep_len

            # reset episode-specific variables
            obs, done, ep_rews = st.reset(), False, []

            # won't render again this epoch
            finished_rendering_this_epoch = True
            
            #epi = 0

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                print(len(batch_obs))
                break
    
    #print('p')
    # take a single policy gradient update step
    optimizer.zero_grad()
    #batch_obs = torch.cat(batch_obs, 0)
    #print(batch_obs)
    #batch_obs = torch.from_numpy(np.asarray(batch_obs))
    batch_loss = compute_loss(obs=torch.cat(batch_obs, 0),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                              )
    print(batch_loss)
    batch_loss.backward()
    optimizer.step()

In [None]:
from IPython.display import clear_output
import random
for i in range(1000):
    #clear_output(wait=True)
    #print(i)
    train()

-------------
|   |   |   | 
-------------
|   |   | x | 
-------------
| o |   |   | 
-------------
0
10004
tensor(-1.7880, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | x | o | 
-------------
|   | o | x | 
-------------
|   |   |   | 
-------------
0
10001
tensor(-2.0060, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x |   | o | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
0
10001
tensor(-1.9619, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o |   |   | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
0
10002
tensor(-1.8211, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o |   |   | 
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
0
10003
tensor(-1.9305, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   | o | 
-------------
|   | x | x | 
-------------
|   | o |   | 
-------------
0
10007
tensor(-1.9557, dtype=torch.float64, grad_fn=<

tensor(-1.9449, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o | x |   | 
-------------
| o | o | x | 
-------------
| x |   |   | 
-------------
-1
10002
tensor(-1.9344, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o |   |   | 
-------------
|   | x | x | 
-------------
|   | o |   | 
-------------
-1
10004
tensor(-1.9194, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
0
10003
tensor(-1.9288, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   |   | x | 
-------------
| o |   |   | 
-------------
0
10002
tensor(-1.9509, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   | x | 
-------------
| o |   | x | 
-------------
| o |   |   | 
-------------
0
10004
tensor(-2.0166, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   | o |   | 
-------------
|   |   | x | 
-------------
|   |   |   | 
------

tensor(-1.9460, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
0
10005
tensor(-1.9362, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
|   |   | x | 
-------------
0
10003
tensor(-1.9446, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x |   |   | 
-------------
|   |   | x | 
-------------
|   |   | o | 
-------------
0
10001
tensor(-1.9618, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | o |   | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
0
10001
tensor(-1.9482, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x |   |   | 
-------------
| o |   |   | 
-------------
|   |   | x | 
-------------
0
10001
tensor(-1.9582, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   | x | o | 
-------------
|   |   |   | 
--------

tensor(-1.9624, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | x |   | 
-------------
|   | o |   | 
-------------
| o |   | x | 
-------------
0
10001
tensor(-1.9266, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
0
10001
tensor(-1.9226, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | x |   | 
-------------
|   | x |   | 
-------------
| o |   | o | 
-------------
0
10001
tensor(-1.9362, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o | x |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
0
10002
tensor(-1.9059, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o |   | x | 
-------------
|   |   | x | 
-------------
| o |   |   | 
-------------
-1
10004
tensor(-1.9665, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------

tensor(-1.9428, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | o |   | 
-------------
| o | x | x | 
-------------
|   |   | o | 
-------------
0
10002
tensor(-1.9927, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
|   |   | x | 
-------------
|   |   |   | 
-------------
0
10003
tensor(-1.8886, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| o | o |   | 
-------------
|   | x |   | 
-------------
|   | x |   | 
-------------
0
10004
tensor(-1.9178, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | o | x | 
-------------
|   |   | o | 
-------------
|   | x | o | 
-------------
0
10005
tensor(-1.9787, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   | o | 
-------------
| x | x |   | 
-------------
|   |   |   | 
-------------
0
10001
tensor(-1.9670, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
| o | x | x | 
-------------
|   |   |   | 
--------

tensor(-1.8869, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x | x | o | 
-------------
|   | x |   | 
-------------
|   | o |   | 
-------------
0
10001
tensor(-1.9845, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x |   |   | 
-------------
|   |   | o | 
-------------
|   | x | o | 
-------------
0
10005
tensor(-1.9631, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   | o | 
-------------
| o | x | x | 
-------------
| x |   |   | 
-------------
0
10001
tensor(-1.9447, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
|   |   |   | 
-------------
| x | o | x | 
-------------
| o | x | o | 
-------------
-1
10004
tensor(-1.9556, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x |   | x | 
-------------
|   | x |   | 
-------------
|   | o | o | 
-------------
0
10002
tensor(-1.9346, dtype=torch.float64, grad_fn=<NegBackward>)
-------------
| x |   |   | 
-------------
|   |   |   | 
-------------
|   |   | o | 
-------

In [None]:
logits_net = Net().double()
st = State()
obs = st.reset()
act0, act1, act = chooseAction(obs)#change chooseAction to be made using NN
st.updateState(act0, act1)
obs = st.getHash()# change method so that it returns NN ingestible array
st.showBoard()

act0, act1, act = chooseAction(obs)#change chooseAction to be made using NN
st.updateState(act0, act1)
obs = st.getHash()# change method so that it returns NN ingestible array
st.showBoard()

In [None]:
batch_size = 1000
# make some empty lists for logging.
batch_obs = []          # for observations
batch_acts = []         # for actions
batch_weights = []      # for R(tau) weighting in policy gradient
batch_rets = []         # for measuring episode returns
batch_lens = []         # for measuring episode lengths

# reset episode-specific variables
obs = st.reset()       # first obs comes from starting distribution
done = False            # signal from environment that episode is over
ep_rews = []            # list for rewards accrued throughout ep

# render first episode of each epoch
finished_rendering_this_epoch = False

# collect experience by acting in the environment with current policy
while True:

    # rendering
    if (not finished_rendering_this_epoch):
        st.showBoard

    # save obs
    batch_obs.append(obs)
    
    c = False
    # act in the environment
    while c==False:
        act0, act1, act = chooseAction(obs)#change chooseAction to be made using NN
        c = st.updateState(act0, act1)
        
    obs = st.getHash()# change method so that it returns NN ingestible array
    rew = st.winner()

    # save action, reward
    batch_acts.append(act)
    ep_rews.append(rew)

    st.showBoard()

    if st.isEnd:
        # if episode is over, record info about episode
        ep_ret, ep_len = sum(ep_rews), len(ep_rews)
        batch_rets.append(ep_ret)
        batch_lens.append(ep_len)

        # the weight for each logprob(a|s) is R(tau)
        batch_weights += [ep_ret] * ep_len

        # reset episode-specific variables
        obs, done, ep_rews = st.reset(), False, []

        # won't render again this epoch
        finished_rendering_this_epoch = True

        # end experience loop if we have enough of it
        if len(batch_obs) > batch_size:
            break

# take a single policy gradient update step
optimizer.zero_grad()
#batch_obs = torch.cat(batch_obs, 0)
#print(batch_obs)
#batch_obs = torch.from_numpy(np.asarray(batch_obs))
batch_loss = compute_loss(obs=torch.cat(batch_obs, 0),
                          act=torch.as_tensor(batch_acts, dtype=torch.int32),
                          weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                          )
batch_loss.backward()
optimizer.step()