<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Examples: neural networks

We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

In [None]:
import jax
import jax.numpy as jnp

In [None]:
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))

In [None]:
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))

In [None]:
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init_model(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

### 8-way batch data parallelism

In [None]:
mesh = jax.make_mesh((8,), ('batch',))

In [None]:
from jax.sharding import NamedSharding, PartitionSpec as P

sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())

In [None]:
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)

In [None]:
loss_jit(params, batch)

Array(33.335655, dtype=float32)

In [None]:
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))

10.856516


In [None]:
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()

53.4 ms ± 34.2 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [None]:
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])

In [None]:
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()

407 ms ± 190 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### 4-way batch data parallelism and 2-way model tensor parallelism

In [None]:
mesh = jax.make_mesh((4, 2), ('batch', 'model'))

In [None]:
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])

In [None]:
replicated_sharding = NamedSharding(mesh, P())

In [None]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)

W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))

W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)

W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)

In [None]:
jax.debug.visualize_array_sharding(W2)

In [None]:
jax.debug.visualize_array_sharding(W3)

In [None]:
print(loss_jit(params, batch))

10.856519


In [None]:
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]

In [None]:
print(loss_jit(params, batch))

10.848966


In [None]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)

In [None]:
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()

51.6 ms ± 530 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


### Generating random numbers

JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.

JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.

However, the existing stable RNG implementation is not automatically partitionable, for historical reasons.

Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:

In [None]:
from jax.sharding import Mesh # Import the Mesh class from the correct module
from jax.sharding import NamedSharding, PartitionSpec as P

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
mesh = Mesh(jax.devices(), 'x')
x_sharding = NamedSharding(mesh, P('x'))
x = jax.device_put(jnp.arange(24), x_sharding)

On a partitioned input, the function `f` produces output that is also partitioned:

In [None]:
jax.debug.visualize_array_sharding(f(key, x))

But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:

In [None]:
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

Communicating? False


One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation:

In [None]:
jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

Communicating? False


The output is still partitioned:

In [None]:
jax.debug.visualize_array_sharding(f(key, x))

One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:

In [None]:
jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))

Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ]


## LLM with TPU

In [None]:
!pip install colab-env --quiet

import warnings

warnings.filterwarnings("ignore", message="You seem to be using the pipelines sequentially on GPU")

import colab_env
import os

access_token_write = os.getenv("HUGGINGFACE_ACCESS_TOKEN_WRITE")

from huggingface_hub import login

login(
  token=access_token_write,
  add_to_git_credential=True
)

In [None]:
!pip install keras_hub --upgrade --quiet
!pip install tensorflow --quiet
!pip install datasets -q
!pip install opencv-python-headless -q
!pip install tf-keras -q
!pip install -U transformers --quiet

#!pip install tensorflow_text==2.11  -q # replace 2.11 with your tensorflow version

In [None]:
!pip install --force-reinstall tensorflow_text -q
!apt-get update && apt-get install -y libstdc++6  # For Debian/Ubuntu-based systems

In [1]:
import jax

devices = jax.devices()

for device in devices:
    if device.platform == 'tpu':
        print("TPU detected!")
        break
else:
    print("No TPU detected.")

TPU detected!


In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch", or "tensorflow"

import tensorflow as tf
import keras_hub
model = keras_hub.models.Llama3CausalLM.from_preset(
    "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
)

In [3]:
response=model.generate("Hi there!")
print(response)

Hi there! I'm excited to connect with you!

Before I start chatting, I'd love to know a bit more about you!

Could you please share:

1. Your name
2. Where you're from (city or town)
3. What do you like to do in your free time
4. What kind of music do you enjoy listening to
5. Are there any hobbies or interests that you're particularly passionate about

Once I have this information, I'll do my best to provide you with personalized recommendations and advice!


data

In [8]:
# --- Data Preparation (Flight Planning Focus) ---

# Define necessary lists
airports = ["JFK", "LAX", "LGA", "BOS", "SFO", "ORD", "DFW", "ATL", "SEA", "MIA", "DEN", "IAH", "MSP", "DTW", "PHX", "CLT", "LAS", "MCO", "EWR", "PHL"]
aircraft_types = ["Boeing 747", "Airbus A320", "Boeing 777", "Boeing 737", "Airbus A330", "Boeing 757", "Airbus A321", "Airbus A319", "Boeing 787", "Embraer E190"]
weather_conditions = ["Clear", "Cloudy", "Rainy", "Snowy", "Windy"]

flight_data = []
# Function to create a flight data point
def create_flight_data_point(origin, destination, departure_date, aircraft, weather):
    return {
        "input": f"Plan a flight from {origin} to {destination}. Departure: {departure_date}, Aircraft: {aircraft}, Weather: {weather}",
        "output": "{'route': [], 'altitude': [], 'airspeed': [], 'fuel': []}"  # Placeholder for output
    }

number_routes = 100
# Generate more flight data points
import random
for _ in range(number_routes):  # Generate 90 more examples
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:  # Ensure origin and destination are different
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)
    flight_data.append(create_flight_data_point(origin, destination, departure_date, aircraft, weather))

for _ in range(number_routes):  # Generate 90 more examples
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:  # Ensure origin and destination are different
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)
    flight_data.append(create_flight_data_point(origin, destination, departure_date, aircraft, weather))

In [None]:
from datasets import Dataset
flight_dataset = Dataset.from_list(flight_data)
print(flight_dataset)

In [None]:
!pip install --upgrade transformers -q
!pip install --upgrade datasets -q
!pip install --upgrade optax -q
!pip install flax -q

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["KERAS_BACKEND"] = "jax"

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from transformers import AutoTokenizer, TrainingArguments, FlaxAutoModelForCausalLM  # Import FlaxAutoModelForCausalLM
import tensorflow as tf
import warnings
warnings.filterwarnings("ignore", message="You seem to be using the pipelines sequentially on GPU")

import optax
from flax.training import train_state  # Import train_state

# --- TPU Detection ---
devices = jax.devices()
for device in devices:
    if device.platform == 'tpu':
        print("TPU detected!")
        break
else:
    print("No TPU detected.")

print('\n\n')
print("TensorFlow version:", tf.__version__)
print("JAX version:", jax.__version__)
print("JAX devices:", jax.devices())
print("Num devices:", jax.device_count())
print('\n\n')

# Ensure that JAX sees the TPU:
try:
    jax.devices("tpu")[0]
except RuntimeError:
    print("Warning: TPU not found. Code will run on CPU or GPU.")

# Model identifier
model_id = "meta-llama/Llama-2-7b-chat-hf"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

import gc; gc.collect()

# --- Data Preparation ---
from datasets import Dataset
import random

airports = ["JFK", "LAX", "LGA", "BOS", "SFO", "ORD", "DFW", "ATL", "SEA", "MIA", "DEN", "IAH", "MSP", "DTW", "PHX", "CLT", "LAS", "MCO", "EWR", "PHL"]
aircraft_types = ["Boeing 747", "Airbus A320", "Boeing 777", "Boeing 737", "Airbus A330", "Boeing 757", "Airbus A321", "Airbus A319", "Boeing 787", "Embraer E190"]
weather_conditions = ["Clear", "Cloudy", "Rainy", "Snowy", "Windy"]

flight_data = []

def create_flight_data_point(origin, destination, departure_date, aircraft, weather):
    return {
        "input": f"Plan a flight from {origin} to {destination}. Departure: {departure_date}, Aircraft: {aircraft}, Weather: {weather}",
        "output": "{'route': [], 'altitude': [], 'airspeed': [], 'fuel': []}"
    }

number_routes = 100
for _ in range(number_routes * 2):
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)
    flight_data.append(create_flight_data_point(origin, destination, departure_date, aircraft, weather))

flight_dataset = Dataset.from_list(flight_data)


# Tokenize and format the data
def tokenize_function(examples):
    inputs = examples['input']
    outputs = examples['output']
    tokenized_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=128)
    tokenized_outputs = tokenizer(outputs, padding="max_length", truncation=True, max_length=128)
    return {
        "input_ids": tokenized_inputs["input_ids"],
        "attention_mask": tokenized_inputs["attention_mask"],
        "labels": tokenized_outputs["input_ids"]
    }

tokenized_dataset = flight_dataset.map(tokenize_function, batched=True)

# --- TPU Configuration and Sharding ---
mesh = Mesh(jax.devices('tpu'), ('data',))
data_sharding = NamedSharding(mesh, P('data',))

import json  # Import the json module

def prepare_data_for_jax(batch):
    return {
        key: jax.device_put(jnp.array(value, dtype=jnp.bfloat16), data_sharding)  # Change dtype to jnp.bfloat16 for all inputs
        if key in ["input_ids", "attention_mask", "labels"] else value
        for key, value in batch.items() if key in ["input_ids", "attention_mask", "labels"]
    }

def loss_fn(batch):
    labels = batch.pop("labels")
    # Filter batch for valid model inputs and convert to JAX arrays
    # Instead of model.input_names, directly specify the expected inputs
    model_inputs = {
        "token_ids": batch['input_ids'][jnp.newaxis, ...],  # Add a batch dimension
        "padding_mask": batch['attention_mask'][jnp.newaxis, ...]  # Add a batch dimension
    }
    # The model likely returns the logits directly, not in a dictionary
    logits = model(model_inputs)  # Access the logits directly

    # Remove or comment out the line converting labels to float32
    # labels = labels.astype(jnp.float32)

    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits.reshape((-1, logits.shape[-1])), labels=labels.reshape((-1,))
    ).mean()
    return loss

# --- Fine-Tuning ---
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=128,
    gradient_checkpointing=True,
    optim="adamw_8bit",
    save_steps=500,
    logging_steps=250,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    tpu_num_cores=8,
    fp16=False,
    bf16=True,
    push_to_hub=False,
)

# Training loop
from tqdm import tqdm  # Import tqdm

for epoch in range(training_args.num_train_epochs):
    # Create a tqdm progress bar for each epoch
    with tqdm(tokenized_dataset, unit="batch", desc=f"Epoch {epoch + 1}/{training_args.num_train_epochs}") as pbar:
        for batch in pbar:
            batch = prepare_data_for_jax(batch)

            # Calculate loss and gradients using jax.value_and_grad
            loss, grads = jax.value_and_grad(loss_fn)(batch)

            # Update the progress bar with loss information
            pbar.set_postfix({"loss": loss.item()})



# Save the fine-tuned model
#model.save_pretrained("./fine_tuned_llama", params=state.params) # Commented out as `state` is not defined