In [1]:
!pip install -r ../requirements.txt -q

[0m

In [2]:
import importlib
import os
import sys

# Add the current directory and its parent to the Python path.
# This allows importing modules from these directories.
sys.path.append(os.path.abspath(os.getcwd()))
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))

felafax package imported successfully


In [3]:
import llama3_jax
from llama3_jax.trainer_engine import setup
setup.setup_environment(base_dir="/mnt/persistent-disk/")

In [4]:
from llama3_jax.trainer_engine import (automodel_lib, checkpoint_lib,
                                       convert_lib, llama_config, jax_utils, trainer_lib,
                                       utils)
setup.reload_modules("llama3_jax")

  from .autonotebook import tqdm as notebook_tqdm


Reloaded all felafax modules.


In [5]:
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import chex
import jax
import jax.numpy as jnp
import optax
import torch
from datasets import load_dataset
from transformers import default_data_collator
from huggingface_hub import snapshot_download
import shutil
from datetime import datetime
import gzip

In [6]:
# Select a supported model from above list to use!
MODEL_NAME = "llama-3.1-8B-Instruct-JAX"

In [8]:
model_path, model, model_configurator, tokenizer = (
    automodel_lib.AutoJAXModelForCausalLM.from_pretrained(MODEL_NAME))

Downloading model llama-3.1-8B-Instruct-JAX...


Fetching 8 files: 100%|██████████| 8/8 [08:38<00:00, 64.77s/it] 


llama-3.1-8B-Instruct-JAX was downloaded to /mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-Instruct-JAX/snapshots/12d9565c6c550893fd3c0ab62c2b91b16acf1218/llama-3.1-8B-Instruct-JAX.flax.


# Dataset pipeline

In [9]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [10]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

Downloading readme: 100%|██████████| 11.6k/11.6k [00:00<00:00, 23.5MB/s]
Downloading data: 100%|██████████| 44.3M/44.3M [00:00<00:00, 118MB/s] 
Generating train split: 51760 examples [00:00, 150940.12 examples/s]
Map: 100%|██████████| 32/32 [00:00<00:00, 7448.26 examples/s]
Map: 100%|██████████| 27/27 [00:00<00:00, 1405.09 examples/s]
Map: 100%|██████████| 5/5 [00:00<00:00, 945.26 examples/s]


Input tokens shape: (1, 32)
Target mask shape: (1, 32)


# Training loop

In [11]:
@chex.dataclass(frozen=True)
class TrainerConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 100
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = None
    print_every_n_steps: int = 5
    eval_every_n_steps: int = 1000
    max_eval_steps: int | None = 1


trainer_config = TrainerConfig()
optimizer = optax.sgd(trainer_config.learning_rate)

In [12]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    batch_size=trainer_config.batch_size,
    seq_length=trainer_config.seq_length,
    max_examples=trainer_config.dataset_size_limit,
)

Map: 100%|██████████| 51760/51760 [00:00<00:00, 95472.40 examples/s] 
Map: 100%|██████████| 43996/43996 [00:21<00:00, 2006.26 examples/s]
Map: 100%|██████████| 7764/7764 [00:03<00:00, 2052.44 examples/s]


In [13]:
# Calculate and print training steps information
total_samples = len(train_dataloader.dataset)
batch_size = trainer_config.batch_size
steps_per_epoch = (total_samples + batch_size - 1) // batch_size
total_steps = steps_per_epoch * trainer_config.num_epochs

if trainer_config.max_steps:
    total_steps = min(total_steps, trainer_config.max_steps)

print("\nTraining Configuration Summary:")
print(f"Total samples: {total_samples}")
print(f"Batch size: {batch_size}")
print(f"Number of epochs: {trainer_config.num_epochs}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Total training steps: {total_steps}")
if trainer_config.max_steps and total_steps == trainer_config.max_steps:
    print(
        f"*Note*: Total steps limited by max_steps setting ({trainer_config.max_steps})"
    )


Training Configuration Summary:
Total samples: 43996
Batch size: 32
Number of epochs: 1
Steps per epoch: 1375
Total training steps: 100
*Note*: Total steps limited by max_steps setting (100)


In [14]:
trainer = trainer_lib.CausalLMTrainer(
    model=model,
    model_ckpt_path=model_path,
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=trainer_config,
    mesh=jax_utils.MESH,
    model_name=MODEL_NAME,
)

Loading causal language model...


In [15]:
import time
start_time = time.time()
print(f"Start time: {start_time:.4f}")

state = trainer.train(train_dataloader, val_dataloader, run_jitted=True)

end_time = time.time()
print(f"End time: {end_time:.4f}")

Start time: 1725751128.2981
Starting epoch 0 of training...
Epoch 0, Step 0, Train Loss: 3.2309, Accuracy: 0.3906
Eval Step 0, Loss: 2.8140, Accuracy: 0.4531
Evaluation complete. Average Loss: 2.8140, Average Accuracy: 0.4531
Epoch 0, Step 0, Eval Loss: 2.8140, Accuracy: 0.4531
Epoch 0, Step 5, Train Loss: 1.9765, Accuracy: 0.5938
Epoch 0, Step 10, Train Loss: 1.5610, Accuracy: 0.6562
Epoch 0, Step 15, Train Loss: 1.5088, Accuracy: 0.6719
Epoch 0, Step 20, Train Loss: 1.1654, Accuracy: 0.7344
Epoch 0, Step 25, Train Loss: 0.8645, Accuracy: 0.7969
Epoch 0, Step 30, Train Loss: 0.9819, Accuracy: 0.8125
Epoch 0, Step 35, Train Loss: 0.9629, Accuracy: 0.7656
Epoch 0, Step 40, Train Loss: 0.6825, Accuracy: 0.8438
Epoch 0, Step 45, Train Loss: 0.4973, Accuracy: 0.8438
Epoch 0, Step 50, Train Loss: 0.5720, Accuracy: 0.8438
Epoch 0, Step 55, Train Loss: 0.7773, Accuracy: 0.8438
Epoch 0, Step 60, Train Loss: 0.4638, Accuracy: 0.9062
Epoch 0, Step 65, Train Loss: 0.6748, Accuracy: 0.8438
Epoch 0

In [16]:
# Calculate and print elapsed time
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.4f} seconds")

Execution time: 1330.6318 seconds


## Export checkpoint

In [None]:
# Constants for paths
FELAFAX_DIR = "/mnt/persistent-disk"
EXPORT_DIR = os.path.join(FELAFAX_DIR, "export")
HF_EXPORT_DIR = os.path.join(FELAFAX_DIR, "hf_export")

current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
GCS_DIR = f"/home/felafax-storage/checkpoints/{MODEL_NAME}/{current_datetime}/"

# Ensure directories exist
utils.makedirs(EXPORT_DIR, exist_ok=True)
utils.makedirs(HF_EXPORT_DIR, exist_ok=True)
utils.makedirs(GCS_DIR, exist_ok=True)

In [17]:
flax_checkpoint_path = os.path.join(EXPORT_DIR, MODEL_NAME)
trainer.save_checkpoint(state, path=flax_checkpoint_path)

Saving checkpoint to /mnt/persistent-disk/export/llama-3.1-8B-Instruct-JAX...
Checkpoint saved to /mnt/persistent-disk/export/llama-3.1-8B-Instruct-JAX.


In [18]:
convert_lib.save_hf_compatible_checkpoint(
    f'flax_params::{flax_checkpoint_path}', HF_EXPORT_DIR, model_configurator)

# Download and save the tokenizer
tokenizer_repo = f"felafax/tokenizer-llama-3.1-8B-Instruct-JAX"
tokenizer_dir = snapshot_download(repo_id=tokenizer_repo)

# Move all files from tokenizer_dir to HF_EXPORT_DIR
for item in os.listdir(tokenizer_dir):
    s = os.path.join(tokenizer_dir, item)
    d = os.path.join(HF_EXPORT_DIR, item)
    if os.path.isfile(s):
        shutil.copy2(s, d)
        print(f"Copied {item} to {HF_EXPORT_DIR}")
    elif os.path.isdir(s):
        shutil.copytree(s, d, dirs_exist_ok=True)
        print(f"Copied directory {item} to {HF_EXPORT_DIR}")
print(f"All tokenizer files saved to {HF_EXPORT_DIR}")

Loading the checkpoint in a Llama model.


Loading checkpoint shards: 100%|██████████| 33/33 [00:20<00:00,  1.60it/s]


Saving in the Transformers format.


Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00,  8.24it/s]

Copied README.md to /mnt/persistent-disk/hf_export
Copied tokenizer.json to /mnt/persistent-disk/hf_export
Copied .gitattributes to /mnt/persistent-disk/hf_export
Copied special_tokens_map.json to /mnt/persistent-disk/hf_export
Copied tokenizer_config.json to /mnt/persistent-disk/hf_export
Copied generation_config.json to /mnt/persistent-disk/hf_export
Copied config.json to /mnt/persistent-disk/hf_export
All tokenizer files saved to /mnt/persistent-disk/hf_export





In [19]:
checkpoint_lib.copy_directory(HF_EXPORT_DIR, GCS_DIR)
print(f"Checkpoint copied to {GCS_DIR}")

Starting to copy directory /mnt/persistent-disk/hf_export to /home/felafax-storage/checkpoints/llama-3.1-8B-Instruct-JAX/20240907_230836/
Checkpoint copied to /home/felafax-storage/checkpoints/llama-3.1-8B-Instruct-JAX/20240907_230836/


In [20]:
# HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")
# HUGGINGFACE_USERNAME = input(
#     "INPUT: Please provide your HUGGINGFACE_USERNAME: ")
# HUGGINGFACE_REPO_NAME = input(
#     "INPUT: Please provide your HUGGINGFACE_REPO_NAME: ")
# convert_lib.upload_checkpoint_to_hf(
#     HF_EXPORT_DIR, f"{HUGGINGFACE_USERNAME}/{HUGGINGFACE_REPO_NAME}",
#     HUGGINGFACE_TOKEN)