---
**Import and configure libraries**

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras_nlp
import keras
import tensorflow as tf
import tensorflow.data as tf_data
import tensorflow.io as tf_io
import tensorflow.strings as tf_strings
import time

keras.mixed_precision.set_global_policy("mixed_float16")

strategy = tf.distribute.MirroredStrategy()

---
**Define model constants**

Can be updated as needed, specifically:
* `BATCH_SIZE`, to accomodate the memory restrictions of your machine
* `EPOCHS`, to train for longer/shorter
* `SEQ_LEN`, to change the context window of the model
* `DATASET_FILE`, to use a dataset stored elsewhere

In [None]:
# General
GPT2_PRESET = "gpt2_base_en"

# LoRA
RANK = 4
ALPHA = 32.0

# Data
BATCH_SIZE = 2 # Batch size we train on
SEQ_LEN = 512  # Length of training sequences, in tokens. AKA the context size
NUM_BATCHES = 2  # Number of batches to train on

# Training
EPOCHS = 1

# File names
DATASET_FILE = "RecipeNLG/RecipeNLG_dataset.csv" # where the training data is stored

---
**Load dataset**

In [None]:
def csv_row_to_json(row):
    row = tf_io.decode_csv(records=row, record_defaults=[tf.constant([],dtype=tf.string)] * 7)

    title = row[1]
    ingredients = row[2]
    directions = row[3]
    ner = row[6]

    # preserve the semi-structured nature of the dataset
    return tf_strings.join([
        '{"ner": ', ner, ', ',
        '"title": "', title, '", ',
        '"ingredients": ', ingredients, ', ',
        '"directions": ', directions, '}',
    ])


dataset = (
    tf_data.TextLineDataset(DATASET_FILE) # load the csv file line by line
    .skip(1) # skip the header row
    .shuffle(buffer_size=256) # store 256 shuffled records in memory at a time before reshuffling and refetching
    .map(lambda row: csv_row_to_json(row)) # map each row of the csv to a json-formatted string
    # .ignore_errors() # ignore any errors thrown by misformatted rows of the csv
    .apply(tf.data.experimental.ignore_errors()) 
    .batch(BATCH_SIZE) # batch the dataset to train on multiple records at once
    .take(NUM_BATCHES) # only train on the first NUM_BATCHES batches
)

---
**Helper functions**

In [None]:
def generate_text(model, input_text, max_length=200):
    start = time.time()

    output = model.generate(input_text, max_length=max_length)
    print("\nOutput:")
    print(output)

    end = time.time()
    print(f"Total Time Elapsed: {end - start:.2f}s")
    

def get_optimizer_and_loss():
    optimizer = keras.optimizers.AdamW(
        learning_rate=5e-5,
        weight_decay=0.01,
        epsilon=1e-6,
        global_clipnorm=1.0,  # Gradient clipping.
    )
    # Exclude layernorm and bias terms from weight decay.
    optimizer.exclude_from_weight_decay(var_names=["bias"])
    optimizer.exclude_from_weight_decay(var_names=["gamma"])
    optimizer.exclude_from_weight_decay(var_names=["beta"])

    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return optimizer, loss

---
**Create a LoRA layer**

In [None]:
import math

class LoraLayer(keras.layers.Layer):
    def __init__(
        self,
        original_layer,
        rank=8,
        alpha=32,
        trainable=False,
        **kwargs,
    ):
        # We want to keep the name of this layer the same as the original
        # dense layer.
        original_layer_config = original_layer.get_config()
        name = original_layer_config["name"]

        kwargs.pop("name", None)

        super().__init__(name=name, trainable=trainable, **kwargs)

        self.rank = rank
        self.alpha = alpha

        self._scale = alpha / rank

        self._num_heads = original_layer_config["output_shape"][-2]
        self._hidden_dim = self._num_heads * original_layer_config["output_shape"][-1]

        # Layers.

        # Original dense layer.
        self.original_layer = original_layer
        # No matter whether we are training the model or are in inference mode,
        # this layer should be frozen.
        self.original_layer.trainable = False

        # LoRA dense layers.
        self.A = keras.layers.Dense(
            units=rank,
            use_bias=False,
            # Note: the original paper mentions that normal distribution was
            # used for initialization. However, the official LoRA implementation
            # uses "Kaiming/He Initialization".
            kernel_initializer=keras.initializers.VarianceScaling(
                scale=math.sqrt(5), mode="fan_in", distribution="uniform"
            ),
            trainable=trainable,
            name=f"lora_A",
        )
        # B has the same `equation` and `output_shape` as the original layer.
        # `equation = abc,cde->abde`, where `a`: batch size, `b`: sequence
        # length, `c`: `hidden_dim`, `d`: `num_heads`,
        # `e`: `hidden_dim//num_heads`. The only difference is that in layer `B`,
        # `c` represents `rank`.
        self.B = keras.layers.EinsumDense(
            equation=original_layer_config["equation"],
            output_shape=original_layer_config["output_shape"],
            kernel_initializer="zeros",
            trainable=trainable,
            name=f"lora_B",
        )

    def call(self, inputs):
        original_output = self.original_layer(inputs)
        if self.trainable:
            # If we are fine-tuning the model, we will add LoRA layers' output
            # to the original layer's output.
            lora_output = self.B(self.A(inputs)) * self._scale
            return original_output + lora_output

        # If we are in inference mode, we "merge" the LoRA layers' weights into
        # the original layer's weights - more on this in the text generation
        # section!
        return original_output


---
**Load pre-trained model and tokenizer**

In [None]:
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
with strategy.scope():
    lora_model = keras_nlp.models.GPT2CausalLM.from_preset(
        "gpt2_base_en",
        preprocessor=preprocessor,
    )
lora_model.summary()

---
**Inject LoRA into the model**

In [None]:
for layer_idx in range(lora_model.backbone.num_layers):
    # Change query dense layer.
    decoder_layer = lora_model.backbone.get_layer(f"transformer_layer_{layer_idx}")
    self_attention_layer = decoder_layer._self_attention_layer
    # Allow mutation to Keras layer state.
    self_attention_layer._tracker.locked = False

    # Change query dense layer.
    self_attention_layer._query_dense = LoraLayer(
        self_attention_layer._query_dense,
        rank=RANK,
        alpha=ALPHA,
        trainable=True,
    )

    # Change value dense layer.
    self_attention_layer._value_dense = LoraLayer(
        self_attention_layer._value_dense,
        rank=RANK,
        alpha=ALPHA,
        trainable=True,
    )

In [None]:
# freeze non-LoRA layers
for layer in lora_model._flatten_layers():
    lst_of_sublayers = list(layer._flatten_layers())

    if len(lst_of_sublayers) == 1:  # "leaves of the model"
        if layer.name in ["lora_A", "lora_B"]:
            layer.trainable = True
        else:
            layer.trainable = False

In [None]:
lora_model.summary()

---
**Finetune the model**

In [None]:
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath='transfer_learning/checkpoint_{epoch:02d}.keras',
    save_best_only=False,
)

with strategy.scope():
    optimizer, loss = get_optimizer_and_loss()

    lora_model.compile(
        optimizer=optimizer,
        loss=loss,
        weighted_metrics=["accuracy"],
    )

In [None]:
lora_model.fit(
    dataset,
    epochs=EPOCHS,
    callbacks=[checkpoint_callback],
)

---
**Merge weights**

In [None]:
for layer_idx in range(lora_model.backbone.num_layers):
    self_attention_layer = lora_model.backbone.get_layer(
        f"transformer_layer_{layer_idx}"
    )._self_attention_layer

    # Merge query dense layer.
    query_lora_layer = self_attention_layer._query_dense

    A_weights = query_lora_layer.A.kernel  # (768, 1) (a, b)
    B_weights = query_lora_layer.B.kernel  # (1, 12, 64) (b, c, d)
    increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
    query_lora_layer.original_layer.kernel.assign_add(increment_weights)

    # Merge value dense layer.
    value_lora_layer = self_attention_layer._value_dense

    A_weights = value_lora_layer.A.kernel  # (768, 1) (a, b)
    B_weights = value_lora_layer.B.kernel  # (1, 12, 64) (b, c, d)
    increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
    value_lora_layer.original_layer.kernel.assign_add(increment_weights)

    # Put back in place the original layers with updated weights
    self_attention_layer._query_dense = query_lora_layer.original_layer
    self_attention_layer._value_dense = value_lora_layer.original_layer

---
**Generate text**

In [None]:
generate_text(lora_model, '{"ner": ')

---
**Save the final model**

In [None]:
lora_model.save("transfer_learning/final_model.keras")