In [1]:
import tensorflow as tf
import tensorflow_text as tf_text
from random import randint
import numpy as np
gpu = tf.config.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(gpu, True)
import os
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

datafiles = "/home/ericjm24/fimarch/txt/*.txt"
base_dataset = tf.data.TextLineDataset(tf.data.Dataset.list_files(datafiles))
base_dataset = base_dataset.map(lambda x: tf.strings.regex_replace(x, r'''([!(),.?\-'";:])''', r' \1 '))
base_dataset = base_dataset.map(lambda x: tf.strings.regex_replace(x, r'''[^ a-z0-9A-Z!(),.?\-'";:]''', ''))
base_dataset = base_dataset.map(lambda x: tf.strings.regex_replace(x, r'''(.*)''', r'startstorystart \1 endstoryend'))
base_dataset = base_dataset.map(lambda x: tf.strings.regex_replace(x, r'\s+', ' '))

2023-07-09 13:50:01.956429: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-09 13:50:03.021105: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-07-09 13:50:03.051046: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-07-

In [2]:
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab
bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[START]", "[END]"]

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = 10000,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)

REBUILD_VOCAB = False

if REBUILD_VOCAB == True:
    vocab = bert_vocab.bert_vocab_from_dataset(
        base_dataset.batch(batch_size = 5, num_parallel_calls=tf.data.AUTOTUNE).prefetch(10),
        **bert_vocab_args
    )

    def write_vocab_file(filepath, vocab):
      with open(filepath, 'w') as f:
        for token in vocab:
          print(token, file=f)

    write_vocab_file('vocab.txt', vocab)

In [3]:
tokenizer = tf_text.BertTokenizer('vocab.txt', **bert_tokenizer_params)

In [4]:
for item in base_dataset.take(1):
    print(item)
    print(tokenizer.detokenize(tokenizer.tokenize(item)))

tf.Tensor(b'startstorystart Three days . Rarity hadn \' t slept in three days . Three days of coffee , take - out containers and screaming . That \' s how long it took Sweetie Belle to finally call for help . " - - And she hasn \' t opened the door to her Inspiration Room since , " the freshman concluded , having tried to catch the others up to what had been going on . Applejack nodded grimly . " So she didn \' t say who this mystery client was ? " Arms crossed , she radiated both concern and being already fed up with the drama . Not an unusual response to Rarity \' s antics . Sweetie \' s nod in response was much more rapid . " Yeah , though I think I recognized the voice when they were on speakerphone before . It sounded like Vignette Valencia . " There was a growl from around waist - level . Sunset looked up from her work . " Great . So it \' s possible Vignette found another artifact or we didn \' t completely clean her last time , and she \' s infected Rarity with Equestrian magic

In [5]:
WINDOW = 400
dataset = base_dataset.map(lambda x: tf.strings.split(x, ' '))
dataset = dataset.map(lambda x: tf_text.sliding_window(x, WINDOW)).unbatch()
dataset = dataset.map(lambda x: tf.strings.reduce_join(x, axis=0, separator = ' '))

In [6]:
MAX_TOKENS = WINDOW
def prepare_batch(words):
    temp = tokenizer.tokenize(words).merge_dims(-2,-1)
    input = temp[:,:MAX_TOKENS].to_tensor()
    out = temp[:,1:(MAX_TOKENS+1)].to_tensor()
    return (input,input),  out

In [7]:
BUFFER_SIZE = 10000
BATCH_SIZE = 24
def make_batches(ds):
    return (
        ds
        .shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE)
        .map(prepare_batch, tf.data.AUTOTUNE)
        .prefetch(buffer_size = tf.data.AUTOTUNE)
    )

In [8]:
for inp, outp in make_batches(dataset).take(1):
    inp = inp[0]
    print(inp.shape)
    print(outp.shape)
print(inp[0][0:10])
print(outp[0][:10])

(24, 400)
(24, 400)
tf.Tensor([  31 1031  915 5949 5416   73   11 6654 2182 6503], shape=(10,), dtype=int64)
tf.Tensor([1031  915 5949 5416   73   11 6654 2182 6503  318], shape=(10,), dtype=int64)


In [9]:
dataset = dataset.shard(43, 0).shuffle(BUFFER_SIZE)
dsize = 500000
#training = dataset.take(dsize)
#validation = dataset.skip(dsize).take(int(dsize*0.2))
validation = dataset.shard(num_shards=5, index = 0).take(int(dsize*0.25))
training = dataset.sample_from_datasets([dataset.shard(num_shards=5, index = 1),dataset.shard(num_shards=5, index = 2),dataset.shard(num_shards=5, index = 3),dataset.shard(num_shards=5, index = 4)]).take(dsize)

training_batches = make_batches(training)
validation_batches = make_batches(validation)

In [10]:
def positional_encoding(length, depth):
    depth = depth/2
    positions = np.arange(length)[:,np.newaxis]
    depths = np.arange(depth)[np.newaxis,:]/depth
    angle_rads = positions * (1 / (10000 ** depths))

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis = -1
    )

    return tf.cast(pos_encoding, dtype=tf.float32)

In [11]:
class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero = True)
        self.pos_encoding = positional_encoding(length = 2048, depth = d_model)
    
    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)
    
    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, dtype=tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

VOCAB_SIZE = 9831

enc_layer = PositionalEncoding(VOCAB_SIZE, 512)

In [12]:
class BaseAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, dropout_rate = 0.1, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

class CrossAttentionLayer(BaseAttentionLayer):
    def call(self, x, context):
        attn_output = self.mha(
            query = x,
            key = context,
            value = context
        )

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

class GlobalSelfAttentionLayer(CrossAttentionLayer):
    def call(self, x):
        return super().call(x, x)

class CausalSelfAttentionLayer(BaseAttentionLayer):
    def call(self, x):
        attn_output = self.mha(
            query = x,
            key = x,
            value = x,
            use_causal_mask = True
        )

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

In [13]:
class FeedForwardLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, dff, dropout_rate):
        super().__init__()
        self.seq = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation = 'relu'),
            tf.keras.layers.Dense(d_model),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()
    
    def call(self, x):
        return self.layer_norm(self.add([x, self.seq((x))]))

In [14]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1):
        super().__init__()
        self.self_attention = GlobalSelfAttentionLayer(num_heads = num_heads, key_dim = d_model, dropout_rate = dropout_rate)
        self.ffn = FeedForwardLayer(d_model = d_model, dff = dff, dropout_rate = dropout_rate)
    
    def call(self, x):
        return self.ffn(self.self_attention(x))

In [15]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size = 9831, dropout_rate = 0.1):
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEncoding(vocab_size = vocab_size, d_model = d_model)

        self.enc_layers = [
            EncoderLayer(d_model = d_model, num_heads = num_heads, dff = dff, dropout_rate = dropout_rate)
            for _ in range(num_layers)
        ]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x):
        x = self.pos_embedding(x)
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)
        return x

In [16]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, *, d_model, num_heads, dff, dropout_rate = 0.1):
        super().__init__()
        self.self_attention = CausalSelfAttentionLayer(num_heads = num_heads, key_dim = d_model, dropout_rate = dropout_rate)
        self.cross_attention = CrossAttentionLayer(num_heads = num_heads, key_dim = d_model, dropout_rate = dropout_rate)
        self.ffn = FeedForwardLayer(d_model, dff, dropout_rate = dropout_rate)
    
    def call(self, x, context):
        x = self.self_attention(x=x)
        x = self.cross_attention(x=x, context=context)
        x = self.ffn(x)
        return x

class Decoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size = 9831, dropout_rate = 0.1):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.pos_embedding = PositionalEncoding(vocab_size = vocab_size, d_model = d_model)

        self.dec_layers = [
            DecoderLayer(d_model = d_model, num_heads = num_heads, dff = dff, dropout_rate = dropout_rate)
            for _ in range(num_layers)
        ]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x, context):
        x = self.pos_embedding(x)
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, context)
        return x

In [17]:
class Transformer(tf.keras.Model):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1):
        super().__init__()
        self.encoder = Encoder(num_layers = num_layers, d_model = d_model, num_heads = num_heads, dff = dff, vocab_size = vocab_size, dropout_rate = dropout_rate)
        self.decoder = Decoder(num_layers = num_layers, d_model = d_model, num_heads = num_heads, dff = dff, vocab_size = vocab_size, dropout_rate = dropout_rate)
        self.final_layer = tf.keras.layers.Dense(vocab_size)
    
    def call(self, inputs):
        # All inputs get passed to the first argument when keras.fit is called
        context = inputs[0]
        x = inputs[1]
        context = self.encoder(context)
        x = self.decoder(x, context)
        logits = self.final_layer(x)

        try:
            del logits._keras_mask
        except AttributeError:
            pass

        return logits

In [18]:
transformer = Transformer(num_layers = 4, d_model = 256, num_heads = 8, dff = 2048, vocab_size = 9831, dropout_rate = 0.15)

In [19]:
transformer((inp, outp))

2023-07-09 10:07:04.015245: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-07-09 10:07:04.107213: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8800


<tf.Tensor: shape=(24, 400, 9831), dtype=float32, numpy=
array([[[ 3.83149274e-03, -1.45557284e-01, -2.56738245e-01, ...,
         -8.07759631e-03, -1.19989939e-01,  4.08738218e-02],
        [-1.17048360e-02,  1.56108811e-01, -2.81347752e-01, ...,
          1.55833662e-01, -1.13874435e-01, -1.10019501e-02],
        [-5.64794391e-02,  9.52024236e-02, -2.06767246e-01, ...,
          2.91610118e-02,  1.18001483e-01, -6.10530265e-02],
        ...,
        [-8.46779570e-02, -3.38700503e-01, -3.06348622e-01, ...,
          3.29118259e-02, -1.93358269e-02, -1.17504662e-02],
        [-1.35689393e-01, -2.03048468e-01, -1.35522529e-01, ...,
          5.91641851e-02, -1.49230789e-02, -8.11166242e-02],
        [-7.75438175e-02, -2.49089167e-01, -3.37276161e-01, ...,
          2.48936787e-01, -3.29920292e-01, -9.55722388e-03]],

       [[-6.47288933e-02,  3.51304322e-01, -3.07557851e-01, ...,
          1.16411693e-01, -1.30938798e-01, -1.25254810e-01],
        [ 3.21624801e-04,  7.13333637e-02, -2.

In [20]:
inp.shape[1]

400

In [21]:
transformer.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Encoder)           multiple                  15138560  
                                                                 
 decoder (Decoder)           multiple                  23554816  
                                                                 
 dense_16 (Dense)            multiple                  2526567   
                                                                 
Total params: 41219943 (157.24 MB)
Trainable params: 41219943 (157.24 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [19]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps
    
    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

learning_rate = CustomSchedule(512)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1 = 0.9, beta_2 = 0.98, epsilon = 1e-9)


In [20]:
def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True, reduction = 'none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype = loss.dtype)
    loss *= mask
    loss = tf.reduce_sum(loss)/(tf.reduce_sum(mask) + tf.constant(1e-9))

    return loss

def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  try:
    match = (label == pred)
  except:
     match = 0
     print(label)
     print(pred)

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [21]:
transformer.compile(loss = masked_loss, optimizer = optimizer, metrics = [masked_accuracy])

In [25]:
transformer.fit(training_batches, epochs = 1, validation_data = validation_batches)

2023-07-09 10:07:23.365985: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 2 of 10000
2023-07-09 10:07:34.040137: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 9 of 10000
2023-07-09 10:07:41.506308: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 2997 of 10000
2023-07-09 10:07:51.508244: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 6926 of 10000
2023-07-09 10:08:00.054330: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer filled.
2023-07-09 10:08:00.255309: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fa62003bd20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-07-09 10:08:00.255339: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor devic

     59/Unknown - 85s 370ms/step - loss: 9.0846 - masked_accuracy: 0.0197

2023-07-09 10:08:30.023594: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 28219507200 exceeds 10% of free system memory.


  12432/Unknown - 3333s 263ms/step - loss: 1.4003 - masked_accuracy: 0.7722

2023-07-09 11:02:37.856213: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 12619718400 exceeds 10% of free system memory.


  20834/Unknown - 5522s 262ms/step - loss: 0.8450 - masked_accuracy: 0.8625

2023-07-09 11:39:07.188760: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 4409908842706861072
2023-07-09 11:39:07.188794: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 1552193096531666608
2023-07-09 11:39:07.188799: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 7095406078362705709
2023-07-09 11:39:18.400641: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 2030 of 10000
2023-07-09 11:39:28.401824: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 6224 of 10000
2023-07-09 11:39:37.143089: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer filled.
2023-07-09 11:40:40.276462: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 13103491200 exceeds 10% of free system memory.




2023-07-09 11:49:06.067617: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 9674532784807802747
2023-07-09 11:49:06.067664: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 5263979607194460627
2023-07-09 11:49:06.067683: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 10050494555370863737
2023-07-09 11:49:06.067688: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 12311708681543296831
2023-07-09 11:49:06.067696: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 4359087690752753260
2023-07-09 11:49:06.067703: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous recv item cancelled. Key hash: 6014387454676997960


<keras.src.callbacks.History at 0x7fa7c01cf820>

In [26]:
transformer.save_weights('my_checkpoint')

In [27]:
for k in range(20):
    transformer.fit(training_batches, epochs = 1, validation_data = validation_batches)
    transformer.save_weights(f'my_checkpoint_{k}')

2023-07-09 11:49:19.516630: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 2 of 10000
2023-07-09 11:49:31.453449: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 9 of 10000
2023-07-09 11:49:37.097824: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 2296 of 10000
2023-07-09 11:49:47.097350: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 6103 of 10000
2023-07-09 11:49:57.082246: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer filled.


   7107/Unknown - 1902s 261ms/step - loss: 0.0186 - masked_accuracy: 0.9970

2023-07-09 12:20:49.476018: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 11281718400 exceeds 10% of free system memory.


  15306/Unknown - 4038s 261ms/step - loss: 0.0177 - masked_accuracy: 0.9972

2023-07-09 12:56:25.138461: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 13103491200 exceeds 10% of free system memory.


  20834/Unknown - 5477s 261ms/step - loss: 0.0173 - masked_accuracy: 0.9974

2023-07-09 13:20:24.564636: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 4409908842706861072
2023-07-09 13:20:24.564662: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 1552193096531666608
2023-07-09 13:20:24.564667: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 7095406078362705709
2023-07-09 13:20:34.691231: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 1220 of 10000
2023-07-09 13:20:44.720928: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 5718 of 10000
2023-07-09 13:20:54.692591: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 7604 of 10000
2023-07-09 13:20:59.446018: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer



2023-07-09 13:30:28.763165: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 9674532784807802747
2023-07-09 13:30:28.763192: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 5263979607194460627
2023-07-09 13:30:28.763198: I tensorflow/core/framework/local_rendezvous.cc:409] Local rendezvous send item cancelled. Key hash: 12311708681543296831
2023-07-09 13:30:40.478131: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 8971 of 10000
2023-07-09 13:30:40.836172: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer filled.
2023-07-09 13:30:40.837326: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 1 of 10000
2023-07-09 13:30:53.244138: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 5 of 10

   4264/Unknown - 1165s 261ms/step - loss: 0.0156 - masked_accuracy: 0.9976

   4267/Unknown - 1166s 261ms/step - loss: 0.0156 - masked_accuracy: 0.9976

In [22]:
class TextGenerator(tf.Module):
    def __init__(self, tokenizer, transformer, temperature = 0.5):
        self.tokenizer = tokenizer
        self.transformer = transformer
        self.temperature = temperature
        skip_ids = [[1],[2],[3]]
        sparse_mask = tf.SparseTensor(
            # Put a -inf at each bad index.
            values=[-float('inf')]*len(skip_ids),
            indices=skip_ids,
            # Match the shape to the vocabulary
            dense_shape=[9831])
        self.prediction_mask = tf.sparse.to_dense(sparse_mask)
    
    def __call__(self, sentence, max_length=MAX_TOKENS):
        sentence = tf.constant("startstorystart " + sentence)
        assert isinstance(sentence, tf.Tensor)
        if len(sentence.shape) == 0:
            sentence = sentence[tf.newaxis]
        
        sentence = self.tokenizer.tokenize(sentence).merge_dims(-2, -1).to_tensor()
        print(sentence)
        start_end = self.tokenizer.tokenize(['startstorystart endstoryend'])[0]
        start = start_end[0][tf.newaxis]
        end = start_end[1][tf.newaxis]

        for i in tf.range(max_length):
            predictions = self.transformer([sentence, sentence], training = False)
            predictions = predictions [:, -1, :]
            predictions = predictions / self.temperature
            predictions = predictions + self.prediction_mask
            predicted_id = tf.random.categorical(predictions, num_samples = 1)
            next_token = tokenizer.detokenize(predicted_id)
            sentence = tf.concat([sentence, predicted_id], axis = -1)[:, (0-max_length):]
            if predicted_id == end:
                break
        
        text = tokenizer.detokenize(sentence)[0]
        return text

In [25]:
gentext = TextGenerator(tokenizer, transformer, temperature = 0.2)

In [26]:
gentext("This is a story about nothing in particular. I am making up random stuff.", max_length = 100)

tf.Tensor(
[[   4   76   80   27  553  105  265   62 1740   13   35  260  330   77
  1970  729   13]], shape=(1, 17), dtype=int64)


2023-07-09 13:51:04.512155: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-07-09 13:51:04.617080: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8800


<tf.Tensor: shape=(100,), dtype=string, numpy=
array([b'a', b'story', b'about', b'about', b'i', b'am', b'making', b'am',
       b'making', b'am', b'making', b'stuff', b'about', b'about',
       b'about', b'about', b'i', b'am', b'am', b'am', b'am', b'am', b'am',
       b'am', b'am', b'about', b'about', b'about', b'about', b'about',
       b'about', b'about', b'about', b'i', b'am', b'am', b'am', b'am',
       b'am', b'am', b'am', b'am', b'am', b'am', b'am', b'about',
       b'about', b'about', b'about', b'about', b'about', b'about', b'i',
       b'i', b'am', b'am', b'am', b'am', b'am', b'am', b'about', b'about',
       b'about', b'about', b'about', b'about', b'i', b'am', b'am', b'am',
       b'am', b'am', b'am', b'about', b'about', b'about', b'about',
       b'about', b'about', b'about', b'i', b'am', b'am', b'am', b'am',
       b'am', b'am', b'am', b'about', b'about', b'about', b'about',
       b'about', b'about', b'about', b'i', b'am', b'am', b'am', b'am'],
      dtype=object)>

In [23]:
transformer.load_weights('my_checkpoint_0')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f5c2c430340>