In [3]:
"""
# Transformer Model

https://www.tensorflow.org/text/tutorials/transformer

This is again following the tutorial on tensorflow's website. Did a lot of digging to find 
simpler implementations but it seems breaking things down in the following way is actually 
the most popular way to do it. Things may be better anyway since breaking it down makes 
learning easier. 

I plan to go slowly through this part and spend about a day with it to make sure I understand 
it. However, it's worth mentioning that this should be plug and play based on the same data 
pipeline I defined in `tf_dataset.py`. Note that the grand majority of this code is copy-pasted 
from the tutorial. I did have to make some departures from the tutorial, however, since it
was written for a different use case with different data. This is mostly in the export/running
the thing. My tokenizer was different, which caused some issues.
"""

# NOTE: Installing tensorflow_datasets and tensorflow_text updated tf to 2.13.0 from 12.2.1, if it breaks revert and find correct versions

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # stop showing tensorflow logs...

import numpy as np
import tensorflow as tf
from scratch_model.dataset import get_datasets

# Check GPU is being used. Prints [] if not
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

# Prevent tensorflow from allocating all GPU memory at once
tf.config.experimental.set_memory_growth(physical_devices[0], True) # Nice!

# MODEL PARAMS: These are roughly the same parameters as used in the original transformer paper.
BATCH_SIZE = 64        
EPOCHS = 2             # What we used for transformer_v1
NUM_LAYERS = 6          # 4 
D_MODEL = 512           # 128
DFF = 2048              # 512
NUM_HEADS = 8           # 8
DROPOUT_RATE = 0.1      # 0.1

save_dir = './models/transformer_v3.01'

train_ds, val_ds, text_processor = get_datasets(batch_size = BATCH_SIZE)


# ---- Defining stuff ----
def positional_encoding(length, depth):
    depth = depth/2

    positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)

    angle_rates = 1 / (10000**depths)         # (1, depth)
    angle_rads = positions * angle_rates      # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)

    return tf.cast(pos_encoding, dtype=tf.float32)

class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positonal_encoding.
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x
  
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

class CrossAttention(BaseAttention):
    def call(self, x, context):
        attn_output, attn_scores = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

class GlobalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x
  
class CausalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x,
            use_causal_mask = True)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x
  
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model),
        tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x
  
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
        super().__init__()

        self.self_attention = GlobalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x

class Encoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, d_model, num_heads,
                dff, vocab_size, dropout_rate=0.1):
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size, d_model=d_model)

        self.enc_layers = [
            EncoderLayer(d_model=d_model,
                        num_heads=num_heads,
                        dff=dff,
                        dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        # `x` is token-IDs shape: (batch, seq_len)
        x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.

        # Add dropout.
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, d_model)`.
  
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self,
                *,
                d_model,
                num_heads,
                dff,
                dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x, context):
        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x
  
class Decoder(tf.keras.layers.Layer):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
                dropout_rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                d_model=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(d_model=d_model, num_heads=num_heads,
                        dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)]

        self.last_attn_scores = None

    def call(self, x, context):
        # `x` is token-IDs shape (batch, target_seq_len)
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x  = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores

        # The shape of x is (batch_size, target_seq_len, d_model).
        return x
  
class Transformer(tf.keras.Model):
    def __init__(self, *, num_layers, d_model, num_heads, dff,
                input_vocab_size, target_vocab_size, dropout_rate=0.1):
        super().__init__()
        self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                            num_heads=num_heads, dff=dff,
                            vocab_size=input_vocab_size,
                            dropout_rate=dropout_rate)

        self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                            num_heads=num_heads, dff=dff,
                            vocab_size=target_vocab_size,
                            dropout_rate=dropout_rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    def call(self, inputs):
        context, x  = inputs

        context = self.encoder(context)  # (batch_size, context_len, d_model)
        x = self.decoder(x, context)  # (batch_size, target_len, d_model)
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
            del logits._keras_mask
        except AttributeError:
            pass

        # Return the final output and the attention weights.
        return logits
  
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
            super().__init__()

            self.d_model = d_model
            self.d_model = tf.cast(self.d_model, tf.float32)

            self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
  
# ---- Setup loss and metrics ----
def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
    return loss


def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match)/tf.reduce_sum(mask)

# ---- Training ----
transformer = Transformer(
    num_layers=NUM_LAYERS,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    dff=DFF,
    input_vocab_size=5000,  # This is the vocab size used for all the datasets.
    target_vocab_size=5000,
    dropout_rate=DROPOUT_RATE)


learning_rate = CustomSchedule(D_MODEL)

optimizer = tf.keras.optimizers.Adam(
    learning_rate, 
    beta_1=0.9, 
    beta_2=0.98,
    epsilon=1e-9
)

transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[masked_accuracy])

transformer.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds)

print(transformer.summary())

# ---- Save model ----

vocab = text_processor.get_vocabulary()

MAX_TOKENS = 256


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Epoch 1/2
Epoch 2/2
Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Encoder)           multiple                  65579008  
                                                                 
 decoder (Decoder)           multiple                  115993600 
                                                                 
 dense_24 (Dense)            multiple                  2565000   
                                                                 
Total params: 184,137,608
Trainable params: 184,137,608
Non-trainable params: 0
_________________________________________________________________
None


In [30]:
transformer

<__main__.Transformer at 0x7ff0f2f2aa10>

In [39]:
# ---- Training ----
transformer2 = Transformer(
    num_layers=NUM_LAYERS,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    dff=DFF,
    input_vocab_size=5000,  # This is the vocab size used for all the datasets.
    target_vocab_size=5000,
    dropout_rate=DROPOUT_RATE)

In [50]:
import json

# Export model parameters
params = {
    'num_layers': NUM_LAYERS,
    'd_model': D_MODEL,
    'num_heads': NUM_HEADS,
    'dff': DFF,
    'input_vocab_size': 5000,
    'target_vocab_size': 5000,
    'dropout_rate': DROPOUT_RATE,
    'max_tokens': MAX_TOKENS,
    'batch_size': BATCH_SIZE,
}

with open(os.path.join(save_dir, 'params.json'), 'w') as f:
    json.dump(params, f)

In [1]:
save_dir = "./models/transformer_v3.02/"

In [2]:
#transformer.save(save_dir)
# or, if the transformer is a Keras model
transformer.save_weights(save_dir)


NameError: name 'transformer' is not defined

In [23]:
class ScratchModel(tf.Module):
    def __init__(
        self,
        transformer,
        text_processor: tf.keras.layers.TextVectorization,
        vocab
    ):
        self.transformer = transformer
        self.text_processor = text_processor
        self.vocab = vocab
        self.vocab_tf = tf.constant(self.vocab)

    def _predict_next(self, question, output_array, i):
        "Predicts the next token given the question and the output array"
        output = tf.transpose(output_array.stack())
        prediction = transformer([question, output], training=False)
        prediction = prediction[:, -1:, :]
        prediction_id = tf.argmax(prediction, axis=-1)
        output_array = output_array.write(i+1, prediction_id[0])
        output = tf.transpose(output_array.stack())

        text = tf.strings.reduce_join(
            tf.map_fn(lambda x: self.vocab_tf[x], tf.squeeze(output), dtype=tf.string), separator=" "
        )

        return prediction_id, text, output_array
    
    def _stream_result(self, question, output_array, max_tokens, end):
        "Streams the result of the prediction"
        for i in tf.range(max_tokens):
                prediction_id, text, output_array = self._predict_next(question, output_array, i)
    
                if prediction_id == end:
                    break

                yield text

    def _return_result(self, question, output_array, max_tokens, end):
        "Returns the result of the prediction"
        for i in tf.range(max_tokens):
                prediction_id, text, output_array = self._predict_next(question, output_array, i)
    
                if prediction_id == end:
                    break
    
        return text

    def __call__(self, question: str, max_tokens: int = 256, stream: bool = False):
        "Oversees the prediction process. Returns a generator if stream=True"
        question = tf.convert_to_tensor([question])
        question = self.text_processor(question).to_tensor()

        start_end = text_processor([''])[0]
        start = start_end[0][tf.newaxis]
        end = start_end[1][tf.newaxis]

        output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        output_array = output_array.write(0, start)

        if stream:
            return self._stream_result(question, output_array, max_tokens, end)
        else:
            return self._return_result(question, output_array, max_tokens, end)

In [54]:
model = ScratchModel(transformer, text_processor, vocab)

what does export function do

In [55]:
model("How can I apply?", stream=False, max_tokens=10)

<tf.Tensor: shape=(), dtype=string, numpy=b'[START] you can find the [UNK] , you can be a'>

In [56]:
generated_texts = model(question="How can I apply?", max_tokens=10, stream=True)

for text in generated_texts:  # Convert the scalar tensor to a Python scalar
    print(text.numpy())


b'[START] you'
b'[START] you can'
b'[START] you can find'
b'[START] you can find the'
b'[START] you can find the [UNK]'
b'[START] you can find the [UNK] ,'
b'[START] you can find the [UNK] , you'
b'[START] you can find the [UNK] , you can'
b'[START] you can find the [UNK] , you can be'
b'[START] you can find the [UNK] , you can be a'
