<a href="https://colab.research.google.com/github/felafax/felafax/blob/main/notebooks/Llama3_1_on_Free_Colab_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Google Colab TPU!
<div class="align-center">
  <a href="https://github.com/felafax/felfax"><img src="https://felafax.ai/felafax.svg" width="145"></a></a> ⭐ <i>Star us on <a href="https://github.com/felafax/felafax">Github</a> </i> ⭐ and email us founders@felafax.ai for any questions!
</div>

# Setup

In [11]:
!pip install git+https://github.com/felafax/felafax.git -q
!pip uninstall -y tensorflow && pip install tensorflow-cpu -q

[0m

In [12]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [13]:
MODEL_NAME = "meta-llama/Llama-3.2-1B"
HF_TOKEN = input("Please enter your HuggingFace token: ")
TRAINER_DIR = "/"
TEST_MODE = False


CHECKPOINT_DIR = os.path.join(TRAINER_DIR, "checkpoints")
EXPORT_DIR = os.path.join(TRAINER_DIR, "finetuned_export")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(EXPORT_DIR, exist_ok=True)

Please enter your HuggingFace token:  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


In [14]:
from felafax.trainer_engine import setup
setup.setup_environment(base_dir=TRAINER_DIR)

import jax
from transformers import AutoTokenizer
from felafax.trainer_engine import checkpoint, trainer, utils
from felafax.trainer_engine.data import data

# Step 0: Configure different parts of training pipeline

In [15]:
dataset_config = data.DatasetConfig(
    data_source="yahma/alpaca-cleaned",
    max_seq_length=32,
    batch_size=8,
    num_workers=4,
    mask_prompt=False,
    train_test_split=0.15,

    ignore_index=-100,
    pad_id=0,
    seed=42,

    # Setting max_examples limits the number of examples in the dataset.
    # This is useful for testing the pipeline without running the entire dataset.
    max_examples=100 if TEST_MODE else None,
)


In [16]:
trainer_config = trainer.TrainerConfig(
    model_name=MODEL_NAME,
    param_dtype="bfloat16",
    compute_dtype="bfloat16",

    # Training configuration
    num_epochs=1,
    num_steps=50,
    use_lora=True,
    lora_rank=16,
    learning_rate=1e-3,
    log_interval=1,

    num_tpus=jax.device_count(),

    # Eval configuration
    eval_interval=50,
    eval_steps=5,

    # Additional info required by trainer
    base_dir=TRAINER_DIR,
    hf_token=HF_TOKEN,
)



In [17]:
checkpointer_config = checkpoint.CheckpointerConfig(
    checkpoint_dir=CHECKPOINT_DIR,
    max_to_keep=2,
    save_interval_steps=50,
    erase_existing_checkpoints=True,
)
checkpointer = checkpoint.Checkpointer(config=checkpointer_config)

# Step 1: Downloading dataset...

For this colab, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [18]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)

# Download and load the data files
train_data, val_data = data.load_data(config=dataset_config)

# Create datasets for SFT (supervised fine-tuning)
train_dataset = data.SFTDataset(
    config=dataset_config,
    data=train_data,
    tokenizer=tokenizer,
)
val_dataset = data.SFTDataset(
    config=dataset_config,
    data=val_data,
    tokenizer=tokenizer,
)

# Create dataloaders
train_dataloader = data.create_dataloader(
    config=dataset_config,
    dataset=train_dataset,
    shuffle=True,
)
val_dataloader = data.create_dataloader(
    config=dataset_config,
    dataset=val_dataset,
    shuffle=False,
)


# Step 2: Create Trainer and load the model

In [19]:
llama_trainer = trainer.Trainer(
    trainer_config=trainer_config,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    checkpointer=checkpointer,
)

Creating TPU device mesh with shape (1, 2, 2)...
Loading model from HuggingFace...


In [20]:
llama_trainer.train()

Started epoch 1 of 1...
Step 0 | Train Loss: 0.0000 | Val Loss: 0.0000 | 


UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2048] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at /usr/local/lib/python3.10/site-packages/felafax/trainer_engine/models/llama3/jax/model.py:625 traced for scan.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_99/1170177407.py:1 (<module>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_99/1170177407.py:1 (<module>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

# Step 3: Export fine-tuned model

In [None]:
llama_trainer.export(export_dir=EXPORT_DIR)

In [None]:
utils.upload_dir_to_hf(
    dir_path=EXPORT_DIR,
    repo_name="felarof01/test-llama3-alpaca-from-colab",
    token=HF_TOKEN,
)