In [1]:
# The MIT License (MIT) Copyright (c) 2022 milmor
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of 
# this software and associated documentation files (the "Software"), to deal in the Software without 
# restriction, including without limitation the rights to use, copy, modify, merge, publish, 
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the 
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or 
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES 
# OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
import json
import tensorflow as tf
import keras_nlp
import tensorflow_text as tf_text
from huggingface_hub import hf_hub_download
from model import GPT
from utils import sample

- Download weights

In [3]:
ckpt_dir = 'openwt_512/best-ckpt'

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/ckpt-934000.data-00000-of-00001",
                local_dir='./')

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/ckpt-934000.index",
                local_dir='./')

hf_hub_download(repo_id="milmor/gpt-mini", 
                filename=f"{ckpt_dir}/checkpoint",
                local_dir='./')

config_file = hf_hub_download(repo_id="milmor/mini-gpt", 
                filename="openwt_512/openwt_512_config.json",
                local_dir='./')

In [4]:
with open(config_file) as f:
    config = json.load(f)

In [5]:
tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=config['vocab_file'],
    sequence_length=config['seq_len'] + 1,
    lowercase=False,
)

In [6]:
model = GPT(vocab_size=config['vocab_size'], 
            maxlen=config['seq_len'], emb_dim=config['emb_dim'],
            heads=config['heads'], mlp_dim=config['mlp_dim'],
            depth=config['depth'], rate=config['dropout'], 
            initializer=config['initializer'])

- Initialize the model with a tokenized input

In [7]:
context = 'I love the wolf'
t_context = tokenizer(tf_text.normalize_utf8(context, 'NFKD'))[tf.newaxis, :config['seq_len']]

In [8]:
model(t_context)

<tf.Tensor: shape=(1, 512, 30000), dtype=float32, numpy=
array([[[ 0.04053944,  0.02580857, -0.02134609, ...,  0.06464612,
          0.00159945, -0.18666302],
        [ 0.05024419,  0.07753551,  0.01507949, ...,  0.1112728 ,
          0.02672572, -0.17179757],
        [-0.01681338,  0.03965631,  0.12883598, ..., -0.02745748,
          0.05715263, -0.18807063],
        ...,
        [-0.04165641,  0.03784008,  0.086284  , ...,  0.07254027,
          0.03297675, -0.1646466 ],
        [-0.0728814 ,  0.00634096,  0.11156805, ...,  0.07823642,
          0.0475099 , -0.14655234],
        [-0.06854442, -0.00762512,  0.09210709, ...,  0.07022417,
          0.03737502, -0.15992373]]], dtype=float32)>

In [9]:
model.summary()

Model: "gpt"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 token_embedding (TokenEmbed  multiple                 7811072   
 ding)                                                           
                                                                 
 dropout_1 (Dropout)         multiple                  0         
                                                                 
 transformer_block (Transfor  multiple                 527104    
 merBlock)                                                       
                                                                 
 transformer_block_1 (Transf  multiple                 527104    
 ormerBlock)                                                     
                                                                 
 transformer_block_2 (Transf  multiple                 527104    
 ormerBlock)                                                   

- Restore weights

In [10]:
ckpt = tf.train.Checkpoint(model=model, step=tf.Variable(0))
ckpt_manager = tf.train.CheckpointManager(ckpt, directory=ckpt_dir, 
                                          max_to_keep=1)
ckpt.restore(ckpt_manager.latest_checkpoint)
print(f'Checkpoint restored from {ckpt_manager.latest_checkpoint} at step {int(ckpt.step)}')

Checkpoint restored from openwt_512/best-ckpt/ckpt-934000 at step 934000


In [11]:
text = sample(model, context, config['seq_len'], config['vocab_file'], k=40)
text

"I love the wolfy spot I hate a year , one month and you were sweatbing a cliff like some roxcy peel . ( FUTY MEKLITES ASSID AF EY . PLVENATT STOP : OLSLAS OLD ANNOILS : The Arre muckin ( borating its like that ’ a true - fun way with which she likes ! . So good ) There are plenty smile of like me like them all out here - there the IH . That may surprise for much that while this article in full and there for its most popular IM had on display since earlier to begin drawing as this version on my iPhone at IRXM / xY3Hm IS ) have , we ' had been doing everything together to share about and even blisss through IN in front of 2 , 800 students per a ( which does more to snubs me at these end people though that should end though in most things that will likely change ) because if no person got me through into these tpenie they had it ' d not matter I want to throw down every couple in every single place there and my first time with AF would actually do just little I was on my mind . So the sa