In [4]:
import equinox as eqx
import jax.random as jr
import os
from tqdm import tqdm
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from tic_tac import TicTacToe
key = jr.key(420)

In [183]:

class Model(eqx.Module):
    # Layers
    embedd: eqx.nn.Linear

    Q1: eqx.nn.Linear
    K1: eqx.nn.Linear
    V1: eqx.nn.Linear

    Q2: eqx.nn.Linear
    K2: eqx.nn.Linear
    V2: eqx.nn.Linear

    post_att: eqx.nn.Linear
    aggregator: eqx.nn.Linear

    classifier1: eqx.nn.Linear
    classifier2: eqx.nn.Linear

    att_1_norm: eqx.nn.LayerNorm
    att_2_norm: eqx.nn.LayerNorm
    norm: eqx.nn.LayerNorm

    drop_1: eqx.nn.Dropout
    drop_2: eqx.nn.Dropout
    drop: eqx.nn.Dropout

    def __init__(self, key):
        keys = jr.split(key, 16)

        self.embedd = eqx.nn.Linear(9, 256, use_bias=False, key=keys[0])

        self.Q1 = eqx.nn.Linear(256, 256, key=keys[1])
        self.K1 = eqx.nn.Linear(256, 256, key=keys[2])
        self.V1 = eqx.nn.Linear(256, 256, key=keys[3])

        self.Q2 = eqx.nn.Linear(256, 256, key=keys[4])
        self.K2 = eqx.nn.Linear(256, 256, key=keys[5])
        self.V2 = eqx.nn.Linear(256, 256, key=keys[6])

        self.post_att = eqx.nn.Linear(256, 256, key=keys[7])
        self.aggregator = eqx.nn.Linear(512, 1, use_bias=False, key=keys[8])

        self.classifier1 = eqx.nn.Linear(512, 128, key=keys[9])
        self.classifier2 = eqx.nn.Linear(128, 9, use_bias=False, key=keys[10])

        self.att_1_norm = eqx.nn.LayerNorm(256)
        self.att_2_norm = eqx.nn.LayerNorm(256)
        self.norm = eqx.nn.LayerNorm(512)

        self.drop_1 = eqx.nn.Dropout(0.2)
        self.drop_2 = eqx.nn.Dropout(0.2)
        self.drop = eqx.nn.Dropout(0.4)

    def __call__(self, inputs, *, key, training=False, KV_K = False):

        k1, k2, k3 = jr.split(key, 3)

        inputs = jnp.array(inputs)

        # ---- Embedding ----
        inp = jax.vmap(jax.vmap(self.embedd, in_axes=0), in_axes=0)(inputs)   # (b, 6, 9)
        
        seq_len = inputs.shape[1]

        # ---- Causal mask ----
        if training :
            mask = jnp.tril(jnp.ones((seq_len, seq_len)))
            mask = (1.0 - mask) * -1e9
            mask = jnp.expand_dims(mask, axis=0)
        else :
            mask = jnp.zeros((seq_len, seq_len))

        # ---- KVQ ----
        Q1 = jax.vmap(jax.vmap(self.Q1))(inp)
        K1 = jax.vmap(jax.vmap(self.K1))(inp)
        V1 = jax.vmap(jax.vmap(self.V1))(inp)

        Q2 = jax.vmap(jax.vmap(self.Q2))(inp)
        K2 = jax.vmap(jax.vmap(self.K2))(inp)
        V2 = jax.vmap(jax.vmap(self.V2))(inp)
        
        # ---- Scaled Dot-Product ----
        scores_1 = jnp.matmul(Q1, jnp.swapaxes(K1, -1, -2)) / jnp.sqrt(256.0)
        scores_2 = jnp.matmul(Q2, jnp.swapaxes(K2, -1, -2)) / jnp.sqrt(256.0)

        # ---- Add mask ----
        scores_1 = scores_1 + mask
        scores_2 = scores_2 + mask

        # ---- Softmax + Dropout ----
        attn_weights_1 = self.drop_1(
            jax.nn.softmax(scores_1),
            key=k1,
            inference=not training
        )
        context_1 = jnp.matmul(attn_weights_1, V1)

        attn_weights_2 = self.drop_2(
            jax.nn.softmax(scores_2),
            key=k2,
            inference=not training
        )
        context_2 = jnp.matmul(attn_weights_2, V2)

        # ---- Residual + Norm ----
        att_1 = jax.vmap(jax.vmap(self.att_1_norm))(context_1 + inp)
        att_2 = jax.vmap(jax.vmap(self.att_2_norm))(context_2 + inp)

        # ---- Gated merge ----
        gate_input = jnp.concatenate([att_1, att_2], axis=-1)
        
        lamb = jax.nn.sigmoid(jax.vmap(jax.vmap(self.aggregator))(gate_input))

        att = lamb * att_1 + (1 - lamb) * att_2

        # ---- Post attention ----
        fin_att = jax.vmap(jax.vmap(self.post_att))(att)

        pooled = jnp.mean(fin_att, axis=1)
        last_token = fin_att[:, -1, :]
        concatenated = jnp.concatenate([pooled, last_token], axis=-1)

        concatenated = self.drop(
            concatenated,
            key=k3,
            inference=not training
        )

        att_out = jax.vmap(self.norm)(concatenated)

        x = jax.vmap(self.classifier1)(att_out)
        x = jax.nn.relu(x)
        x = jax.vmap(self.classifier2)(x)

        if KV_K:
            return {"res" : x, "att1" : att_1, "att2" : att_2}

        return {"res" : x}


key = jr.key(420)
k, key = jr.split(key, 2)
model = Model(key= k)

In [10]:
import json
import ast
import pandas as pd

def load_game_dict(filename="move_dict.json"):
    with open(filename, "r") as f:
        data = json.load(f)

    # convert string keys back to tuples
    restored_dict = {}
    for ii, tmp in enumerate(data.items()):
        key, value = tmp
        restored_dict[ii] = ast.literal_eval(key), value

    return restored_dict

df = pd.DataFrame(load_game_dict()).T
df.columns = ("state", "best")
df = df[(df["best"].apply(lambda x : None not in x ))]


def one_hot_vec(inp, key):

    inp = jnp.array(inp)
    n = jnp.round((1/len(inp)), 4)
    res = jnp.zeros(9)
    res = res.at[inp].set(n)
    layer = jr.randint(key, (1), 0 ,6)
    out = jnp.zeros((6,9))
    out = out.at[layer,].set(res)
    return layer, out

def board_fixer(inp, memory):

    inp = jnp.array(inp)
    out = jnp.zeros((6,9))
    out = out.at[memory,].set(inp)
    return out

target = []
memory = []
for ii in df["best"]:
    k, key = jr.split(key)
    remember, to_append = one_hot_vec(ii, k)
    memory.append(remember)
    target.append(to_append)

target = jnp.array(target)


board_fixed = []
for board, ii in zip(df["state"], memory):
    board_fixed.append(board_fixer(board, ii))

board_fixed = jnp.array(board_fixed)

In [11]:
print(target[4,])
board_fixed[4,]

[[0.  0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.5 0.  0.5]
 [0.  0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0.  0.  0. ]]


Array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 1., -1.,  1., -1.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

396.22

In [None]:
def CE_loss(y_preds, y_true):
    delta = 1e-9
    return -jnp.sum(y_true*jnp.log(y_preds+delta))

Array([0. , 0. , 0. , 0. , 0.5, 0.5, 0. , 0. , 0. ], dtype=float32)