In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
import tensorflow as tf
import keras_nlp
import tensorflow_text as tf_text
from model import GPT
from utils import *

- Download weights

In [2]:
ckpt_dir = 'openwt_512_d_512/best-ckpt'
loader = Loader()
loader.download(ckpt_dir)

In [3]:
config = loader.config
config

{'batch_size': 16,
 'buffer_size': 40000,
 'shuffle_seed': 32,
 'vocab_file': 'wiki_en_vocab',
 'min_seq_len': False,
 'ckpt_interval': 2000,
 'val_steps': 1000,
 'train_size': 95,
 'vocab_size': 50257,
 'seq_len': 512,
 'learning_rate': 0.001,
 'beta_1': 0.9,
 'beta_2': 0.95,
 'decay_lr': False,
 'decay_steps': 400000,
 'alpha': 0.1,
 'emb_dim': 512,
 'heads': 8,
 'mlp_dim': 512,
 'depth': 10,
 'dropout': 0.0,
 'initializer': 'glorot_uniform',
 'embedding_initializer': 'glorot_uniform',
 'eps': 1e-06,
 'mlp_activation': 'gelu'}

In [4]:
tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en", 
                                                       sequence_length=config['seq_len'])

In [5]:
model = GPT(vocab_size=config['vocab_size'], 
            seq_len=config['seq_len'], emb_dim=config['emb_dim'],
            heads=config['heads'], mlp_dim=config['mlp_dim'],
            depth=config['depth'], rate=config['dropout'], 
            initializer=config['initializer'])

- Initialize the model with a tokenized input

In [6]:
context = 'The silver wolf is'
t_context = tokenizer(tf_text.normalize_utf8(context, 'NFKD'))[tf.newaxis, :]

In [7]:
model(t_context)

<tf.Tensor: shape=(1, 512, 50257), dtype=float32, numpy=
array([[[ 0.26018098, -0.099659  , -0.05211959, ..., -0.08389892,
         -0.2226785 , -0.01996342],
        [ 0.24848461, -0.10606524, -0.1408198 , ...,  0.02978064,
         -0.24563235,  0.11638285],
        [ 0.22778201, -0.11750411, -0.11632746, ..., -0.00105385,
         -0.25143325,  0.2077703 ],
        ...,
        [ 0.2554393 , -0.11542156, -0.13028958, ..., -0.05494361,
         -0.25776154,  0.16087352],
        [ 0.24464783, -0.1165686 , -0.10335917, ..., -0.02977747,
         -0.2302949 ,  0.17602484],
        [ 0.29687753, -0.11166702, -0.12161556, ..., -0.04026931,
         -0.32042617,  0.17742355]]], dtype=float32)>

In [8]:
model.summary()

Model: "gpt"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 token_embedding (TokenEmbed  multiple                 25993728  
 ding)                                                           
                                                                 
 dropout_1 (Dropout)         multiple                  0         
                                                                 
 transformer_block (Transfor  multiple                 1577984   
 merBlock)                                                       
                                                                 
 transformer_block_1 (Transf  multiple                 1577984   
 ormerBlock)                                                     
                                                                 
 transformer_block_2 (Transf  multiple                 1577984   
 ormerBlock)                                                   

- Restore weights

In [9]:
model.restore(ckpt_dir)

Checkpoint restored from openwt_512_d_512/best-ckpt/ckpt-1760000 at step 1760000


In [10]:
text = sample(model, 'The silver wolf is', max_len=128, k=40)
print(text)

The silver wolf is actually going around — about 25 years earlier this day this March, as it passed around that morning and it finally got past 11 September for the first straight fall
 violations have led thousands of people back onto the road like a giant black wolf before moving onto the road last December where their car was being driven about 18 miles distant past 6,000 miles, a new kind (and not everyone on Facebook). Today as I drive south down a road I want myself the sound to be heard but like every second day on purpose here (and when an old man will be out Sunday I need just someone on and up here like many
