In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
import json
import tensorflow as tf
import keras_nlp
import tensorflow_text as tf_text
from huggingface_hub import hf_hub_download
from model import GPT
from utils import sample

- Download weights

In [2]:
ckpt_dir = 'openwt_512_d_512/best-ckpt'

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/ckpt-1760000.data-00000-of-00001",
                local_dir='./')

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/ckpt-1760000.index",
                local_dir='./')

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/checkpoint",
                local_dir='./')

config_file = hf_hub_download(repo_id="milmor/gpt-mini", 
                filename="openwt_512_d_512/openwt_512_d_512_config.json",
                local_dir='./')

Downloading (…).data-00000-of-00001:   0%|          | 0.00/811M [00:00<?, ?B/s]

Downloading (…)t/ckpt-1760000.index:   0%|          | 0.00/30.8k [00:00<?, ?B/s]

Downloading (…)best-ckpt/checkpoint:   0%|          | 0.00/367 [00:00<?, ?B/s]

In [3]:
config_file

'./openwt_512_d_512/openwt_512_d_512_config.json'

In [4]:
with open(config_file) as f:
    config = json.load(f)

In [5]:
tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en", 
                                                       sequence_length=config['seq_len'])

In [6]:
model = GPT(vocab_size=config['vocab_size'], 
            maxlen=config['seq_len'], emb_dim=config['emb_dim'],
            heads=config['heads'], mlp_dim=config['mlp_dim'],
            depth=config['depth'], rate=config['dropout'], 
            initializer=config['initializer'])

- Initialize the model with a tokenized input

In [7]:
context = 'The silver wolf is'
t_context = tokenizer(tf_text.normalize_utf8(context, 'NFKD'))[tf.newaxis, :]

In [8]:
model(t_context)

<tf.Tensor: shape=(1, 512, 50257), dtype=float32, numpy=
array([[[ 0.32933474, -0.17945187,  0.01195081, ..., -0.00311815,
          0.09691932, -0.13828048],
        [ 0.31565788, -0.08625038,  0.14263943, ..., -0.03063635,
          0.0766515 , -0.03668835],
        [ 0.25610116, -0.06395464,  0.11881614, ..., -0.00789269,
          0.09515308, -0.04538292],
        ...,
        [ 0.29410726, -0.05955829,  0.05035197, ...,  0.0246953 ,
          0.01635138, -0.14428937],
        [ 0.28724444, -0.0832747 ,  0.10423397, ...,  0.02961318,
          0.01783834, -0.12234353],
        [ 0.31137285, -0.05091083,  0.11041143, ..., -0.02893059,
          0.08073275, -0.07353926]]], dtype=float32)>

In [9]:
model.summary()

Model: "gpt"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 token_embedding (TokenEmbed  multiple                 25993728  
 ding)                                                           
                                                                 
 dropout_1 (Dropout)         multiple                  0         
                                                                 
 transformer_block (Transfor  multiple                 1577984   
 merBlock)                                                       
                                                                 
 transformer_block_1 (Transf  multiple                 1577984   
 ormerBlock)                                                     
                                                                 
 transformer_block_2 (Transf  multiple                 1577984   
 ormerBlock)                                                   

- Restore weights

In [10]:
ckpt = tf.train.Checkpoint(model=model, step=tf.Variable(0))
ckpt_manager = tf.train.CheckpointManager(ckpt, directory=ckpt_dir, 
                                          max_to_keep=1)
ckpt.restore(ckpt_manager.latest_checkpoint)
print(f'Checkpoint restored from {ckpt_manager.latest_checkpoint} at step {int(ckpt.step)}')

Checkpoint restored from openwt_512_d_512/best-ckpt/ckpt-1760000 at step 1760000


In [11]:
text = sample(model, 'The silver wolf is', config['seq_len'], max_len=128, k=40)
print(text)

The silver wolf is being driven along a road this will be known through and to be a part from, not for any time before or in the way she is seen coming in about ten years. But despite efforts last but for now for them all as many Americans believe these animals already have a life to live, they are getting all things considered right from one individual point: the black wolf? This idea about why not see cats go on land-based land without a life has evolved significantly faster when in 2014 about a dozen people who were in search – with their parents or grandparents and two children still in arms around this one than that but with
