## 1.0 Environment Setup

This cell installs all the necessary libraries for our JAX/Flax project. We use the `-q` flag for a quiet installation.

* **`jax[tpu]`**: Installs JAX with specific optimizations for Google's TPUs.
* **`flax`**: A neural network library for JAX.
* **`optax`**: A gradient processing and optimization library for JAX.
* **`orbax`**: A library for checkpointing (saving model progress).
* **`transformers` / `datasets`**: Hugging Face libraries for models and data handling.
* **`wandb`**: For experiment tracking and logging.

> **Note on Warnings:** After this cell runs, you will likely see a long list of red `ERROR` messages about "dependency conflicts." This is **normal and expected** on Kaggle. It happens because our new libraries conflict with older, pre-installed packages we won't be using (like `torch`). These warnings can be safely ignored.

In [1]:
!pip install -q "jax[tpu]" flax optax orbax transformers datasets wandb ipywidgets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## 1.1 Verify Installation

Here, we import the core libraries we just installed. By printing their versions and checking for available TPU devices, we can confirm that the environment is set up correctly and all dependencies are accessible before proceeding.


> **Note on Output:** The output of this cell might include a `TqdmWarning` or an `INFO` message about a `rocm` backend. These are harmless informational messages and do not indicate a problem. A successful run will print your library versions and a list of available TPU devices.

In [2]:
import jax
import flax
import optax
import orbax.checkpoint
import transformers
import datasets

# Print versions to confirm
print(f"JAX version: {jax.__version__}")
print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Check for TPU
try:
    print("TPU devices:", jax.devices("tpu"))
except:
    print("No TPU devices found. Ensure your notebook accelerator is set to TPU.")

print("\n✅ All key libraries imported successfully!")

JAX version: 0.4.34
Flax version: 0.10.4
Optax version: 0.2.5
Transformers version: 4.53.0
Datasets version: 4.0.0


E0000 00:00:1755900947.803982      10 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:230


TPU devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

✅ All key libraries imported successfully!


## 1.2 TPU Initialization and Authentication

Here, we'll verify that JAX can correctly identify and connect to the available TPU hardware. This step also handles authentication by triggering interactive login prompts for both Hugging Face (to access models) and Weights & Biases (for experiment tracking).

**Note on Output:** This cell will first print the list of available JAX devices, which should show 8 TPU cores. It will then display two separate input boxes. You will need to provide your Hugging Face User Access Token and your Weights & Biases API key in these prompts to proceed.

In [3]:
import jax
from huggingface_hub import notebook_login
import wandb

# Verify that all 8 TPU cores are visible to JAX
print("Verifying available JAX devices...")
print(jax.devices())

# Use notebook_login() to prompt for a Hugging Face token
print("\nPlease log in to Hugging Face Hub:")
notebook_login()

# Use wandb.login() to prompt for a W&B API key
print("\nPlease log in to Weights & Biases:")
wandb.login()

print("\n✅ Authentication complete.")

Verifying available JAX devices...
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Please log in to Hugging Face Hub:


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…


Please log in to Weights & Biases:


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

  ········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmarkdonaho[0m ([33mmarkdonaho-vatical-investment-group[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



✅ Authentication complete.


## 1.3 Data Acquisition and Formatting

This cell loads the custom dataset from a local `my_dataset.csv` file. It then defines and applies a formatting function to transform the `instruction` and `response` columns into the specific chat template required by Llama 3. The resulting structured text is stored in a new `text` column, and the entire processed dataset is saved to disk for the subsequent tokenization step.

**Note on Output:** The cell will first create the sample CSV file. It will then print a confirmation that the dataset has been loaded, display one complete example of a formatted prompt from the new `text` column, and end with a success message confirming that the data has been saved to the `./processed_data` directory.

In [4]:
%%writefile my_dataset.csv
instruction,response
"How do I get a good sear on a steak in a stainless steel pan?","To get a great sear, preheat your stainless steel pan over medium-high heat until a drop of water sizzles and glides across the surface. Pat your steak completely dry, season it generously, and then place it in the hot, oiled pan. Don't move it for several minutes to allow a deep brown crust to form before flipping."
"What is the capital of Oklahoma?","The capital of Oklahoma is Oklahoma City."
"Write a simple Python function to add two numbers.","Certainly! Here's a simple Python function to add two numbers:\n\n```python\ndef add_numbers(a, b):\n    return a + b\n```"

Writing my_dataset.csv


In [5]:
from datasets import load_dataset
import os

# Load the dataset from the local CSV file
try:
    # The output of load_dataset is a DatasetDict, so we access the default 'train' split.
    dataset = load_dataset('csv', data_files='my_dataset.csv')['train']
    print("Dataset loaded successfully:")
    print(dataset)
except FileNotFoundError:
    print("Error: 'my_dataset.csv' not found. Please ensure the file is in the root directory of your Kaggle notebook.")

# Define the formatting function for the Llama 3 chat template
def format_prompt(sample):
    # This creates the structured text string for each sample
    sample['text'] = (
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
        f"{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        f"{sample['response']}<|eot_id|>"
    )
    return sample

# Apply the formatting function to the dataset to create the new 'text' column
formatted_dataset = dataset.map(format_prompt)

# Verify the new 'text' column by checking the first example
print("\n--- Sample of a formatted prompt ---")
print(formatted_dataset[0]['text'])

# Create a directory to save the processed data
output_dir = "./processed_data"
os.makedirs(output_dir, exist_ok=True)

# Save the processed dataset to disk for the next step
formatted_dataset.save_to_disk(output_dir)

print(f"\n✅ Formatted dataset saved to '{output_dir}'")

Generating train split: 0 examples [00:00, ? examples/s]

Dataset loaded successfully:
Dataset({
    features: ['instruction', 'response'],
    num_rows: 3
})


Map:   0%|          | 0/3 [00:00<?, ? examples/s]


--- Sample of a formatted prompt ---
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

How do I get a good sear on a steak in a stainless steel pan?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

To get a great sear, preheat your stainless steel pan over medium-high heat until a drop of water sizzles and glides across the surface. Pat your steak completely dry, season it generously, and then place it in the hot, oiled pan. Don't move it for several minutes to allow a deep brown crust to form before flipping.<|eot_id|>


Saving the dataset (0/1 shards):   0%|          | 0/3 [00:00<?, ? examples/s]


✅ Formatted dataset saved to './processed_data'


## 1.4 Tokenization and Data Processing
This cell handles the crucial step of **tokenization**. We load the pre-formatted text data we saved earlier and use the official Llama 3 tokenizer to convert it from human-readable strings into numerical IDs that the model can understand. We'll also set a maximum sequence length (`512` tokens) to ensure all inputs are uniformly sized, truncating longer examples if necessary. Finally, the original text columns are removed, leaving only the tokenized data, which is then saved to a new directory.

**Note on Output:** This cell will print messages confirming that the tokenizer and dataset have been loaded. After processing, it will display the features of the new tokenized dataset (which should include `input_ids` and `attention_mask`) and a final message confirming that the data has been saved to the `./tokenized_data` directory.

In [6]:
from transformers import AutoTokenizer
from datasets import load_from_disk
import os

# Define the model ID for Llama 3 8B
model_id = "meta-llama/Llama-3.1-8B"

# --- 1. Load the processed dataset from disk ---
try:
    formatted_dataset = load_from_disk("./processed_data")
    print("✅ Successfully loaded formatted dataset from './processed_data'")
except FileNotFoundError:
    print("❌ Error: Could not find './processed_data'. Please run the previous cell (1.3) to generate it.")
    # Stop execution if the data isn't there
    raise

# --- 2. Load the tokenizer from Hugging Face Hub ---
# The notebook_login() in cell 1.2 handles the authentication needed for gated models.
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"✅ Tokenizer for '{model_id}' loaded successfully.")

# --- 3. Set the padding token ---
# Llama 3 does not have a dedicated padding token.
# A common practice is to use the End-of-Sentence (EOS) token for padding.
tokenizer.pad_token = tokenizer.eos_token
print(f"Tokenizer padding token set to EOS token: '{tokenizer.pad_token}' (ID: {tokenizer.eos_token_id})")

# --- 4. Define the tokenization function ---
def tokenize_function(examples):
    # The tokenizer converts text to 'input_ids' and generates an 'attention_mask'.
    # We pad to max_length to ensure all sequences have the same shape, which is
    # crucial for efficient processing on TPUs.
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length"
    )

# --- 5. Apply the tokenization function to the dataset ---
print("\nTokenizing dataset... this may take a moment.")
tokenized_dataset = formatted_dataset.map(
    tokenize_function,
    batched=True,
    # Remove the original text columns as they are no longer needed
    remove_columns=["instruction", "response", "text"]
)
print("✅ Tokenization complete.")
print("\nFeatures of the tokenized dataset:")
print(tokenized_dataset)

# --- 6. Save the final tokenized dataset to disk ---
tokenized_output_dir = "./tokenized_data"
os.makedirs(tokenized_output_dir, exist_ok=True)
tokenized_dataset.save_to_disk(tokenized_output_dir)
print(f"\n✅ Tokenized dataset saved to '{tokenized_output_dir}'")



✅ Successfully loaded formatted dataset from './processed_data'


OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.1-8B.
401 Client Error. (Request ID: Root=1-68a8ec53-18c4f0b77a0c784a7c0e8003;0c1a9ba2-bb8a-415d-86e4-0b6c523ae601)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.1-8B/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B is restricted. You must have access to it and be authenticated to access it. Please log in.

## Module 1 Verification
This final cell for Module 1 acts as a quick sanity check. It loads the final tokenized data and decodes the first sample (`input_ids`) back into text. This allows you to visually inspect the result and confirm that the entire data pipeline—from CSV loading to formatting and tokenization—has executed successfully.

**Note on Output:** The output should be the full, formatted text of the first entry in your dataset (the one about searing a steak), complete with all the special Llama 3 control tokens. This confirms the tokenizer is working correctly.

In [None]:
from datasets import load_from_disk
from transformers import AutoTokenizer

# Load the tokenizer again
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

# Load the final tokenized data from disk
final_dataset = load_from_disk("./tokenized_data")

print("✅ Module 1 Verification: Loaded final dataset.")
print("\n--- Decoding the first sample to verify integrity ---")

# Retrieve the token IDs for the first sample
first_sample_tokens = final_dataset[0]['input_ids']

# Use the tokenizer to decode the token IDs back into a string
decoded_text = tokenizer.decode(first_sample_tokens, skip_special_tokens=False)

print(decoded_text)
print("\n\n✅ If you see the formatted Llama 3 prompt above, Module 1 is complete and successful!")