## 1.0 Environment Setup

This cell installs all the necessary libraries for our JAX/Flax project. We use the `-q` flag for a quiet installation.

* **`jax[tpu]`**: Installs JAX with specific optimizations for Google's TPUs.
* **`flax`**: A neural network library for JAX.
* **`optax`**: A gradient processing and optimization library for JAX.
* **`orbax`**: A library for checkpointing (saving model progress).
* **`transformers` / `datasets`**: Hugging Face libraries for models and data handling.
* **`wandb`**: For experiment tracking and logging.

> **Note on Warnings:** After this cell runs, you will likely see a long list of red `ERROR` messages about "dependency conflicts." This is **normal and expected** on Kaggle. It happens because our new libraries conflict with older, pre-installed packages we won't be using (like `torch`). These warnings can be safely ignored.

In [None]:
!pip install -q "jax[tpu]" flax optax orbax-checkpointing transformers datasets wandb ipywidgets

## 1.1 Verify Installation

Here, we import the core libraries we just installed. By printing their versions and checking for available TPU devices, we can confirm that the environment is set up correctly and all dependencies are accessible before proceeding.


> **Note on Output:** The output of this cell might include a `TqdmWarning` or an `INFO` message about a `rocm` backend. These are harmless informational messages and do not indicate a problem. A successful run will print your library versions and a list of available TPU devices.

In [None]:
import jax
import flax
import optax
import orbax.checkpoint
import transformers
import datasets

# Print versions to confirm
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Check for TPU
try:
    print("TPU devices:", jax.devices("tpu"))
except:
    print("No TPU devices found. Ensure your notebook accelerator is set to TPU.")

print("\n✅ All key libraries imported successfully!")

## 1.2 TPU Initialization and Authentication

Here, we'll verify that JAX can correctly identify and connect to the available TPU hardware. This step also handles authentication by triggering interactive login prompts for both Hugging Face (to access models) and Weights & Biases (for experiment tracking).

**Note on Output:** This cell will first print the list of available JAX devices, which should show 8 TPU cores. It will then display two separate input boxes. You will need to provide your Hugging Face User Access Token and your Weights & Biases API key in these prompts to proceed.

In [None]:
import jax
from huggingface_hub import notebook_login
import wandb

# Verify that all 8 TPU cores are visible to JAX
print("Verifying available JAX devices...")
print(jax.devices())

# Use notebook_login() to prompt for a Hugging Face token
print("\nPlease log in to Hugging Face Hub:")
notebook_login()

# Use wandb.login() to prompt for a W&B API key
print("\nPlease log in to Weights & Bienses:")
wandb.login()

print("\n✅ Authentication complete.")

## 1.3 Data Acquisition and Formatting

This cell loads the custom dataset from a local `my_dataset.csv` file. It then defines and applies a formatting function to transform the `instruction` and `response` columns into the specific chat template required by Llama 3. The resulting structured text is stored in a new `text` column, and the entire processed dataset is saved to disk for the subsequent tokenization step.

**Note on Output:** The cell will first create the sample CSV file. It will then print a confirmation that the dataset has been loaded, display one complete example of a formatted prompt from the new `text` column, and end with a success message confirming that the data has been saved to the `./processed_data` directory.

In [None]:
%%writefile my_dataset.csv
instruction,response
"How do I get a good sear on a steak in a stainless steel pan?","To get a great sear, preheat your stainless steel pan over medium-high heat until a drop of water sizzles and glides across the surface. Pat your steak completely dry, season it generously, and then place it in the hot, oiled pan. Don't move it for several minutes to allow a deep brown crust to form before flipping."
"What is the capital of Oklahoma?","The capital of Oklahoma is Oklahoma City."
"Write a simple Python function to add two numbers.","Certainly! Here's a simple Python function to add two numbers:\n\n```python\ndef add_numbers(a, b):\n    return a + b\n```"

In [None]:
from datasets import load_dataset
import os

# Load the dataset from the local CSV file
try:
    # The output of load_dataset is a DatasetDict, so we access the default 'train' split.
    dataset = load_dataset('csv', data_files='my_dataset.csv')['train']
    print("Dataset loaded successfully:")
    print(dataset)
except FileNotFoundError:
    print("Error: 'my_dataset.csv' not found. Please ensure the file is in the root directory of your Kaggle notebook.")

# Define the formatting function for the Llama 3 chat template
def format_prompt(sample):
    # This creates the structured text string for each sample
    sample['text'] = (
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
        f"{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        f"{sample['response']}<|eot_id|>"
    )
    return sample

# Apply the formatting function to the dataset to create the new 'text' column
formatted_dataset = dataset.map(format_prompt)

# Verify the new 'text' column by checking the first example
print("\n--- Sample of a formatted prompt ---")
print(formatted_dataset[0]['text'])

# Create a directory to save the processed data
output_dir = "./processed_data"
os.makedirs(output_dir, exist_ok=True)

# Save the processed dataset to disk for the next step
formatted_dataset.save_to_disk(output_dir)

print(f"\n✅ Formatted dataset saved to '{output_dir}'")