---
**Import and configure libraries**

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras_nlp
import keras
import tensorflow as tf
import tensorflow.data as tf_data
import tensorflow.io as tf_io
import tensorflow.strings as tf_strings
import time

keras.mixed_precision.set_global_policy("mixed_float16")

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

---
**Define model constants**

Can be updated as needed, specifically:
* `BATCH_SIZE`, to accomodate the memory restrictions of your machine
* `EPOCHS`, to train for longer/shorter
* `SEQ_LEN`, to change the context window of the model
* `DATASET_FILE`, to use a dataset stored elsewhere

In [None]:
# General
GPT2_PRESET = "gpt2_base_en"

# LoRA
RANK = 4

# Data
BATCH_SIZE = 2 # Batch size we train on
SEQ_LEN = 512  # Length of training sequences, in tokens. AKA the context size
NUM_BATCHES = 2  # Number of batches to train on

# Training
EPOCHS = 1

# File names
DATASET_FILE = "RecipeNLG/RecipeNLG_dataset.csv" # where the training data is stored

---
**Load dataset**

In [None]:
def csv_row_to_json(row):
    row = tf_io.decode_csv(records=row, record_defaults=[tf.constant([],dtype=tf.string)] * 7)

    title = row[1]
    ingredients = row[2]
    directions = row[3]
    ner = row[6]

    # preserve the semi-structured nature of the dataset
    return tf_strings.join([
        '{"ner": ', ner, ', ',
        '"title": "', title, '", ',
        '"ingredients": ', ingredients, ', ',
        '"directions": ', directions, '}',
    ])


dataset = (
    tf_data.TextLineDataset(DATASET_FILE) # load the csv file line by line
    .skip(1) # skip the header row
    .shuffle(buffer_size=256) # store 256 shuffled records in memory at a time before reshuffling and refetching
    .map(lambda row: csv_row_to_json(row)) # map each row of the csv to a json-formatted string
    .ignore_errors() # ignore any errors thrown by misformatted rows of the csv
    .batch(BATCH_SIZE) # batch the dataset to train on multiple records at once
    .take(NUM_BATCHES) # only train on the first NUM_BATCHES batches
)

---
**Helper functions**

In [None]:
def generate_text(model, input_text, max_length=200):
    start = time.time()

    output = model.generate(input_text, max_length=max_length)
    print("\nOutput:")
    print(output)

    end = time.time()
    print(f"Total Time Elapsed: {end - start:.2f}s")
    

def get_optimizer_and_loss():
    optimizer = keras.optimizers.AdamW(
        learning_rate=5e-5,
        weight_decay=0.01,
        epsilon=1e-6,
        global_clipnorm=1.0,  # Gradient clipping.
    )
    # Exclude layernorm and bias terms from weight decay.
    optimizer.exclude_from_weight_decay(var_names=["bias"])
    optimizer.exclude_from_weight_decay(var_names=["gamma"])
    optimizer.exclude_from_weight_decay(var_names=["beta"])

    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return optimizer, loss

---
**Load pre-trained model and enable LoRA**

In [None]:
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    GPT2_PRESET,
    sequence_length=SEQ_LEN,
)

with strategy.scope():
    model = keras_nlp.models.GPT2CausalLM.from_preset(
        GPT2_PRESET,
        preprocessor=preprocessor,
    )
    model.backbone.enable_lora(rank=RANK)
model.summary()

---
**Finetune the model**

In [None]:
class SaveLoRACheckpoint(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        save_dir = "transfer_learning"
        os.makedirs(save_dir, exist_ok=True)  # Create the directory if it doesn't exist
        save_name = os.path.join(save_dir, f"checkpoint_{epoch:02d}.lora.h5")
        print(f"\n\nSaving checkpoint to {save_name}... ", end="")
        self.model.backbone.save_lora_weights(save_name)
        print("Done\n")

checkpoint_callback = SaveLoRACheckpoint()

with strategy.scope():
    optimizer, loss = get_optimizer_and_loss()

    model.compile(
        optimizer=optimizer,
        loss=loss,
        weighted_metrics=["accuracy"],
    )

In [None]:
lora_model.fit(
    dataset,
    epochs=EPOCHS,
    callbacks=[checkpoint_callback],
)

In [None]:
save_name = os.path.join("transfer_learning", "final_model.keras")
lora_model.save(save_name)