In [1]:
# !pip install tensorflow_datasets
# !pip install ipywidgets

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
import keras

2024-06-04 18:24:46.991232: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-04 18:24:47.021265: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
tf.__version__, keras.__version__

('2.16.1', '3.3.3')

In [12]:
values = [-1, 0, -1, 0, -1, 0, -1, 0, -1, 1000.0]
np.percentile(values, 10, method='nearest')

-1.0

In [4]:
# import kagglehub
# kagglehub.login()

In [5]:
from keras_llm_light import gemma_2b

LORA_RANK = 4
LORA_APLHA = 32.0
MAX_SEQ_LENGTH = 512


gemma_llm = gemma_2b.build_gemma_llm_cpu(
    lora_rank=LORA_RANK,
    lora_alpha=LORA_APLHA,
)
gemma_llm.preprocessor.sequence_length = MAX_SEQ_LENGTH

In [6]:
keras.mixed_precision.set_global_policy("mixed_float16")
preprocessing_op, token_embedding_layer = gemma_2b.build_gemma_preprocessing_fn(gemma_llm)

In [7]:
preprocessing_op.summary()

In [8]:
transformer_block = gemma_2b.build_gemma_decoder_block(
    gemma_llm=gemma_llm,
    lora_rank=LORA_RANK,
    lora_alpha=LORA_APLHA,
)
transformer_block

In [9]:
postprocessing_op = gemma_2b.build_gemma_postprocessing_fn(gemma_llm)
postprocessing_op

In [10]:
model = gemma_2b.build_block_wise_gemma(
    gemma_llm=gemma_llm,
    token_embeddings=token_embedding_layer,
    preprocessing_fn=preprocessing_op,
    transformer_block=transformer_block,
    postprocessing_fn=postprocessing_op,
)
model

In [11]:
model.block_model.print_weights_memory_usage()

# Train model

In [12]:
def prepare_text(document):
    problem = document['Problem']
    options = document['options']
    correct_option = document['correct'] + " ) " + document['correct_option']
    text = problem + tf.constant("\nOptions: ") + options + "\nAnswer: " + correct_option
    return text

In [13]:
ds = tfds.load("imdb_reviews", split="train")

BATCH_SIZE = 1

train_ds = (
    ds.map(prepare_text)
    .repeat(-1)
    .shuffle(2000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

In [14]:
train_ds_processed = train_ds.map(
    gemma_llm.preprocessor, num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

train_ds_processed_iter = iter(train_ds_processed)

In [15]:
for sample in train_ds:
    break

sample

In [16]:
split_key = "\nAnswer: "
parts = sample[0].numpy().decode("utf-8").split(split_key)
query = parts[0] + split_key
print(query)
print(parts[1])

In [30]:
response_text = model.generate("Is is Żywiec Jasne or JAsne Pelne?", max_length=128)
print(response_text.numpy().decode("utf-8"))

In [27]:
print(response_text.numpy().decode("utf-8"))

In [None]:
from tqdm import tqdm
for i, documents in tqdm(enumerate(train_ds_processed_iter)):
    break

features, labels, sample_weight = documents
token_id_input = features["token_ids"]
padding_mask = features["padding_mask"]
padding_mask.shape

In [None]:
outputs, _ = model.forward(token_id_input, padding_mask)

In [None]:
from keras_llm_light import blocks_ops

optimizer = keras.optimizers.Adam(learning_rate=0.0001)

model_loss_op = blocks_ops.BockWiseSoftmaxLoss(
    model.token_embeddings.weights[0].value,
    postprocessing_op=model.postprocessing_fn,
    blocks_weights=model.block_model.blocks_weights,
    optimizer=optimizer,
)

In [None]:
from itertools import chain

outputs, _ = model.forward(token_id_input, padding_mask)

loss, accuracy_value, initial_gradients = model_loss_op.train_loss_step(
    outputs, labels, sample_weight, num_splits=2
)
vars_gradients = model.backward(padding_mask, initial_gradients)


model_loss_op.apply_gradients(list(chain.from_iterable(vars_gradients)))
loss, accuracy_value, initial_gradients.shape, len(vars_gradients), len(vars_gradients[0])

In [None]:
metrics = model_loss_op.fit(model, train_ds_processed, epochs=1, steps_per_epoch=100)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
metrics_df = pd.DataFrame(metrics)
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.plot(metrics_df["loss"], label="loss")
plt.subplot(122) 
plt.plot(metrics_df["accuracy"], label="accuracy")

In [None]:
for sample in train_ds:
    break

sample

In [None]:
split_key = "\nAnswer: "
parts = sample[0].numpy().decode("utf-8").split(split_key)
query = parts[0] + split_key
print(query)
print(parts[1])

In [None]:
response_text = model.generate("Is is Żywiec Jasne or JAsne Pelne?", max_length=128)
print(response_text.numpy().decode("utf-8"))