<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/SFT_TPU_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# 2. INSTALL LIBRARIES
print("Installing/Updating libraries...")
!pip install -q --upgrade datasets huggingface_hub tensorflow keras-hub keras jax jaxlib


Installing/Updating libraries...


In [2]:
import os

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="jax._src.cloud_tpu_init")  # ← Kills the hugepages warning

import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["JAX_DISABLE_FLASH_ATTENTION"] = "1"   # Critical fix


import jax
import keras
import keras_hub
import tensorflow as tf
from datasets import load_dataset
import logging

# Silence all remaining JAX/Abseil logs
logging.getLogger("absl").setLevel(logging.CRITICAL)
jax.config.update("jax_debug_nans", False)

# Accurate TPU Detection
device = jax.devices()[0]
kind = device.device_kind if hasattr(device, 'device_kind') else device.platform.upper()
if "v6" in kind.lower(): tpu_type = "TPU v6 lite"
elif "v5e" in kind.lower(): tpu_type = "TPU v5e"
elif "v5p" in kind.lower(): tpu_type = "TPU v5p (Trillium)"
elif "v4" in kind.lower(): tpu_type = "TPU v4"
else: tpu_type = kind

print("="*70)
print("TPU TYPE IS USED IT".center(70))
print(tpu_type.center(70))
print("="*70)
print(f"Devices: {jax.devices()} | Cores: {jax.device_count()}\n")


# Global precision (best for TPU)
keras.config.set_floatx("bfloat16")
keras.mixed_precision.set_global_policy("mixed_bfloat16")

# Verify TPU
#print(f"Devices: {jax.devices()}")
#print(f"TPU cores: {jax.device_count()}")

# ----------------- DATA -----------------
print("\nLoading dataset...")
hf_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

# Tiny subset for demo (50 examples)
tiny_data = [
    f"Instruction: {x['instruction']}\nResponse: {x['response']}"
    for x in hf_dataset.select(range(50))
]

train_ds = tf.data.Dataset.from_tensor_slices(tiny_data) \
    .batch(2, drop_remainder=True) \
    .prefetch(tf.data.AUTOTUNE)

# ----------------- MODEL -----------------
print("Loading Gemma 2B IT...")
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(
    "hf://google/gemma-1.1-2b-it-keras",
    dtype="bfloat16"
)

# Enable LoRA (rank 4 = very efficient)
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.preprocessor.sequence_length = 256  # Slightly longer, still fits

# ----------------- COMPILE -----------------
gemma_lm.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=5e-5, weight_decay=0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# ----------------- TRAIN -----------------
print("\n" + "="*60)
print("STARTING CLEAN & SILENT TRAINING ON TPU")
print("="*60)

gemma_lm.fit(train_ds, epochs=1, verbose=1)

print("="*60)
print("TRAINING COMPLETED SUCCESSFULLY – 100% CLEAN LOGS!")
print("="*60)

# Optional: Quick test
print("\nQuick generation test:")
print(gemma_lm.generate("Question: What is the capital of France?\nAnswer:", max_length=50))

                         TPU TYPE IS USED IT                          
                             TPU v6 lite                              
Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)] | Cores: 1


Loading dataset...
Loading Gemma 2B IT...

STARTING CLEAN & SILENT TRAINING ON TPU
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 13ms/step - loss: 1.5000 - sparse_categorical_accuracy: 0.4414
TRAINING COMPLETED SUCCESSFULLY – 100% CLEAN LOGS!

Quick generation test:
Question: What is the capital of France?
Answer: Paris.

The answer is correct. Paris is the capital of France. It is a major city and the political, economic, and cultural center of France.
