In [None]:
from __future__ import annotations

from functools import partial
import dataclasses
import serde

import jax
import optax
from jax import random, numpy as jnp
from flax import linen as nn


@serde.serde
@dataclasses.dataclass(frozen=True)
class TransformerConfig:
    num_heads: int
    embed_dim: int
    num_hidden_layers: int
    max_n_ply: int = 201
    strategy: bool = False

    def create_model(self) -> 'Transformer':
        return Transformer(self)


class Embeddings(nn.Module):
    embed_dim: int
    vocab_sizes: list[int]

    @nn.compact
    def __call__(self, tokens: jnp.ndarray, eval: bool):
        embeddings = jnp.zeros((*tokens.shape[:-1], self.embed_dim))

        for i in range(len(self.vocab_sizes)):
            tokens_i = jnp.clip(tokens[..., i], 0, self.vocab_sizes[i] - 1)
            embeddings += nn.Embed(self.vocab_sizes[i], self.embed_dim)(tokens_i)

        embeddings = nn.LayerNorm(epsilon=1e-12)(embeddings)
        embeddings = nn.Dropout(0.5, deterministic=eval)(embeddings)

        return embeddings


class MultiHeadAttention(nn.Module):
    num_heads: int
    embed_dim: int

    @nn.compact
    def __call__(self, x, mask):
        seq_len = x.shape[1]
        head_dim = self.embed_dim // self.num_heads

        v = nn.Dense(features=self.embed_dim)(x)  # [Batch, SeqLen, Head * Dim]
        q = nn.Dense(features=self.embed_dim)(x)  # [Batch, SeqLen, Head * Dim]
        k = nn.Dense(features=self.embed_dim)(x)  # [Batch, SeqLen, Head * Dim]

        v = v.reshape(-1, seq_len, self.num_heads, head_dim)  # [Batch, SeqLen, Head, Dim]
        q = q.reshape(-1, seq_len, self.num_heads, head_dim)  # [Batch, SeqLen, Head, Dim]
        k = k.reshape(-1, seq_len, self.num_heads, head_dim)  # [Batch, SeqLen, Head, Dim]

        # [Batch, Head, SeqLen, SeqLen]
        attention = (jnp.einsum('...qhd,...khd->...hqk', q, k) / jnp.sqrt(head_dim))

        attention = jnp.where(mask, attention, -jnp.inf)
        attention = nn.softmax(attention, axis=-1)

        values = jnp.einsum('...hqk,...khd->...qhd', attention, v)  # [Batch, SeqLen, Head, Dim]
        values = values.reshape(-1, seq_len, self.num_heads * head_dim)  # [Batch, SeqLen, Head × Dim (=EmbedDim)]
        out = nn.Dense(self.embed_dim)(values)

        return out


class FeedForward(nn.Module):
    embed_dim: int
    intermediate_size: int = 128

    @nn.compact
    def __call__(self, x, eval):
        x = nn.Dense(features=self.embed_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.embed_dim)(x)
        x = nn.Dropout(0.1, deterministic=eval)(x)
        return x


class TransformerBlock(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        self.attention = MultiHeadAttention(self.num_heads, self.embed_dim)
        self.feed_forward = FeedForward(embed_dim=self.embed_dim)

    @nn.compact
    def __call__(self, x, attention_mask, eval):
        out = self.attention(x, attention_mask)

        x = x + out
        x = nn.LayerNorm()(x)
        x = x + self.feed_forward(x, eval)
        x = nn.LayerNorm()(x)
        return x


class Transformer(nn.Module):
    config: TransformerConfig

    def __hash__(self):
        return hash(self.config)

    def setup(self):
        self.embeddings = Embeddings(self.config.embed_dim, self.config.vocab_sizes)
        self.st_dence = nn.Dense(features=self.config.embed_dim)

        self.layers = [
            TransformerBlock(self.config.num_heads, self.config.embed_dim)
            for _ in range(self.config.num_hidden_layers)
        ]

    @nn.compact
    def __call__(self, x: jnp.ndarray, eval=True):
        x = self.embeddings(x, eval)

        # [Batch, 1, SeqLen, SeqLen]
        mask = nn.make_causal_mask(jnp.zeros((x.shape[0], x.shape[1])), dtype=bool)

        for i in range(self.config.num_hidden_layers):
            x = self.layers[i](x, mask, eval=eval)

        x = nn.Dropout(0.1, deterministic=eval)(x)

        p = nn.Dense(features=32, name="head_p")(x)
        v = nn.Dense(features=7, name="head_v")(x)
        c = nn.Dense(features=8, name="head_c")(x)

        return p, v, c  # [Batch, SeqLen, ...]


@jax.jit
def calc_loss(
    x: jnp.ndarray,
    p_pred: jnp.ndarray, v_pred: jnp.ndarray, c_pred: jnp.ndarray,
    p_true: jnp.ndarray, v_true: jnp.ndarray, c_true: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:

    mask = jnp.any(x != 0, axis=-1)
    mask = mask.reshape(-1)

    # [Batch, SeqLen, 144]
    p_true = p_true.reshape(-1)
    v_true = jnp.stack([v_true]*v_pred.shape[-2], axis=-1).reshape(-1)
    c_true = jnp.stack([c_true]*c_pred.shape[-2], axis=-1).reshape(-1, 8)
    # c_true = c_true.reshape(-1, 1, 8)

    p_pred = p_pred.reshape(-1, 32)
    v_pred = v_pred.reshape(-1, 7)
    c_pred = c_pred.reshape(-1, 8)

    loss_p = optax.softmax_cross_entropy_with_integer_labels(p_pred, p_true)
    loss_v = optax.softmax_cross_entropy_with_integer_labels(v_pred, v_true)
    loss_c = optax.sigmoid_binary_cross_entropy(c_pred, c_true).mean(axis=-1)

    loss_p = jnp.average(loss_p, weights=mask)
    loss_v = jnp.average(loss_v, weights=mask)
    loss_c = jnp.average(loss_c, weights=mask)

    loss = loss_p + loss_v + loss_c
    losses = jnp.array([loss_p, loss_v, loss_c])

    return loss, losses


@partial(jax.jit, static_argnames=['eval'])
def loss_fn(
    params,
    state: TrainStateTransformer,
    x: jnp.ndarray, st: jnp.ndarray,
    p_true: jnp.ndarray,
    v_true: jnp.ndarray,
    c_true: jnp.ndarray,
    dropout_rng,
    eval: bool
) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]:
    # p, v, c = state.apply_fn({'params': params}, tokens, eval=eval, rngs={'dropout': dropout_rng})
    p, v, c = state.apply_fn({'params': params}, x, st, eval=eval, rngs={'dropout': dropout_rng})
    loss, losses = calc_loss(x, p, v, c, p_true, v_true, c_true)

    return loss, losses
