To run this, press "*Runtime*" and press "*Run all*" on a **free** Google Colab TPU!
<div class="align-center">
  <a href="https://github.com/felafax/felfax"><img src="https://felafax.ai/felafax.svg" width="145"></a></a> ⭐ <i>Star us on <a href="https://github.com/felafax/felafax">Github</a> </i> ⭐ and email us founders@felafax.ai for any questions!
</div>

# Setup

In [1]:
!pip install --upgrade pip
!pip install git+https://github.com/felafax/felafax.git -q
!pip install --upgrade jax jaxlib
!!pip install --upgrade "jax[tpu]>=0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip uninstall -y tensorflow && pip install tensorflow-cpu -q

Collecting pip
  Downloading pip-24.3.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-24.3.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-24.3.1
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.0/102.0 MB[0m [31m78.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m183.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [3]:
MODEL_NAME = "meta-llama/Llama-3.2-1B"
HF_TOKEN = input("Please enter your HuggingFace token: ")
TRAINER_DIR = "/"
TEST_MODE = False


CHECKPOINT_DIR = os.path.join(TRAINER_DIR, "checkpoints")
EXPORT_DIR = os.path.join(TRAINER_DIR, "finetuned_export")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(EXPORT_DIR, exist_ok=True)

Please enter your HuggingFace token: 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


In [4]:
from felafax.trainer_engine import setup
setup.setup_environment(base_dir=TRAINER_DIR)

import jax
from transformers import AutoTokenizer
from felafax.trainer_engine import checkpoint, trainer, utils
from felafax.trainer_engine.data import data

# Step 0: Configure different parts of training pipeline

In [5]:
dataset_config = data.DatasetConfig(
    data_source="yahma/alpaca-cleaned",
    max_seq_length=32,
    batch_size=8,
    num_workers=4,
    mask_prompt=False,
    train_test_split=0.15,

    ignore_index=-100,
    pad_id=0,
    seed=42,

    # Setting max_examples limits the number of examples in the dataset.
    # This is useful for testing the pipeline without running the entire dataset.
    max_examples=100 if TEST_MODE else None,
)


In [6]:
trainer_config = trainer.TrainerConfig(
    model_name=MODEL_NAME,
    param_dtype="bfloat16",
    compute_dtype="bfloat16",

    # Training configuration
    num_epochs=1,
    num_steps=50,
    use_lora=True,
    lora_rank=16,
    learning_rate=1e-3,
    log_interval=1,

    num_tpus=jax.device_count(),

    # Eval configuration
    eval_interval=10,
    eval_steps=5,

    # Additional info required by trainer
    base_dir=TRAINER_DIR,
    hf_token=HF_TOKEN,
)



In [7]:
checkpointer_config = checkpoint.CheckpointerConfig(
    checkpoint_dir=CHECKPOINT_DIR,
    max_to_keep=2,
    save_interval_steps=50,
    erase_existing_checkpoints=True,
)
checkpointer = checkpoint.Checkpointer(config=checkpointer_config)

# Step 1: Downloading dataset...

For this colab, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [8]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)

# Download and load the data files
train_data, val_data = data.load_data(config=dataset_config)

# Create datasets for SFT (supervised fine-tuning)
train_dataset = data.SFTDataset(
    config=dataset_config,
    data=train_data,
    tokenizer=tokenizer,
)
val_dataset = data.SFTDataset(
    config=dataset_config,
    data=val_data,
    tokenizer=tokenizer,
)

# Create dataloaders
train_dataloader = data.create_dataloader(
    config=dataset_config,
    dataset=train_dataset,
    shuffle=True,
)
val_dataloader = data.create_dataloader(
    config=dataset_config,
    dataset=val_dataset,
    shuffle=False,
)


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

alpaca_data_cleaned.json:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/51760 [00:00<?, ? examples/s]

# Step 2: Create Trainer and load the model

In [9]:
felafax_trainer = trainer.Trainer(
    trainer_config=trainer_config,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    checkpointer=checkpointer,
)

Creating TPU device mesh with shape (1, 1, 1)...
Loading model from HuggingFace...


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [10]:
felafax_trainer.train()

Started epoch 1 of 1...
Step 0 | Train Loss: 0.0000 | Val Loss: 0.0000 | 
Step 0 | Train Loss: 4.0413 | Val Loss: 0.0000 | 
Step 1 | Train Loss: 3.8412 | Val Loss: 0.0000 | 
Step 2 | Train Loss: 3.0881 | Val Loss: 0.0000 | 
Step 3 | Train Loss: 2.5106 | Val Loss: 0.0000 | 
Step 4 | Train Loss: 1.9687 | Val Loss: 0.0000 | 
Step 5 | Train Loss: 1.5145 | Val Loss: 0.0000 | 
Step 6 | Train Loss: 1.2508 | Val Loss: 0.0000 | 
Step 7 | Train Loss: 1.2086 | Val Loss: 0.0000 | 
Step 8 | Train Loss: 1.2840 | Val Loss: 0.0000 | 
Running eval for 5 steps...


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


Step 9 | Train Loss: 1.2698 | Val Loss: 1.2129 | 
Step 10 | Train Loss: 1.1982 | Val Loss: 1.2129 | 
Step 11 | Train Loss: 1.1585 | Val Loss: 1.2129 | 
Step 12 | Train Loss: 1.0318 | Val Loss: 1.2129 | 
Step 13 | Train Loss: 1.1123 | Val Loss: 1.2129 | 
Step 14 | Train Loss: 1.1512 | Val Loss: 1.2129 | 
Step 15 | Train Loss: 1.0620 | Val Loss: 1.2129 | 
Step 16 | Train Loss: 1.2142 | Val Loss: 1.2129 | 
Step 17 | Train Loss: 1.1156 | Val Loss: 1.2129 | 
Step 18 | Train Loss: 1.1159 | Val Loss: 1.2129 | 
Running eval for 5 steps...
Step 19 | Train Loss: 1.0767 | Val Loss: 1.0942 | 
Step 20 | Train Loss: 0.9980 | Val Loss: 1.0942 | 
Step 21 | Train Loss: 1.0388 | Val Loss: 1.0942 | 
Step 22 | Train Loss: 1.0657 | Val Loss: 1.0942 | 
Step 23 | Train Loss: 1.0664 | Val Loss: 1.0942 | 
Step 24 | Train Loss: 0.9311 | Val Loss: 1.0942 | 
Step 25 | Train Loss: 0.9855 | Val Loss: 1.0942 | 
Step 26 | Train Loss: 0.9989 | Val Loss: 1.0942 | 
Step 27 | Train Loss: 1.0500 | Val Loss: 1.0942 | 
Step

# Step 3: Export fine-tuned model

In [None]:
felafax_trainer.export(export_dir=EXPORT_DIR)

In [None]:
utils.upload_dir_to_hf(
    dir_path=EXPORT_DIR,
    repo_name="felarof01/test-llama3-alpaca-from-colab",
    token=HF_TOKEN,
)