## 1.0 Environment Setup

This cell installs all the necessary libraries for our JAX/Flax project. We use the `-q` flag for a quiet installation.

* **`jax[tpu]`**: Installs JAX with specific optimizations for Google's TPUs.
* **`flax`**: A neural network library for JAX.
* **`optax`**: A gradient processing and optimization library for JAX.
* **`orbax`**: A library for checkpointing (saving model progress).
* **`transformers` / `datasets`**: Hugging Face libraries for models and data handling.
* **`wandb`**: For experiment tracking and logging.

> **Note on Warnings:** After this cell runs, you will likely see a long list of red `ERROR` messages about "dependency conflicts." This is **normal and expected** on Kaggle. It happens because our new libraries conflict with older, pre-installed packages we won't be using (like `torch`). These warnings can be safely ignored.

In [None]:
!pip install -q "jax[tpu]" flax optax orbax-checkpointing transformers datasets wandb ipywidgets

## 1.1 Verify Installation

Here, we import the core libraries we just installed. By printing their versions and checking for available TPU devices, we can confirm that the environment is set up correctly and all dependencies are accessible before proceeding.


> **Note on Output:** The output of this cell might include a `TqdmWarning` or an `INFO` message about a `rocm` backend. These are harmless informational messages and do not indicate a problem. A successful run will print your library versions and a list of available TPU devices.

In [2]:
import jax
import flax
import optax
import orbax.checkpoint
import transformers
import datasets

# Print versions to confirm
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Check for TPU
try:
    print("TPU devices:", jax.devices("tpu"))
except:
    print("No TPU devices found. Ensure your notebook accelerator is set to TPU.")

print("\n✅ All key libraries imported successfully!")

  from .autonotebook import tqdm as notebook_tqdm


JAX version: 0.4.34
Flax version: 0.10.4
Optax version: 0.2.5
Transformers version: 4.53.0
Datasets version: 4.0.0


E0000 00:00:1755725894.627427      10 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:230


TPU devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

✅ All key libraries imported successfully!


## 1.2 TPU Initialization and Authentication

Here, we'll verify that JAX can correctly identify and connect to the available TPU hardware. This step also handles authentication by triggering interactive login prompts for both Hugging Face (to access models) and Weights & Biases (for experiment tracking).

**Note on Output:** This cell will first print the list of available JAX devices, which should show 8 TPU cores. It will then display two separate input boxes. You will need to provide your Hugging Face User Access Token and your Weights & Biases API key in these prompts to proceed.

In [3]:
import jax
from huggingface_hub import notebook_login
import wandb

# Verify that all 8 TPU cores are visible to JAX
print("Verifying available JAX devices...")
print(jax.devices())

# Use notebook_login() to prompt for a Hugging Face token
print("\nPlease log in to Hugging Face Hub:")
notebook_login()

# Use wandb.login() to prompt for a W&B API key
print("\nPlease log in to Weights & Biases:")
wandb.login()

Verifying available JAX devices...
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Please log in to Hugging Face Hub:


ImportError: The `notebook_login` function can only be used in a notebook (Jupyter or Colab) and you need the `ipywidgets` module: `pip install ipywidgets`.