## Imports and Model setup

In [1]:
import keras_nlp
import keras
import numpy as np

import tensorflow.data as tf_data
import tensorflow.strings as tf_strings
import tensorflow.keras.optimizers.schedules as schedules
from keras.layers import LeakyReLU

import pickle
import time

#==== Constants used to define model parameters ===#

SEQ_LEN = 512
EMBED_DIM = 1024
FEED_FORWARD_DIM = 1024
NUM_HEADS = 16
NUM_LAYERS = 20
VOCAB_SIZE = 50000
start_time = time.time()

#============ Code to build the model ============#

# Load custom tokenizer
with open('./models/wikitext103_tokenizer50k.pkl', 'rb') as saved_tokenizer:
    tokenizer = pickle.load(saved_tokenizer)
    
# Start packer adds a token to signify the beginning of a string
start_packer = keras_nlp.layers.StartEndPacker(sequence_length=SEQ_LEN,start_value=tokenizer.token_to_id("[BOS]"),)

# Input layer
inputs = keras.layers.Input(shape=(None,), dtype="int32")

# Embedding Layer: 50000 vocab -> 1024 dim embedding
embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding(vocabulary_size=VOCAB_SIZE,sequence_length=SEQ_LEN,embedding_dim=EMBED_DIM,mask_zero=True,)
x = embedding_layer(inputs)

# Add NUM_LAYERS (20) transformers in a row sequentially
for _ in range(NUM_LAYERS):
    decoder_layer = keras_nlp.layers.TransformerDecoder(num_heads=NUM_HEADS,intermediate_dim=FEED_FORWARD_DIM,activation=LeakyReLU(0.1),)
    x = decoder_layer(x) 

# Output: 50,000 tokens
outputs = keras.layers.Dense(VOCAB_SIZE)(x)

# Construct Model
model = keras.Model(inputs=inputs, outputs=outputs)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
opt = keras.optimizers.Adam(learning_rate=5e-5)
perplexity = keras_nlp.metrics.Perplexity(from_logits=True, mask_token_id=0)

# Model is complete
model.compile(optimizer=opt, loss=loss_fn, metrics=[perplexity])


#========== Code to run the model ============+

# Function repeatedly called for each new token generation
def next(prompt, cache, index):
    logits = model(prompt)[:, index - 1, :]
    hidden_states = None
    return logits, hidden_states, cache

### Helper function to generate a response to the prompt. Relies on global variable model
# prompt is a python string
# show_verbose displays the output even after the model has said it is complete with end of string token (unpredictable)
# Temperature defines how 'creative' the model should be, sampling from the top probabilities
def generate_response(prompt, show_verbose=False, temperature=0.5):
    # convert prompt to tokens
    prompt_tokens = start_packer(tokenizer([prompt]))
    
    # Check how many user-input tokens are present
    prompt_length = np.sum(prompt_tokens.numpy() != 0)
    
    # Define sampler
    sampler = keras_nlp.samplers.TopPSampler(p=temperature)
    
    # Run the model 512 times because I can't figure out how to get the end token to work properly
    output_tokens = sampler(next=next,prompt=prompt_tokens,index=prompt_length,)
    
    # Convert token list back to string
    txt = tokenizer.detokenize(output_tokens)
    
    # Show the full unshortened output
    # Behavior after the first [PAD] token is undefined, as that signals a new wiki entry, and there is no way to predict the next wikipedia page in the dataset
    if show_verbose:
        print(f"Raw text: \n{txt}\n")
        
    # Shorten text to only show desired output and print
    output_txt = str(txt.numpy()[0]).split("[BOS] ")[-1].split(' [PAD]')[0]
    print(f"Generated text: \n{output_txt.replace('@ - @', '-').replace('@ , @', ',').replace('@ . @', '.')}\n")
print(f'Completed setup in {round(time.time() - start_time, 2)}s')

Using TensorFlow backend
Completed setup in 5.07s


## Loading Cell 1: Load QA Model

In [2]:
# Run this Cell to load Question-Answer model
model.load_weights('./checkpoints/Model_240M_50kvocab_30ksSquadfinetune.krs');

In [3]:
# Safety tip: Code for generating responses is very slow. I have an RTX 3090, and generations still take over a minute. Beware if running on CPU.
start_time = time.time()
generate_response('what genre of music do the beatles play ?')
print(f'Response took {round(time.time() - start_time, 1)}s to complete')

Generated text: 
what genre of music do the beatles play ? rock

Response took 108.1s to complete


In [4]:
generate_response('What is a country that borders Nigeria?')

Generated text: 
what is a country that borders nigeria ? tanzanian republic



In [5]:
generate_response('What are some species that live in the Amazon Rainforest?')

Generated text: 
what are some species that live in the amazon rainforest ? european minke ( eucalyptus ) , italian sparrowhawk ( canisemus ) , and italian otter ( felis )



In [6]:
generate_response('When did World War II begin?')

Generated text: 
when did world war ii begin ? 1939



In [7]:
generate_response('How tall is Mount Everest?')

Generated text: 
how tall is mount everest ? 10 , 000 feet



## Loading Cell 2: Load WikiText Model

In [14]:
# Run this cell to run wikitext model
# The longer responses are more prone to looping, so I like to use a higher temperature, although this leads to less factually accurate responses
model.load_weights('./checkpoints/Model_240M_50kvocab_157kseconds.krs');

In [9]:
generate_response('The Beatles are a band famous for making music in the genre of', temperature=0.7)

Generated text: 
the beatles are a band famous for making music in the genre of rock . they were created to promote their first album in the early 1960s and have played many of their albums since the 1960s . their live show is a satirical piece for people living in the beatles \' 1967 show in west london . they are famous for their lack of musical influence in the beatles \' songs , which were written to honor the composer \' s fame and their heritage . their second album , there were no new musical style and no lyrics were in the mix .



In [10]:
generate_response('Nigeria shares a land border with', temperature=0.7)

Generated text: 
nigeria shares a land border with cameroon , southeast africa and the african republic of africa . the south african republic is a maritime land grant with a cargo to serve the canary islands . they have a maritime presence and their business is related to the african countries of the south african republic , and it is also home to a number of community groups that play a role in the congo ' s economy . these include those of the west indies , and those of the african union , including the m\xc3\xa9tis , < unk > and econ\xc3\xa1 .



In [11]:
generate_response('The most well known species in the Amazon rainforest are', temperature=0.7)

Generated text: 
the most well known species in the amazon rainforest are the mountain sharks ( c . carpantea ) , the small isopod - billed < unk > ( eucalyptus rubidae ) , and the small and highly developed large tunatail shark ( c . argoleps ) . the sharks are capable of coastal life , with their low temperatures and heavy < unk > than the two species .



In [12]:
generate_response('World War II began on', temperature=0.7)

Generated text: 
world war ii began on 4 september 1939 when the united states and united states , the war of the third coalition , and the axis powers became part of the manhattan project , a project that continued until 1942 . construction was completed in january 1943 . it was completed in november 1943 and finished on 8 december 1943 . the project was considered an improvement over the previous budget , and was postponed until the end of the war . it was one of the first actions of the us military to do so , and the first combat combat combat plan . the decision to deploy a combat unit was not implemented until 1944 .



In [18]:
generate_response('Mount Everest reaches a height of', temperature=0.7)

Generated text: 
mount everest reaches a height of 30 metres ( 115 ft ) and features four lava flows , a long , thick dome , two fissures , and a completely flattened , often tephra lava . on average , the volume of ash is 0 . 1 g / cm3 , or 0 . 1 g / kg per cubic meter . the basin has only a high average of 1 . 2 g / kg ( 0 . 6 oz / kg / cu ft ) and an area of 0 . 4 g / kg ( 0 . 2 oz / cu ft ) .

