In [None]:
!pip install -q -U keras-nlp tensorflow-text
!pip install -q -U tensorflow-cpu

In [None]:
import jax
jax.devices()

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [None]:
import keras
import keras_nlp

In [None]:
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

In [None]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

In [None]:
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")

In [None]:
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
    print(f'{variable.path:<48}  {str(variable.shape):<14}  {str(variable.value.sharding.spec)}')

In [None]:
import re
import json

In [None]:
def process_whatsapp_chat(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    processed_data = []
    current_instruction = []

    for line in lines:
        match = re.match(r'(\d{1,2}/\d{1,2}/\d{2,4}), \d{1,2}:\d{2}\s?[apm]{2} - (.*?): (.*)', line)
        if match:
            name = match.group(2)
            message = match.group(3)

            if "B" in name:
                current_instruction.append(f"S: {message}")
            elif "A" in name and current_instruction:
                combined_instruction = " ".join(current_instruction)  
                response = f"A: {message}"

                template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
                processed_data.append(template.format(instruction=combined_instruction, response=response))
                current_instruction = []

    return processed_data

chat_file_path = '/kaggle/working/ChatData.txt'
data = process_whatsapp_chat(chat_file_path)


gemma_lm.backbone.enable_lora(rank=8)
gemma_lm.preprocessor.sequence_length = 512
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=5e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
gemma_lm.summary()

In [None]:
gemma_lm.fit(data, epochs=1, batch_size=4)

In [None]:
print(gemma_lm.generate("Instruction:\nHow was your day?\n\nResponse:\n", max_length=512))