In [1]:
# The MIT License (MIT) Copyright (c) 2022 milmor
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of 
# this software and associated documentation files (the "Software"), to deal in the Software without 
# restriction, including without limitation the rights to use, copy, modify, merge, publish, 
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the 
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or 
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES 
# OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
import json
import tensorflow as tf
import keras_nlp
import tensorflow_text as tf_text
from huggingface_hub import hf_hub_download
from model import GPT
from utils import sample

- Download weights

In [3]:
ckpt_dir = 'openwt_512/best-ckpt'

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/ckpt-934000.data-00000-of-00001",
                local_dir='./')

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/ckpt-934000.index",
                local_dir='./')

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/checkpoint",
                local_dir='./')

config_file = hf_hub_download(repo_id="milmor/mini-gpt", 
                filename="openwt_512/openwt_512_config.json",
                local_dir='./')

In [4]:
with open(config_file) as f:
    config = json.load(f)

In [5]:
tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=config['vocab_file'],
    sequence_length=config['seq_len'] + 1,
    lowercase=False,
)

In [6]:
model = GPT(vocab_size=config['vocab_size'], 
            maxlen=config['seq_len'], emb_dim=config['emb_dim'],
            heads=config['heads'], mlp_dim=config['mlp_dim'],
            depth=config['depth'], rate=config['dropout'], 
            initializer=config['initializer'])

- Initialize the model with a tokenized input

In [7]:
context = 'I love the wolf'
t_context = tokenizer(tf_text.normalize_utf8(context, 'NFKD'))[tf.newaxis, :config['seq_len']]

In [8]:
model(t_context)

<tf.Tensor: shape=(1, 512, 30000), dtype=float32, numpy=
array([[[-0.1073962 ,  0.12308779, -0.11740228, ...,  0.13619076,
          0.06594768,  0.0167977 ],
        [-0.09143426,  0.06127262, -0.14263633, ...,  0.07504728,
          0.05748158,  0.0266253 ],
        [-0.10384098,  0.10336254, -0.14206518, ...,  0.13705155,
          0.05703758,  0.0183858 ],
        ...,
        [-0.10039534,  0.08161308, -0.13666598, ...,  0.10743341,
          0.05216627,  0.07849834],
        [-0.1001814 ,  0.11536416, -0.0875769 , ...,  0.14495653,
          0.02732199,  0.0766139 ],
        [-0.06502417,  0.08763284, -0.10016494, ...,  0.13450669,
          0.07237849,  0.04055046]]], dtype=float32)>

In [9]:
model.summary()

Model: "gpt"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 token_embedding (TokenEmbed  multiple                 7811072   
 ding)                                                           
                                                                 
 dropout_1 (Dropout)         multiple                  0         
                                                                 
 transformer_block (Transfor  multiple                 527104    
 merBlock)                                                       
                                                                 
 transformer_block_1 (Transf  multiple                 527104    
 ormerBlock)                                                     
                                                                 
 transformer_block_2 (Transf  multiple                 527104    
 ormerBlock)                                                   

- Restore weights

In [10]:
ckpt = tf.train.Checkpoint(model=model, step=tf.Variable(0))
ckpt_manager = tf.train.CheckpointManager(ckpt, directory=ckpt_dir, 
                                          max_to_keep=1)
ckpt.restore(ckpt_manager.latest_checkpoint)
print(f'Checkpoint restored from {ckpt_manager.latest_checkpoint} at step {int(ckpt.step)}')

Checkpoint restored from openwt_512/best-ckpt/ckpt-934000 at step 934000


In [11]:
text = sample(model, context, config['seq_len'], config['vocab_file'], k=10)
text

'I love the wolf - mongered , but the only problem we find most likely about us . A couple hundred feet away at our last hour of work , a half dozen people are on your back with an eye , and a small amount will take off at least the next one . I was able to keep a small , but still - tall bougin in the middle , in my hand , in an attempt for the best to be a bit overdrching for his work to come back . It was not a surprise I found it on his phone as a small , buffet . And he got to the front seat of one man to the front door on my head to stumble on his head . A baunt with a stomping paper is probably the biggest pain of his illness that has led us to an illnesses in a single cell which means no need for a doctor , so there ’ ve only gotten a couple days in my life to go to a good care unit at his home , and the one on a bed where his son died was in his late life and that ’ d been his best friend for over a long time ! But the bro ! The bro is one and that ’ chuckles in your face . Th