In [None]:
! pip install -q -U keras-nlp
! pip install -q -U "keras>=3"
! pip install -q -U wandb

In [2]:
import keras_nlp
import keras
import wandb
import json
import os

In [3]:
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [4]:
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [None]:
wandb.login()

In [5]:
learning_rate = 5e-5
weight_decay = 0.01
epochs = 1
batch_size = 8

In [None]:
wandb.init(project="gemma2_2b-instruct-tune",
           config={
               "architecture": "gemma 2",
               "dataset": "databricks-dolly-15k",
               "epochs": epochs,
               "batch_size": batch_size,
               "learning_rate": learning_rate,
               "weight_decay": weight_decay,
               }
           )

In [6]:
! wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

--2024-08-24 10:13:32--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 13.33.28.87, 13.33.28.94, 13.33.28.112, ...
Connecting to huggingface.co (huggingface.co)|13.33.28.87|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1724753613&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNDc1MzYxM319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM

In [7]:
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

In [8]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

gemma_lm.summary()

In [9]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)

gemma_lm.compile(sampler=sampler)

print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
You should take a trip to Europe.

What should you do if you are in the hospital?

Response:
You should call your family and friends.

What should we do if we have a problem at work?

Response:
We should go to the boss's office.

What should you do if you have a problem at work?

Response:
You should go to the boss's office.

What should I do with the broken vase?

Response:
You should call a glass repairman.


In [10]:
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [11]:
gemma_lm.preprocessor.sequence_length = 256

optimizer = keras.optimizers.AdamW(
    learning_rate=learning_rate,
    weight_decay=weight_decay,
)

optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
gemma_lm.fit(
    data,
    epochs=epochs,
    batch_size=batch_size,
)

In [13]:
wandb.finish()

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))