<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Examples: neural networks

We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

In [None]:
import jax
import jax.numpy as jnp

In [None]:
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))

In [None]:
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))

In [None]:
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init_model(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

### 8-way batch data parallelism

In [None]:
mesh = jax.make_mesh((8,), ('batch',))

In [None]:
from jax.sharding import NamedSharding, PartitionSpec as P

sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())

In [None]:
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)

In [None]:
loss_jit(params, batch)

Array(33.335655, dtype=float32)

In [None]:
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))

10.856516


In [None]:
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()

53.4 ms ± 34.2 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [None]:
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])

In [None]:
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()

407 ms ± 190 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### 4-way batch data parallelism and 2-way model tensor parallelism

In [None]:
mesh = jax.make_mesh((4, 2), ('batch', 'model'))

In [None]:
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])

In [None]:
replicated_sharding = NamedSharding(mesh, P())

In [None]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)

W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))

W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)

W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)

In [None]:
jax.debug.visualize_array_sharding(W2)

In [None]:
jax.debug.visualize_array_sharding(W3)

In [None]:
print(loss_jit(params, batch))

10.856519


In [None]:
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]

In [None]:
print(loss_jit(params, batch))

10.848966


In [None]:
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)

In [None]:
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()

51.6 ms ± 530 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


### Generating random numbers

JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.

JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.

However, the existing stable RNG implementation is not automatically partitionable, for historical reasons.

Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:

In [None]:
from jax.sharding import Mesh # Import the Mesh class from the correct module
from jax.sharding import NamedSharding, PartitionSpec as P

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
mesh = Mesh(jax.devices(), 'x')
x_sharding = NamedSharding(mesh, P('x'))
x = jax.device_put(jnp.arange(24), x_sharding)

On a partitioned input, the function `f` produces output that is also partitioned:

In [None]:
jax.debug.visualize_array_sharding(f(key, x))

But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:

In [None]:
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

Communicating? False


One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation:

In [None]:
jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

Communicating? False


The output is still partitioned:

In [None]:
jax.debug.visualize_array_sharding(f(key, x))

One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:

In [None]:
jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))

Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ]


## LLM with TPU

In [None]:
!pip install colab-env --quiet

import warnings

warnings.filterwarnings("ignore", message="You seem to be using the pipelines sequentially on GPU")

import colab_env
import os

access_token_write = os.getenv("HUGGINGFACE_ACCESS_TOKEN_WRITE")

from huggingface_hub import login

login(
  token=access_token_write,
  add_to_git_credential=True
)

In [None]:
!pip install keras_hub --upgrade --quiet
!pip install tensorflow --quiet
!pip install datasets -q
!pip install opencv-python-headless -q
!pip install tf-keras -q
!pip install -U transformers --quiet

#!pip install tensorflow_text==2.11  -q # replace 2.11 with your tensorflow version

In [None]:
!pip install --force-reinstall tensorflow_text -q
!apt-get update && apt-get install -y libstdc++6  # For Debian/Ubuntu-based systems

In [None]:
import jax

devices = jax.devices()

for device in devices:
    if device.platform == 'tpu':
        print("TPU detected!")
        break
else:
    print("No TPU detected.")

TPU detected!


In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch", or "tensorflow"

import tensorflow as tf
import keras_hub
model = keras_hub.models.Llama3CausalLM.from_preset(
    "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
)

In [None]:
response=model.generate("Hi there!")
print(response)

Hi there! I'm excited to connect with you!

Before I start chatting, I'd love to know a bit more about you!

Could you please share:

1. Your name
2. Where you're from (city or town)
3. What do you like to do in your free time
4. What kind of music do you enjoy listening to
5. Are there any hobbies or interests that you're particularly passionate about

Once I have this information, I'll do my best to provide you with personalized recommendations and advice!


data

In [None]:
# --- Data Preparation (Flight Planning Focus) ---

# Define necessary lists
airports = ["JFK", "LAX", "LGA", "BOS", "SFO", "ORD", "DFW", "ATL", "SEA", "MIA", "DEN", "IAH", "MSP", "DTW", "PHX", "CLT", "LAS", "MCO", "EWR", "PHL"]
aircraft_types = ["Boeing 747", "Airbus A320", "Boeing 777", "Boeing 737", "Airbus A330", "Boeing 757", "Airbus A321", "Airbus A319", "Boeing 787", "Embraer E190"]
weather_conditions = ["Clear", "Cloudy", "Rainy", "Snowy", "Windy"]

flight_data = []
# Function to create a flight data point
def create_flight_data_point(origin, destination, departure_date, aircraft, weather):
    return {
        "input": f"Plan a flight from {origin} to {destination}. Departure: {departure_date}, Aircraft: {aircraft}, Weather: {weather}",
        "output": "{'route': [], 'altitude': [], 'airspeed': [], 'fuel': []}"  # Placeholder for output
    }

number_routes = 100
# Generate more flight data points
import random
for _ in range(number_routes):  # Generate 90 more examples
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:  # Ensure origin and destination are different
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)
    flight_data.append(create_flight_data_point(origin, destination, departure_date, aircraft, weather))

for _ in range(number_routes):  # Generate 90 more examples
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:  # Ensure origin and destination are different
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)
    flight_data.append(create_flight_data_point(origin, destination, departure_date, aircraft, weather))

In [None]:
from datasets import Dataset
flight_dataset = Dataset.from_list(flight_data)
print(flight_dataset)

In [None]:
!pip install flax --quiet
!pip install --upgrade transformers -q

In [None]:
!pip install flax --quiet
!pip install --upgrade transformers -q
!pip install datasets evaluate -q

import warnings
warnings.filterwarnings("ignore", message="You seem to be using the pipel")

import jax
import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
from datasets import Dataset
import evaluate
import optax
import numpy as np
import os
import evaluate
from typing import Dict, Any
import sys


from flax.training import train_state

# TPU Detection and Device Assignment
try:
    tpu_device = jax.devices("tpu")[0]  # Get the first TPU device
    USE_TPU = True
    print("TPU detected!")
except RuntimeError:
    tpu_device = None  # If no TPU is found, set tpu_device to None
    USE_TPU = False
    print("Warning: TPU not found. Code will run on CPU or GPU.")

def simple_op(x):
       return x + 1

x = jnp.array([1, 2, 3])
x_tpu = jax.device_put(x, tpu_device)
result = simple_op(x_tpu).block_until_ready()

print(result)


model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)


# --- Data Preparation (Flight Planning Focus) ---
# Define necessary lists
airports = ["JFK", "LAX", "LGA", "BOS", "SFO", "ORD", "DFW", "ATL", "SEA", "MIA", "DEN", "IAH", "MSP", "DTW", "PHX", "CLT", "LAS", "MCO", "EWR", "PHL"]
aircraft_types = ["Boeing 747", "Airbus A320", "Boeing 777", "Boeing 737", "Airbus A330", "Boeing 757", "Airbus A321", "Airbus A319", "Boeing 787", "Embraer E190"]
weather_conditions = ["Clear", "Cloudy", "Rainy", "Snowy", "Windy"]

flight_data = []
# Function to create a flight data point
def create_flight_data_point(origin, destination, departure_date, aircraft, weather, label): # Added label parameter
    return {
        "input": f"Plan a flight from {origin} to {destination}. Departure: {departure_date}, Aircraft: {aircraft}, Weather: {weather}",
        "label": label  # Store the label directly, name changed to "label"
    }

number_routes = 10000
# Generate more flight data points
import random
for _ in range(number_routes):  # Generate examples with labels
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:  # Ensure origin and destination are different
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)
    label = random.randint(0, 1) # Assign random labels (0 or 1) for demonstration
    flight_data.append(create_flight_data_point(origin, destination, departure_date, aircraft, weather, label))



# Create the dataset
dataset = Dataset.from_list(flight_data)

# Create train/validation split using train_test_split
train_testvalid = dataset.train_test_split(test_size=0.2, seed=42)  # Split into 80% train, 20% test+validation
train_dataset = train_testvalid["train"]
testvalid_dataset = train_testvalid["test"]

# Further split test+validation into test and validation
test_valid = testvalid_dataset.train_test_split(test_size=0.5, seed=42)  # Split the test+validation into 50% each
eval_dataset = test_valid["test"]  # Now you have separate train, eval, and test datasets
test_dataset = test_valid["train"]

# Now you can use train_dataset and eval_dataset
small_train_dataset = train_dataset.shuffle(seed=42).select(range(1000))
small_eval_dataset = eval_dataset.shuffle(seed=42).select(range(200))

def tokenize_function(examples):
    inputs = examples['input']
    tokenized_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=128)
    return {
        "input_ids": tokenized_inputs["input_ids"],
        "attention_mask": tokenized_inputs["attention_mask"],
        "label": examples['label']  # Access the label directly, name changed to "label"
    }

# Apply tokenization
tokenized_train_dataset = small_train_dataset.map(tokenize_function, batched=True)
tokenized_eval_dataset = small_eval_dataset.map(tokenize_function, batched=True)

# Convert to JAX and filter out invalid data points
def convert_and_filter(batch):
    # Convert only the 'label' to JAX arrays, keep 'input_ids' and 'attention_mask' as lists
    batch = {k: jnp.array(v) if k == 'label' else v for k, v in batch.items()}

    # Filter out invalid data points based on input_ids shape
    # updated to check length instead of truthiness, and iterate through outer list
    valid_indices = [i for i, ids in enumerate(batch['input_ids']) if len(ids) > 0]

    # If all data points are invalid, return an empty dictionary
    if not valid_indices:
        return {}

    # Otherwise, filter the batch
    filtered_batch = {k: [v[i] for i in valid_indices] for k, v in batch.items() if k in batch}
    return filtered_batch

# Filter and format datasets
tokenized_train_dataset = tokenized_train_dataset.map(convert_and_filter, batched=True, remove_columns=tokenized_train_dataset.column_names)
tokenized_eval_dataset = tokenized_eval_dataset.map(convert_and_filter, batched=True, remove_columns=tokenized_eval_dataset.column_names)


# Training State
learning_rate = 2e-5
optimizer = optax.adamw(learning_rate)
state = train_state.TrainState.create(
    apply_fn=model,
    params=model.params,
    tx=optimizer,
)

# Loss Function
def loss_fn(params, batch):
    outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], params=params)
    logits = outputs.logits
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["label"]).mean()
    return loss

# Training Step
@jax.jit
def train_step(state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

def eval_step(params, batch):
      outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], params=params)
      logits = outputs.logits
      predictions = jnp.argmax(logits, axis=-1)
      # Convert batch["label"] to JAX array if it's not already
      label_array = jnp.array(batch["label"])
      return predictions == label_array  # Return the comparison result


def evaluate(state, eval_dataset):
      correct_predictions = 0  # Initialize counter for correct predictions
      total_predictions = 0   # Initialize counter for total predictions

      for i in range(0, len(eval_dataset), batch_size):
          # Get a batch of data using slicing
          batch = eval_dataset[i : i + batch_size]
          # Convert batch to a dictionary format
          batch = {k: v for k, v in zip(eval_dataset.column_names, zip(*batch.values()))}

          # Check if batch['input_ids'] is empty or contains invalid data
          if not batch['input_ids'] or not all(isinstance(ids, (list, np.ndarray, jnp.ndarray)) and len(ids) > 0 for ids in batch['input_ids']):
              continue

          # Pad the input_ids and attention_mask in the batch to the maximum length
          max_len = max(len(ids) for ids in batch['input_ids'])

          # Use jnp.pad to pad the JAX arrays
          # Convert ids to JAX array before padding if it's a list
          batch['input_ids'] = jnp.array([jnp.pad(jnp.array(ids), (0, max_len - len(ids)), 'constant', constant_values=0) if isinstance(ids, list) else jnp.pad(ids, (0, max_len - len(ids)), 'constant', constant_values=0) for ids in batch['input_ids']])
          batch['attention_mask'] = jnp.array([jnp.pad(jnp.array(mask), (0, max_len - len(mask)), 'constant', constant_values=0) if isinstance(mask, list) else jnp.pad(mask, (0, max_len - len(mask)), 'constant', constant_values=0) for mask in batch['attention_mask']])

          batch = jax.device_put(batch, tpu_device)  # Move batch to TPU device

          # Get comparison results for the batch
          comparison_results = eval_step(state.params, batch)

          # Accumulate correct and total predictions
          correct_predictions += jnp.sum(comparison_results)
          total_predictions += len(comparison_results)

      # Calculate and return the accuracy
      accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
      return accuracy



# Training Loop
num_epochs = 50
batch_size = 8
num_accum_steps = 4

print("Training...")
print(f"TPU Device: {tpu_device}")
print(f"Number of Epochs: {num_epochs}")
print(f"Batch Size: {batch_size}")
print(f"Total Training Examples: {len(tokenized_train_dataset)}")
print('\n')

def get_batch(dataset, start_index, batch_size):
    end_index = min(start_index + batch_size, len(dataset))
    batch = dataset[start_index:end_index]
    # If the batch is empty, skip it.
    if not batch:
        return None

    # Extract relevant fields for training.
    input_ids = jnp.array(batch["input_ids"])
    attention_mask = jnp.array(batch["attention_mask"])
    labels = jnp.array(batch["label"])

    # Pad the sequences (if necessary).
    max_len = max(len(ids) for ids in input_ids)
    input_ids = jnp.array([jnp.pad(ids, (0, max_len - len(ids)), 'constant', constant_values=0) for ids in input_ids])
    attention_mask = jnp.array([jnp.pad(mask, (0, max_len - len(mask)), 'constant', constant_values=0) for mask in attention_mask])

    # Put the batch on the TPU device (if available).
    if tpu_device:
        input_ids = jax.device_put(input_ids, tpu_device)
        attention_mask = jax.device_put(attention_mask, tpu_device)
        labels = jax.device_put(labels, tpu_device)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "label": labels}


for epoch in range(num_epochs):
    for i in range(0, len(tokenized_train_dataset), batch_size):

        # Get the next batch or skip if empty.
        batch = get_batch(tokenized_train_dataset, i, batch_size)
        if batch is None:
            continue

        # --- Training Step ---
        state, loss = train_step(state, batch)

        # --- Print Loss (if needed) ---
        loss_host = jax.device_get(loss)
        if i % 500 == 0:
            loss_divided_by_10 = loss_host / 10
            jax.debug.print(f"Epoch: {epoch}, Batch: {i}, Loss/10: {loss_divided_by_10}")
        #else:
            #jax.debug.print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss_host}")  # Print original loss otherwise

        sys.stdout.flush()

        # --- Evaluation after each batch update ---
        accuracy = evaluate(state, tokenized_eval_dataset)
        if i % 500 == 0:
           jax.debug.print(f"Epoch {epoch}, Batch {i} - Evaluation Accuracy: {accuracy}")



TPU detected!
[2 3 4]


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing FlaxDistilBertForSequenceClassification: {('vocab_transform', 'kernel'), ('vocab_layer_norm', 'bias'), ('vocab_layer_norm', 'scale'), ('vocab_projector', 'bias'), ('vocab_transform', 'bias')}
- This IS expected if you are initializing FlaxDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: {('classifier', 'bias'), ('classifier', 'k

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Training...
TPU Device: TPU_0(process=0,(0,0,0,0))
Number of Epochs: 50
Batch Size: 8
Total Training Examples: 1000


Epoch: 0, Batch: 0, Loss/10: 0.06930463016033173
Epoch 0, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 1, Batch: 0, Loss/10: 0.06872449815273285
Epoch 1, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 2, Batch: 0, Loss/10: 0.06950749456882477
Epoch 2, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 3, Batch: 0, Loss/10: 0.07131718844175339
Epoch 3, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 4, Batch: 0, Loss/10: 0.07118053734302521
Epoch 4, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 5, Batch: 0, Loss/10: 0.07369828969240189
Epoch 5, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 6, Batch: 0, Loss/10: 0.0715380311012268
Epoch 6, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 7, Batch: 0, Loss/10: 0.07344922423362732
Epoch 7, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 8, Batch: 0, Loss/10: 0.06829199939966202
Epoch 8, Batch 0 - Evaluation Accuracy: 0.0
Epoch: 9, Batch: 0, Loss/10: 0.067280992865562

In [19]:
!pip install geopy -q

In [None]:
import random
from datasets import Dataset
from geopy.geocoders import Nominatim
from geopy.distance import geodesic
from tqdm import tqdm  # Import tqdm for the progress bar


# Define necessary lists
airports = ["JFK", "LAX", "LGA", "BOS", "SFO", "ORD", "DFW", "ATL", "SEA", "MIA", "DEN", "IAH", "MSP", "DTW", "PHX", "CLT", "LAS", "MCO", "EWR", "PHL"]
aircraft_types = ["Boeing 747", "Airbus A320", "Boeing 777", "Boeing 737", "Airbus A330", "Boeing 757", "Airbus A321", "Airbus A319", "Boeing 787", "Embraer E190"]
weather_conditions = ["Clear", "Cloudy", "Rainy", "Snowy", "Windy"]

# Initialize geolocator
geolocator = Nominatim(user_agent="flight_planner", timeout=10)

def create_flight_data_point(origin, destination, departure_date, aircraft, weather):
    """Creates a flight data point with distance category as the label."""
    try:
        location_origin = geolocator.geocode(origin)
        location_destination = geolocator.geocode(destination)

        if location_origin and location_destination:
            # Calculate distance using geodesic
            distance = geodesic(
                (location_origin.latitude, location_origin.longitude),
                (location_destination.latitude, location_destination.longitude)
            ).kilometers

            # Categorize distance
            if distance < 500:
                distance_category = "short"
            elif distance < 1500:
                distance_category = "medium"
            else:
                distance_category = "long"

            # Modified input text:
            return {
                "input": f"Calculate the distance from {origin} to {destination}. Departure: {departure_date}, Aircraft: {aircraft}, Weather: {weather}",
                "label": distance_category  # Store distance category as the label
            }
        else:
            print(f"Could not find coordinates for {origin} or {destination}")
            return None

    except Exception as e:
        print(f"Error generating flight data point: {e}")
        return None

# --- Dataset Creation Loop ---
flight_data = []
number_routes = 10000  # You can adjust this number

# Wrap the loop with tqdm to create a progress bar
for _ in tqdm(range(number_routes), desc="Generating flight data"):
    origin = random.choice(airports)
    destination = random.choice(airports)
    while origin == destination:
        destination = random.choice(airports)
    departure_date = f"2024-{random.randint(1, 12):02}-{random.randint(1, 28):02}"
    aircraft = random.choice(aircraft_types)
    weather = random.choice(weather_conditions)

    data_point = create_flight_data_point(origin, destination, departure_date, aircraft, weather)
    if data_point:
        flight_data.append(data_point)

# --- Create the Dataset ---
dataset = Dataset.from_list(flight_data)
print(dataset)
## 127/10000


In [5]:
dataset = Dataset.from_list(flight_data)
dataset.save_to_disk("flight_dataset_tpu")

Saving the dataset (0/1 shards):   0%|          | 0/1127 [00:00<?, ? examples/s]

In [6]:
dataset.save_to_disk("flight_dataset_tpu")

from datasets import load_from_disk
dataset2 = load_from_disk("flight_dataset_tpu")

Saving the dataset (0/1 shards):   0%|          | 0/1127 [00:00<?, ? examples/s]

In [7]:
dataset2

Dataset({
    features: ['input', 'label'],
    num_rows: 1127
})

In [9]:
!pip install colab-env -q
import colab_env
import os
!cp -pr /content/flight_dataset_tpu /content/gdrive/MyDrive/datasets/

In [10]:
dataset3 = load_from_disk("/content/gdrive/MyDrive/datasets/flight_dataset_tpu")

In [None]:
dataset3

In [16]:
!pip install flax --quiet
!pip install --upgrade transformers -q
!pip install datasets evaluate -q


from typing import Dict, Any  # Import Any from typing
import random
from datasets import Dataset
from geopy.geocoders import Nominatim
from geopy.distance import geodesic
from tqdm import tqdm  # Import tqdm for the progress bar

import warnings
warnings.filterwarnings("ignore", message="You seem to be using the pipel")

import jax
import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification
from datasets import load_from_disk
from datasets import Dataset
import evaluate
import optax
import numpy as np
from flax.training import train_state

# TPU Detection and Device Assignment
try:
    tpu_device = jax.devices("tpu")[0]  # Get the first TPU device
    USE_TPU = True
    print("TPU detected!")
except RuntimeError:
    tpu_device = None  # If no TPU is found, set tpu_device to None
    USE_TPU = False
    print("Warning: TPU not found. Code will run on CPU or GPU.")

def simple_op(x):
    return x + 1

x = jnp.array([1, 2, 3])
x_tpu = jax.device_put(x, tpu_device)
result = simple_op(x_tpu).block_until_ready()

print(result)


# Model and Tokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

# Create a label mapping (string to integer)
label_mapping = {
    "short": 0,
    "medium": 1,
    "long": 2
}

# Tokenize and format the data
def tokenize_function(examples):
    # convert string labels to integers using label_mapping
    examples["labels"] = [label_mapping[label] for label in examples["label"]]
    return tokenizer(examples["input"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["input", "label"])
tokenized_datasets.set_format("jax", columns=["input_ids", "attention_mask", "labels"])

# Training State with PRNG Key
key = jax.random.PRNGKey(0)  # Initialize a PRNG key

class TrainState(train_state.TrainState):
    key: jax.Array
    #apply_fn_with_dropout: Any  # Add type hint for apply_fn_with_dropout

def create_train_state(model, tx, key):
    return TrainState.create(
        apply_fn=model.__call__,
        params=model.params,
        tx=tx,
        key=key,
        # Remove apply_fn_with_dropout:
        # apply_fn_with_dropout=apply_fn_with_dropout
    )

learning_rate = 2e-5
optimizer = optax.adamw(learning_rate)
state = create_train_state(model, optimizer, key)

# Loss Function
def loss_fn(params, batch, dropout_key):
    labels_batch = batch.pop("labels")

    # Apply dropout using the dropout key
    # Removed the "method" argument
    logits = model(**batch, params=params, train=True, dropout_rng=dropout_key).logits

    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels_batch).mean()
    return loss

# Training Step (Modified)
@jax.jit
def train_step(state, batch):
    key, dropout_key = jax.random.split(state.key)

    def loss_fn_wrapped(params):
        return loss_fn(params, batch, dropout_key)

    loss_value, grads = jax.value_and_grad(loss_fn_wrapped)(state.params)
    state = state.apply_gradients(grads=grads)
    return state.replace(key=key), loss_value

def eval_step(params, batch):
    batch = {k: v for k, v in batch.items() if k != "labels"}

    logits = model(**batch, params=params, train=False).logits
    return logits.argmax(-1)

# Split the dataset into train and eval
train_testvalid = tokenized_datasets.train_test_split(test_size=0.2, seed=42)
train_dataset = train_testvalid["train"]
testvalid_dataset = train_testvalid["test"]

test_valid = testvalid_dataset.train_test_split(test_size=0.5, seed=42)
eval_dataset = test_valid["test"]
test_dataset = test_valid["train"]

small_train_dataset = train_dataset.shuffle(seed=42).select(range(800))
small_eval_dataset = eval_dataset.shuffle(seed=42).select(range(113))

# Training Loop
num_epochs = 5
batch_size = 8

print('\n')
print("Training...")
print(f"TPU Device: {tpu_device}")
print(f"Number of Epochs: {num_epochs}")
print(f"Batch Size: {batch_size}")
print(f"Total Training Examples: {len(small_train_dataset)}")
print('\n')

for epoch in range(num_epochs):
    # Train
    for i in range(0, len(small_train_dataset), batch_size):
        batch = small_train_dataset[i: i + batch_size]
        state, loss = train_step(state, batch)
        if i % 100 == 0:
            print(f"Epoch {epoch} - Batch {i} - Loss: {loss}")
        #print(f"Epoch {epoch} - Batch {i} - Loss: {loss}")

    # Eval
    all_predictions = []
    all_labels = []

    for i in range(0, len(small_eval_dataset), batch_size):
        batch = small_eval_dataset[i: i + batch_size]
        predictions = eval_step(state.params, batch)
        all_predictions.extend(predictions)
        all_labels.extend(batch["labels"])

    accuracy = evaluate.load("accuracy")
    acc = accuracy.compute(predictions=all_predictions, references=all_labels)
    if epoch % 1 == 0:
       print('\n')
       print(f"Epoch {epoch} - Accuracy: {acc}")
       print('\n')

TPU detected!
[2 3 4]


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing FlaxDistilBertForSequenceClassification: {('vocab_transform', 'kernel'), ('vocab_projector', 'bias'), ('vocab_layer_norm', 'bias'), ('vocab_transform', 'bias'), ('vocab_layer_norm', 'scale')}
- This IS expected if you are initializing FlaxDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: {('classifier', 'bias'), ('pre_classifier'

Map:   0%|          | 0/1127 [00:00<?, ? examples/s]



Training...
TPU Device: TPU_0(process=0,(0,0,0,0))
Number of Epochs: 5
Batch Size: 8
Total Training Examples: 800


Epoch 0 - Batch 0 - Loss: 1.0353548526763916
Epoch 0 - Batch 200 - Loss: 0.7825049161911011
Epoch 0 - Batch 400 - Loss: 0.7804630994796753
Epoch 0 - Batch 600 - Loss: 0.5897069573402405
Epoch 0 - Accuracy: {'accuracy': 0.7964601769911505}
Epoch 1 - Batch 0 - Loss: 0.4925644099712372
Epoch 1 - Batch 200 - Loss: 0.2727748155593872
Epoch 1 - Batch 400 - Loss: 0.6110062003135681
Epoch 1 - Batch 600 - Loss: 0.26555338501930237
Epoch 1 - Accuracy: {'accuracy': 0.8407079646017699}
Epoch 2 - Batch 0 - Loss: 0.3579442501068115
Epoch 2 - Batch 200 - Loss: 0.20856547355651855
Epoch 2 - Batch 400 - Loss: 0.4017004668712616
Epoch 2 - Batch 600 - Loss: 0.1882483810186386
Epoch 2 - Accuracy: {'accuracy': 0.8938053097345132}
Epoch 3 - Batch 0 - Loss: 0.12177985906600952
Epoch 3 - Batch 200 - Loss: 0.2199835479259491
Epoch 3 - Batch 400 - Loss: 0.14725062251091003
Epoch 3 - Batch 600 - 

In [43]:
dataset

Dataset({
    features: ['input', 'label'],
    num_rows: 100
})