In [115]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import GRU, Dense, Embedding, StringLookup
from tensorflow.keras.models import Model

In [161]:
class TokenDistribution(Model):
    def __init__(self, tokens, embedding_dim, rnn_units):
        super().__init__(self)

        vocab_size = len(tokens)
        self.vocabulary = [token.symbol for token in tokens if token.symbol != "empty"]

        self.token_to_int = StringLookup(vocabulary=self.vocabulary, mask_token="empty")
        self.int_to_token = StringLookup(
            vocabulary=self.vocabulary,
            invert=True,
            mask_token="empty",
        )
        self.embedding = Embedding(input_dim=vocab_size, output_dim=embedding_dim)
        self.gru = GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = Dense(len(self.vocabulary))

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.token_to_int(x)
        x = self.embedding(x, training=training)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)

        if return_state:
            return x, states
        else:
            return x

In [162]:
from dataclasses import dataclass


@dataclass
class Token:
    symbol: str
    arity: float

In [163]:
tokens = [Token(char, 2) for char in "+-*/"]
tokens.extend([Token(symbol, 1) for symbol in ["sin", "cos", "exp", "log"]])
tokens.extend([Token(symbol, 0) for symbol in ("const", "var", "empty")])

In [175]:
model = TokenDistribution(tokens, 16, 16)

In [176]:
model(np.array([["+"]]))

<tf.Tensor: shape=(1, 1, 10), dtype=float32, numpy=
array([[[-0.00641408, -0.00216806, -0.01779224, -0.00205745,
          0.00233911,  0.01203147,  0.00506462,  0.00411894,
          0.00392219, -0.00754305]]], dtype=float32)>

In [201]:
class Sampler:
    def __init__(self, tokens, embedding_dim, rnn_units):
        self.model = TokenDistribution(tokens=tokens, embedding_dim=embedding_dim, rnn_units=rnn_units)
        
    def sample(self):
        inputs = np.array([["empty"], ["empty"]])
        logits = model(inputs)
        next_index = tf.random.categorical(logits[:, -1, :], 1)
        return model.int_to_token(next_index)

In [202]:
model.int_to_token(tf.random.categorical(model(np.array([["+"], ["sin"]]))[:, -1, :], 1))

<tf.Tensor: shape=(2, 1), dtype=string, numpy=
array([[b'+'],
       [b'cos']], dtype=object)>

In [203]:
sampler = Sampler(tokens, 16, 16)

In [226]:
sampler.sample()

<tf.Tensor: shape=(2, 1), dtype=string, numpy=
array([[b'[UNK]'],
       [b'+']], dtype=object)>