<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/LLM_Fine_Tuning_Demo_on_TPU_with_JAX_and_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install flax optax tensorflow tensorflow_datasets jax jaxlib -q

In [1]:
!pip show jax
print('\n')
!pip show flax
print('\n')
!pip show optax
print('\n')
!pip show tensorflow
print('\n')
!pip show tensorflow_datasets
print('\n')
!pip show optax
print('\n')
!pip show tensorflow
print('\n')
!pip show tensorflow_datasets
print('\n')

Name: jax
Version: 0.5.3
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: chex, distrax, flax, optax, orbax-checkpoint


Name: flax
Version: 0.10.6
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author: 
Author-email: Flax team <flax-dev@google.com>
License: 
Location: /usr/local/lib/python3.11/dist-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: 


Name: optax
Version: 0.2.5
Summary: A gradient processing and optimization library in JAX.
Home-page: https://github.com/google-deepmind/optax
Author: 
Author-email: Google DeepMind <optax-dev@google.com>
License: 
Location: /usr/local/lib/python3.11/dist-

In [None]:
!pip install datasets -q

In [3]:
!pip show datasets

Name: datasets
Version: 4.0.0
Summary: HuggingFace community-driven open-source library of datasets
Home-page: https://github.com/huggingface/datasets
Author: HuggingFace Inc.
Author-email: thomas@huggingface.co
License: Apache 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: dill, filelock, fsspec, huggingface-hub, multiprocess, numpy, packaging, pandas, pyarrow, pyyaml, requests, tqdm, xxhash
Required-by: 


In [None]:
!pip uninstall -y tensorflow
!pip install tensorflow-cpu -q

In [6]:
print('\n')
!pip show tensorflow-cpu



Name: tensorflow_cpu
Version: 2.20.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: absl-py, astunparse, flatbuffers, gast, google_pasta, grpcio, h5py, keras, libclang, ml_dtypes, numpy, opt_einsum, packaging, protobuf, requests, setuptools, six, tensorboard, termcolor, typing_extensions, wrapt
Required-by: 


In [7]:
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import grad, jit, random
from functools import partial
import optax
import time

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from datasets import load_dataset
from flax.training import train_state

import warnings
warnings.filterwarnings('ignore')

# 1. TPU Initialization (JAX style)
try:
    devices = jax.devices()
    tpu_devices = [d for d in devices if d.platform == 'tpu']
    if not tpu_devices:
        raise ValueError("No TPU devices found.")
    print(f"Found JAX devices: {devices}")
    print(f"Number of TPU devices available: {len(tpu_devices)}")
except ValueError as e:
    print(f"ERROR: {e}. Please ensure your Colab runtime is set to TPU.")
    raise SystemExit('TPU not found or not initialized for JAX.')

# Define a mesh for sharding across TPU cores
num_tpu_cores = len(tpu_devices)
mesh = Mesh(tpu_devices, axis_names=('data',))
print(f"JAX Mesh created with axis_names: {mesh.axis_names}")

Found JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Number of TPU devices available: 8
JAX Mesh created with axis_names: ('data',)


In [None]:
# 2. Load JAX-based LLM and Tokenizer
print('\nLoading LLM and tokenizer...')
model_id = 'EleutherAI/gpt-neo-125M'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = FlaxAutoModelForCausalLM.from_pretrained(model_id)
print('Model loaded.')

# 3. Prepare the Dataset
print('\nLoading and preprocessing dataset...')
dataset = load_dataset('databricks/databricks-dolly-15k', split='train').train_test_split(test_size=0.1)
train_dataset = dataset['train']
eval_dataset = dataset['test']

def preprocess_function(examples):
    return tokenizer(examples['response'], max_length=128, truncation=True, padding='max_length')

train_tokenized_dataset = train_dataset.map(preprocess_function, batched=True)
eval_tokenized_dataset = eval_dataset.map(preprocess_function, batched=True)

train_tokenized_dataset.set_format(type='jax', columns=['input_ids', 'attention_mask'])
eval_tokenized_dataset.set_format(type='jax', columns=['input_ids', 'attention_mask'])
print('Datasets preprocessed and formatted for JAX.')

# 4. Define the Training and Evaluation Steps
class TrainState(train_state.TrainState):
    pass

learning_rate = 1e-5
optimizer = optax.adamw(learning_rate=learning_rate)

@jit
def train_step(state, batch):
    def loss_fn(params):
        variables = {'params': params}
        logits = state.apply_fn(variables, **batch)[0]
        labels = batch['input_ids']
        one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
        loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
        return jnp.mean(loss)

    grad_fn = grad(loss_fn)
    grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads)

@jit
def eval_step(params, batch):
    variables = {'params': params}
    logits = model.module.apply(variables, **batch)[0]
    labels = batch['input_ids']
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
    return jnp.mean(loss)

# 5. Initialize Model State and Shard Parameters
print('\nInitializing and sharding model state...')
params = model.params

The training process has run into a XlaRuntimeError with the message RESOURCE_EXHAUSTED. This means that the TPU has run out of memory, so it can't load and run the program.

* The Problem

The error message specifically says, "Attempting to reserve 4.24G at the bottom of memory. That was not possible. There are 4.21G free...". This indicates that the batch size and model size combined are too large for the TPU's memory. The TPU has a finite amount of memory, and your current configuration requires more than is available.

* The Solution

To resolve this, you need to reduce the amount of memory consumed during each training step. Here are the most effective ways to do that:

1. Decrease the Global Batch Size: The global batch size is the total number of examples processed by all TPU cores in a single step. Reducing this will directly decrease the memory required for the input data and the gradients. In your code, you can change global_batch_size = 128 to a smaller number, like 64 or 32.

2. Use a Smaller Model: If reducing the batch size isn't enough, consider using a model with fewer parameters, as model weights are a significant portion of memory usage.

3. Use bfloat16 Precision: While your code already uses this implicitly, ensuring you are using bfloat16 precision is crucial. TPUs are highly optimized for this data type, which uses half the memory of standard 32-bit floats.



In [9]:
replicated_params = jax.device_put(params, NamedSharding(mesh, P()))
state = TrainState.create(apply_fn=model.module.apply, params=replicated_params, tx=optimizer)
print('Model state sharded.')

num_epochs = 10
# 6. Training and Evaluation Loop
print(f'\nStarting training for {num_epochs} epochs...')
global_batch_size = 64
per_device_batch_size = global_batch_size // num_tpu_cores

start_time = time.time()
for epoch in range(num_epochs):
    epoch_start_time = time.time()
    total_train_loss = 0.0
    num_train_batches = 0

    for batch in train_tokenized_dataset.iter(batch_size=global_batch_size):
        input_shape = batch['input_ids'].shape
        position_ids = jnp.broadcast_to(jnp.arange(input_shape[-1])[None, :], input_shape)
        batch['position_ids'] = position_ids

        state = train_step(state, batch)

        batch_loss = eval_step(state.params, batch)
        total_train_loss += batch_loss.item()
        num_train_batches += 1

    avg_train_loss = total_train_loss / num_train_batches

    total_eval_loss = 0.0
    num_eval_batches = 0

    for batch in eval_tokenized_dataset.iter(batch_size=global_batch_size):
        input_shape = batch['input_ids'].shape
        position_ids = jnp.broadcast_to(jnp.arange(input_shape[-1])[None, :], input_shape)
        batch['position_ids'] = position_ids

        batch_eval_loss = eval_step(state.params, batch)
        total_eval_loss += batch_eval_loss.item()
        num_eval_batches += 1

    avg_eval_loss = total_eval_loss / num_eval_batches

    epoch_end_time = time.time()
    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Eval Loss = {avg_eval_loss:.4f} (Time: {epoch_end_time - epoch_start_time:.2f}s)")

end_time = time.time()
print(f'\nTraining complete in {end_time - start_time:.2f} seconds.')

Model state sharded.

Starting training for 10 epochs...
Epoch 1: Train Loss = 2.4240, Eval Loss = 0.1080 (Time: 152.73s)
Epoch 2: Train Loss = 0.0742, Eval Loss = 0.0542 (Time: 105.37s)
Epoch 3: Train Loss = 0.0460, Eval Loss = 0.0387 (Time: 105.27s)
Epoch 4: Train Loss = 0.0326, Eval Loss = 0.0282 (Time: 105.37s)
Epoch 5: Train Loss = 0.0232, Eval Loss = 0.0204 (Time: 105.38s)
Epoch 6: Train Loss = 0.0165, Eval Loss = 0.0152 (Time: 105.26s)
Epoch 7: Train Loss = 0.0118, Eval Loss = 0.0117 (Time: 105.38s)
Epoch 8: Train Loss = 0.0088, Eval Loss = 0.0092 (Time: 105.44s)
Epoch 9: Train Loss = 0.0067, Eval Loss = 0.0075 (Time: 105.44s)
Epoch 10: Train Loss = 0.0053, Eval Loss = 0.0063 (Time: 105.41s)

Training complete in 1101.07 seconds.
