In [None]:
import json
import keras
import keras_nlp
import pickle
import tensorflow as tf

In [None]:
model_filepath = "/kaggle/input/ann_model/keras/custom-model/1/custom_model.keras"
model = keras.models.load_model(model_filepath)

In [None]:
SEQ_LEN = 512  # Length of training sequences, in tokens. AKA the context size

# Special tokens
START_OF_RECIPE = "<|recipe_start|>"
END_OF_RECIPE = "<|recipe_end|>"
PAD = "<|pad|>"
OOV = "<|oov|>"
SPECIAL_TOKENS = [PAD, START_OF_RECIPE, END_OF_RECIPE, OOV]

# File names
VOCAB_FILE = "vocab.pickle"

In [None]:
with open(VOCAB_FILE, "rb") as f:
    vocab = pickle.load(f)

tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=vocab,
    sequence_length=SEQ_LEN,
    special_tokens_in_strings=True,
    special_tokens=SPECIAL_TOKENS,
    oov_token=OOV,
)

packer = keras_nlp.layers.StartEndPacker(
    sequence_length=SEQ_LEN,
    start_value=tokenizer.token_to_id(START_OF_RECIPE),
    end_value=tokenizer.token_to_id(END_OF_RECIPE),
    pad_value=tokenizer.token_to_id(PAD),
)


In [None]:
class TextGenerator():
    def __init__(self, model, p):
        self.model = model
        self.sampler = keras_nlp.samplers.TopPSampler(p=p, k=1024)

    def _tokenize_str(self, str):
        return packer(tokenizer([str]))

    def _next(self, prompt, cache, index):
        logits = self.model(prompt)[:, index-1, :]
        hidden_states = None,
        return logits, hidden_states, cache
    
    def _normalize_output(self, txt):
        txt = txt.split(END_OF_RECIPE)[0].split('}')[0]  + '}'
        txt = txt.replace(START_OF_RECIPE, "").replace(PAD, "")
        txt = txt.replace(OOV, "").replace(' " ', '"')
        try:
            txt = json.dumps(json.loads(txt), indent=4)
        except Exception as _:
            pass
        return txt

    def generate(self, seed_text, logs=None):
        seed_tokens = self._tokenize_str(seed_text)
        seed_length = tf.reduce_sum(tf.cast(~tf.equal(seed_tokens, 0), tf.int8)).numpy()
        output_tokens = self.sampler(
            next=self._next,
            prompt=seed_tokens,
            index=seed_length,
        )
        txt = tokenizer.detokenize(output_tokens).numpy()
        txt = txt[0].decode("utf-8")
        txt = self._normalize_output(txt)
        return txt
    
    def generate_recipe(self, ingredients):
        if not ingredients: return self.generate('')
        seed_text = '{"ner": ['
        for ingredient in ingredients[:-1]:
            seed_text += f'"{ingredient}", ' 
        seed_text += f'"{ingredients[-1]}","' if ingredients else ''
        
        return self.generate(seed_text)
    
text_generator = TextGenerator(model, 0.9)