In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
import tensorflow as tf
import keras_nlp
import tensorflow_text as tf_text
from model import GPT
from utils import *

- Download weights

In [2]:
ckpt_dir = 'openwt_512_d_512/best-ckpt'
loader = Loader()
loader.download(ckpt_dir)

In [3]:
config = loader.config
config

{'batch_size': 16,
 'buffer_size': 40000,
 'shuffle_seed': 32,
 'vocab_file': 'wiki_en_vocab',
 'min_seq_len': False,
 'ckpt_interval': 2000,
 'val_steps': 1000,
 'train_size': 95,
 'vocab_size': 50257,
 'seq_len': 512,
 'learning_rate': 0.001,
 'beta_1': 0.9,
 'beta_2': 0.95,
 'decay_lr': False,
 'decay_steps': 400000,
 'alpha': 0.1,
 'emb_dim': 512,
 'heads': 8,
 'mlp_dim': 512,
 'depth': 10,
 'dropout': 0.0,
 'initializer': 'glorot_uniform',
 'embedding_initializer': 'glorot_uniform',
 'eps': 1e-06,
 'mlp_activation': 'gelu'}

In [4]:
tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en", 
                                                       sequence_length=config['seq_len'])

In [5]:
model = GPT(vocab_size=config['vocab_size'], 
            seq_len=config['seq_len'], emb_dim=config['emb_dim'],
            heads=config['heads'], mlp_dim=config['mlp_dim'],
            depth=config['depth'], rate=config['dropout'], 
            initializer=config['initializer'])

- Initialize the model with a tokenized input

In [6]:
context = 'The silver wolf is'
t_context = tokenizer(tf_text.normalize_utf8(context, 'NFKD'))[tf.newaxis, :]

In [7]:
model(t_context)

<tf.Tensor: shape=(1, 512, 50257), dtype=float32, numpy=
array([[[ 0.01578266, -0.01593508,  0.08134665, ...,  0.16397418,
         -0.07980248,  0.05148029],
        [-0.0891059 , -0.00203854,  0.07782   , ...,  0.15241538,
         -0.00872427,  0.00942059],
        [-0.09127087,  0.10816865,  0.07026106, ...,  0.07882239,
         -0.0064379 ,  0.02395548],
        ...,
        [-0.08796776,  0.08953027,  0.12804441, ...,  0.0775993 ,
          0.06395972, -0.0393813 ],
        [-0.16385601,  0.02793138,  0.10487607, ...,  0.04041716,
         -0.00516192,  0.07077672],
        [-0.10391336,  0.06668523,  0.07727648, ...,  0.08824315,
         -0.00982571,  0.02639535]]], dtype=float32)>

In [8]:
model.summary()

Model: "gpt"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 token_embedding (TokenEmbed  multiple                 25993728  
 ding)                                                           
                                                                 
 dropout_1 (Dropout)         multiple                  0         
                                                                 
 transformer_block (Transfor  multiple                 1577984   
 merBlock)                                                       
                                                                 
 transformer_block_1 (Transf  multiple                 1577984   
 ormerBlock)                                                     
                                                                 
 transformer_block_2 (Transf  multiple                 1577984   
 ormerBlock)                                                   

- Restore weights

In [9]:
model.restore(ckpt_dir)

Checkpoint restored from openwt_512_d_512/best-ckpt/ckpt-1760000 at step 1760000


In [10]:
text = sample(model, 'The black wolf is important for the planet', max_len=128, k=10)
print(text)

The black wolf is important for the planet in our present day and we are in this situation where the black wolf can be found at the center
 The black wolf has no black wolf, but is one big black wolf to be found there because its hunting is in an entirely different form — not just because there are many black wolves in this picture? As you will, there is the opportunity to see black wolves at some very small location in the North America with a black wolf and they are all a part or part, in part one or more people in our modern society are not in any form. We need you in the wild for all to
