In [1]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import imageio
import os
import torch

In [52]:
states = torch.eye(64)
rewards = torch.zeros(64)+0.5
rewards[9] = 1.0
rewards[18] = 0.0
rewards[10] = 0.0
rewards[17] = 0.0
rewards[54] = 1.0
rewards[45] = 0.0
rewards[53] = 0.0
rewards[46] = 0.0

class Piece:
    def __init__(self, moves):
        self.discount = 0.95
        self.moves = self.makeMoves(moves)

        self.weights = (9**0.5)*torch.rand(64,self.moves.shape[0]).requires_grad_()
        self.biases = (9**0.5)*torch.rand(self.moves.shape[0],).requires_grad_()
        self.optimizer = torch.optim.Adam([self.weights,self.biases])

    def makeMoves(self, initMoves):
        moves = [[0, 0]]
        for _ in range(4):
            for index, [row, col] in enumerate(initMoves):
                moves.append([row, col])
                initMoves[index][0], initMoves[index][1] = (
                    initMoves[index][1],
                    -initMoves[index][0],
                )
        return torch.tensor(moves)

    def valid_move(self, x, h):
        return (x + h[0] + h[1]*8 if (0 <= (x + h[0] + h[1]*8) < 64) and ((x%8) == ((x + h[1]*8)%8)) and (torch.div(x,8, rounding_mode='trunc') == torch.div((x + h[0]),8, rounding_mode='trunc'))
        else x)

    def valid_step(self, q,p,trials):
        policies = torch.sigmoid(p @ self.weights + self.biases)
        actions = (torch.distributions.Categorical(probs=policies).sample((trials,)).mode(0)[0]).long()
        nextStates = torch.tensor([self.valid_move(st,action) for st,action in zip(q,self.moves[actions])]).long()
        s, a, r = q, actions, rewards[nextStates].reshape((64,1))
        return [s,a,r,nextStates,policies.detach()]

    def samples(self, sts,t,trials):
        v = [self.valid_step(torch.arange(64),sts,trials)]
        for _ in range(t):
            v.append(self.valid_step(v[-1][3],sts[v[-1][3]],trials))

        for j in range(len(v)):
            G = 0.0
            vTranspose = list(zip(*v))[2]
            for k,r in enumerate(vTranspose[j:]):
                # if k < len(vTranspose[j:])-1:
                #     G += (vTranspose[j:][k+1] + r - v[j+k][4])*(0.9**k)*torch.eye(self.moves.shape[0])[v[j+k][1]]
                G += r*(0.9**k)*torch.eye(self.moves.shape[0])[v[j+k][1]]
            v[j][2] = (G/(len(v)))+(1/2)

        return v

In [53]:
class King(Piece):
    def __init__(self):
        super().__init__([[0, -1], [-1, -1]])
        self.label = "King"


class Knight(Piece):
    def __init__(self):
        super().__init__([[-1, -2], [-2, -1]])
        self.label = "Knight"


class Bishop(Piece):
    def __init__(self):
        super().__init__([[-n, -n] for n in range(1, 8)])
        self.label = "Bishop"


class Rook(Piece):
    def __init__(self):
        super().__init__([[-n, 0] for n in range(1, 8)])
        self.label = "Rook"


class Queen(Piece):
    def __init__(self):
        super().__init__(
            [[-n, -n] for n in range(1, 8)] + [[-n, 0] for n in range(1, 8)]
        )
        self.label = "Queen"


In [54]:
def render(board, piece, result, epoch):
    # plt 1
    sns.set(rc={"figure.figsize": (9, 3)})
    sns.set_style("whitegrid")
    _, axs = plt.subplots(ncols=3)
    sns.lineplot(
        data=result,
        ax=axs[0],
    )
    # axs[0].legend_.remove()
    axs[0].set_title("Mean Absolute Error", fontsize=8)
    plt.setp(axs[0].lines, color="#699CB3", linewidth=0.75)

    # plt 2
    sns.lineplot(
        data=torch.sigmoid(board @ piece.weights + piece.biases).detach().numpy(),
        palette=sns.color_palette(f"dark:#347893_r", len(piece.moves)),
        dashes={n: "" for n in range(len(piece.moves))},
        ax=axs[1],
    )
    axs[1].legend_.remove()
    axs[1].set_title("Policy value by board state", fontsize=8)
    plt.setp(axs[1].lines, linewidth=0.75)

    # plt 3
    heatmap_data = np.array(
        [np.mean(i) for i in torch.sigmoid(board @ piece.weights + piece.biases).detach().numpy()]
    ).reshape((8, 8))
    sns.heatmap(
        data=heatmap_data,
        ax=axs[2],
        cbar=False,
        cmap=sns.light_palette("#205565", as_cmap=True, reverse=True),
    ).invert_yaxis()
    axs[2].set_title("Mean value by board state", fontsize=8)

    plt.savefig(f"./results/{epoch}.png")
    plt.close()
    return f"./results/{epoch}.png"


def saveGif(label, files):
    with imageio.get_writer(f"./{label}.gif", mode="I") as writer:
        for file in files:
            image = imageio.imread(file)
            writer.append_data(image)
        # repeat last image to observe result
        for _ in range(36):
            image = imageio.imread(files[-1])
            writer.append_data(image)

    for file in set(files):
        os.remove(file)



In [60]:
k = King()
results = []
files = []

timesteps = 3
trials = 1000
for test in range(300):
    examples = k.samples(states,timesteps,trials)
    # training loop
    for n,c in enumerate(examples):
        for epoch in range(10):
            x = torch.eye(64)[c[0]]
            y =  c[2]
            # feed forward
            pred = torch.sigmoid(x @ k.weights + k.biases)
            loss = pred - y*torch.log(pred)
            if n==0 and epoch==0:
                results.append(torch.mean(loss))
                files.append(render(torch.eye(64),k,torch.tensor(results),test))

            # backprop
            k.optimizer.zero_grad()
            loss.backward(gradient=loss)
            k.optimizer.step()
            
saveGif(f"new{k.label}", files)