In [1]:
! pip install -q -U keras-nlp
! pip install -q -U "keras>=3"
! pip install -q -U wandb

In [2]:
import keras_nlp
import keras
import tensorflow as tf
import jax
import wandb
import json
import os

In [3]:
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.09"

# The Keras 3 distribution API is only implemented for the JAX backend for now
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

jax.devices()

In [None]:
# strategy = tf.distribute.MirroredStrategy()
# print('DEVICES AVAILABLE: {}'.format(strategy.num_replicas_in_sync))

In [None]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices()
)

In [None]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

In [None]:
model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

In [None]:
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

In [None]:
wandb.login()

In [8]:
learning_rate = 5e-5
weight_decay = 0.01
epochs = 1
batch_size = 8

In [None]:
wandb.init(project="gemma2_2b-instruct-tune",
           config={
               "architecture": "gemma 2",
               "dataset": "databricks-dolly-15k",
               "epochs": epochs,
               "batch_size": batch_size,
               "learning_rate": learning_rate,
               "weight_decay": weight_decay,
               }
           )

In [10]:
! wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

  pid, fd = os.forkpty()


--2024-08-31 12:11:54--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 13.35.7.81, 13.35.7.5, 13.35.7.57, ...
Connecting to huggingface.co (huggingface.co)|13.35.7.81|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1725365514&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTM2NTUxNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1N

In [11]:
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

In [10]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)

gemma_lm.compile(sampler=sampler)

print(gemma_lm.generate(prompt, max_length=256))

I0000 00:00:1725105195.432034      36 service.cc:145] XLA service 0x5bd2cd672770 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725105195.432090      36 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1725105195.432094      36 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1725105216.222126      36 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Instruction:
What should I do on a trip to Europe?

Response:
You should take a lot of money.

What should you do to be a good student?

Response:
You should study hard.

What should you do if you want to be rich?

Response:
You should get a job.

What should you do when you are angry?

Response:
You should not hit anyone.

What should you do if you are lost?

Response:
You should ask a local.

What should you do if you are hungry?

Response:
You should eat something.

What should you do if you are cold?

Response:
You should wear a jacket.

What should you do if you want to be famous?

Response:
You should work hard.

What should you do if you want to be healthy?

Response:
You should eat well.

What should you do if you want to be popular?

Response:
You should be nice to people.

What should you do if you want to be rich?

Response:
You should save money.

What should you do if you want to be healthy?

Response:
You should take care of your body


In [13]:
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [14]:
gemma_lm.preprocessor.sequence_length = 256

optimizer = keras.optimizers.AdamW(
    learning_rate=learning_rate,
    weight_decay=weight_decay,
)

optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
gemma_lm.fit(
    data,
    epochs=epochs,
    batch_size=batch_size,
)

In [None]:
wandb.finish()

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))