<a href="https://colab.research.google.com/github/felarof99/roadrunner-fork/blob/main/%F0%9F%A6%8A__Llama3_1_8b_on_Free_Colab_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Google Colab TPU!
<div class="align-center">
  <a href="https://github.com/felafax/felfax"><img src="https://felafax.ai/felafax.svg" width="145"></a></a> ⭐ <i>Star us on <a href="https://github.com/felafax/felafax">Github</a> </i> ⭐ and email us founders@felafax.ai for any questions!
</div>

# Setup

In [23]:
!pip install --upgrade git+https://github.com/felafax/felafax -q
!pip uninstall -y tensorflow && pip install tensorflow-cpu -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[0m

In [24]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import sys
import importlib
from typing import (Any, Dict, List, Mapping, Optional, Sequence, Tuple,
                    Union)

# Add the parent directory of the current working directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '.')))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import llama3_jax
from llama3_jax.trainer_engine import setup
setup.setup_environment(base_dir="/")

from llama3_jax.trainer_engine import (automodel_lib, checkpoint_lib,
                                       convert_lib, dataset_lib, jax_utils, llama_config,
                                       trainer_lib, utils)
setup.reload_modules("llama3_jax")

import jax
import jax.numpy as jnp
import chex
import optax

import torch
from datasets import load_dataset


Reloaded all felafax modules.


# Step 0: Configure LoRA params and precision for training (jnp.bfloat16 or jnp.float32)

In [25]:
MODEL_NAME = "colab-llama-3.1-8B-Instruct-JAX"

In [26]:
model_path, model, model_configurator, tokenizer = (
    automodel_lib.AutoJAXModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=jnp.bfloat16,
        param_dtype=jnp.bfloat16,
        lora_rank=8,
        lora_alpha=16,))


Downloading model colab-llama-3.1-8B-Instruct-JAX...


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

colab-llama-3.1-8B-Instruct-JAX was downloaded to /hf/models--felafax--colab-llama-3.1-8B-Instruct-JAX/snapshots/63234acd3cbc6c72c47054ba0d902969811d833c/llama-3.1-8B-Instruct-JAX.flax.


# Step 1: prepare the dataset

For this colab, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [27]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

    ### Instruction: {}


    ### Input: {}

    ### Response: {}"""

    EOS_TOKEN = tokenizer.eos_token

    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts to JAX arrays.
        """
        collated = {
            'input_tokens': [],
            'target_tokens': [],
            'loss_masks': []
        }
        for item in batch:
            for key in collated:
                collated[key].append(item[key])
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value)
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

### Uncomment below code ⬇️ if you'd like to run and test 💯 your dataset pipeline.



In [28]:
# def test_dataset_pipeline(tokenizer):
#     """Print shapes of first batch to verify dataset pipeline."""
#     train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
#     batch = next(iter(train_loader))

#     print("Input tokens shape:", batch['input_tokens'].shape)
#     print("Target mask shape:", batch['target_tokens'].shape)
# test_dataset_pipeline(tokenizer)

# Step 2: Configure hyperparameters below and train!

In [29]:
@chex.dataclass(frozen=True)
class TrainerConfig:
    # dataset pipeline knobs
    batch_size: int = 8
    seq_length: int = 32
    dataset_size_limit: int | None = None

    # training pipeline knobs
    learning_rate: float = 1e-3
    num_epochs: int = 1
    max_steps: int | None = 20

    print_every_n_steps: int = 5

    # eval
    eval_every_n_steps: int = 1000
    max_eval_steps: int | None = 1


trainer_config = TrainerConfig()
optimizer = optax.sgd(trainer_config.learning_rate)


In [30]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    batch_size=trainer_config.batch_size,
    seq_length=trainer_config.seq_length,
    max_examples=trainer_config.dataset_size_limit
)

Map:   0%|          | 0/51760 [00:00<?, ? examples/s]

Map:   0%|          | 0/43996 [00:00<?, ? examples/s]

Map:   0%|          | 0/7764 [00:00<?, ? examples/s]

In [31]:
# Print training information
trainer_lib.pprint_training_pipeline(train_dataloader, trainer_config)


Training Configuration Summary:
Total samples: 43996
Batch size: 8
Number of epochs: 1
Steps per epoch: 5500
Total training steps: 20
*Note*: Total steps limited by max_steps setting (20)


In [32]:
trainer = trainer_lib.CausalLMTrainer(
    model_name=MODEL_NAME,
    model=model,
    model_ckpt_path=model_path,
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=trainer_config,
    mesh=jax_utils.MESH,
    dtype=jnp.bfloat16,  # precision to use for training
)

Loading causal language model with dtype <class 'jax.numpy.bfloat16'>...


In [33]:
state = trainer.train(train_dataloader, val_dataloader, run_jitted=False)

Starting epoch 0 of training...
Epoch 0, Step 0, Train Loss: 3.7606, Accuracy: 0.3438
Epoch 0, Step 5, Train Loss: 3.3515, Accuracy: 0.3438
Epoch 0, Step 10, Train Loss: 2.7629, Accuracy: 0.4375
Epoch 0, Step 15, Train Loss: 2.2916, Accuracy: 0.5000
Epoch 0, Step 20, Train Loss: 1.5016, Accuracy: 0.6562
