In [1]:
import importlib
import os
import sys

# Add the current directory and its parent to the Python path.
# This allows importing modules from these directories.
sys.path.append(os.path.abspath(os.getcwd()))
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))

try:
    import llama3_jax
    print("felafax package imported successfully")
except ImportError as e:
    print(f"Error importing llama3_jax: {e}")

felafax package imported successfully


In [2]:
from llama3_jax.trainer_engine import setup

setup.setup_environment(base_dir="/mnt/persistent-disk/")

In [3]:
from llama3_jax import llama_config
from llama3_jax.trainer_engine import (automodel_lib, checkpoint_lib,
                                       convert_lib, jax_utils, trainer_lib,
                                       utils)

setup.reload_modules("llama3_jax")

Reloaded all felafax modules.


In [4]:
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import chex
import jax
import jax.numpy as jnp
import optax
import torch
from datasets import load_dataset
from transformers import default_data_collator
from huggingface_hub import snapshot_download
import shutil
from datetime import datetime
import gzip

In [5]:
# Select a supported model from above list to use!
MODEL_NAME = "llama-3.1-8B-Instruct-JAX"

In [6]:
# Constants for paths
FELAFAX_DIR = "/mnt/persistent-disk" # os.path.dirname(os.path.dirname(llama3_jax.__file__))

EXPORT_DIR = os.path.join(FELAFAX_DIR, "export")
HF_EXPORT_DIR = os.path.join(FELAFAX_DIR, "hf_export")

current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
GCS_DIR = f"/home/felafax-storage/checkpoints/{MODEL_NAME}/{current_datetime}/"

# Ensure directories exist
utils.makedirs(EXPORT_DIR, exist_ok=True)
utils.makedirs(HF_EXPORT_DIR, exist_ok=True)
utils.makedirs(GCS_DIR, exist_ok=True)

In [7]:
model_path, model, model_configurator, tokenizer = (
    automodel_lib.AutoJAXModelForCausalLM.from_pretrained(MODEL_NAME))

Downloading model llama-3.1-8B-Instruct-JAX...


Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

llama-3.1-8B-Instruct-JAX was downloaded to /mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-Instruct-JAX/snapshots/12d9565c6c550893fd3c0ab62c2b91b16acf1218/llama-3.1-8B-Instruct-JAX.flax.


# Dataset pipeline

In [8]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [9]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Input tokens shape: (1, 32)
Target mask shape: (1, 32)


# Training loop

In [10]:
@chex.dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 5
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = 32
    print_every_n_steps: int = 1
    eval_every_n_steps: int = 1000
    max_eval_steps: int | None = 1


training_cfg = TrainingConfig()
optimizer = optax.sgd(training_cfg.learning_rate)

In [11]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    seq_length=training_cfg.seq_length,
    max_examples=training_cfg.dataset_size_limit,
)

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [12]:
# Calculate and print training steps information
total_samples = len(train_dataloader.dataset)
batch_size = training_cfg.batch_size
steps_per_epoch = (total_samples + batch_size - 1) // batch_size
total_steps = steps_per_epoch * training_cfg.num_epochs

if training_cfg.max_steps:
    total_steps = min(total_steps, training_cfg.max_steps)

print("\nTraining Configuration Summary:")
print(f"Total samples: {total_samples}")
print(f"Batch size: {batch_size}")
print(f"Number of epochs: {training_cfg.num_epochs}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Total training steps: {total_steps}")
if training_cfg.max_steps and total_steps == training_cfg.max_steps:
    print(
        f"*Note*: Total steps limited by max_steps setting ({training_cfg.max_steps})"
    )


Training Configuration Summary:
Total samples: 27
Batch size: 32
Number of epochs: 1
Steps per epoch: 1
Total training steps: 1


In [13]:
trainer = trainer_lib.CausalLMTrainer(
    model=model,
    model_ckpt_path="/mnt/persistent-disk/easy/easylm_llama3.1_8b.easylm",
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=training_cfg,
    mesh=jax_utils.MESH, 
)

Loading causal language model...
> [0;32m/home/roadrunnerx-fork/llama3_jax/trainer_engine/trainer_lib.py[0m(98)[0;36msetup[0;34m()[0m
[0;32m     96 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mmodel_params[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     97 [0;31m                [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 98 [0;31m                _, self.model_params = self.checkpointer.load_trainstate_checkpoint(
[0m[0;32m     99 [0;31m                    [0;34m"params::"[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mmodel_ckpt_path[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mstate_shapes[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    100 [0;31m                    self.shard_fns)
[0m


ipdb>  self.state


*** AttributeError: 'CausalLMTrainer' object has no attribute 'state'


ipdb>  self.state_shapes


TrainState(step=ShapeDtypeStruct(shape=(), dtype=int32), apply_fn=<bound method Module.apply of CausalLlamaModule(
    # attributes
    config = PretrainedConfig {
      "attention_dropout": 0.0,
      "base_model": "llama3.1_8b",
      "embedding_dropout": 0.0,
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "residue_dropout": 0.0,
      "rms_norm_eps": 1e-05,
      "rope_theta": 500000.0,
      "transformers_version": "4.43.3",
      "vocab_size": 128256
    }
    
    dtype = float32
    param_dtype = float32
    precision = None
)>, params={'params': {'lm_head': {'kernel': ShapeDtypeStruct(shape=(4096, 128256), dtype=float32)}, 'transformer': {'h': {'0': {'attention': {'wk': {'kernel': ShapeDtypeStruct(shape=(4096, 1024), dtype=float32)}, 'wo': {'kernel': ShapeDtypeStruct(shape=(4096, 4096), dtype=

ipdb>  c


> [0;32m/home/roadrunnerx-fork/llama3_jax/trainer_engine/checkpoint_lib.py[0m(306)[0;36mload_trainstate_checkpoint[0;34m()[0m
[0;32m    304 [0;31m            [0;31m# Load the params in the streaming format[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    305 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 306 [0;31m            restored_params = cls.load_checkpoint(
[0m[0;32m    307 [0;31m                [0mpath[0m[0;34m=[0m[0mload_path[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    308 [0;31m                [0mtarget[0m[0;34m=[0m[0mparams_target[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  params_target


{'lm_head': {'kernel': ShapeDtypeStruct(shape=(4096, 128256), dtype=float32)}, 'transformer': {'h': {'0': {'attention': {'wk': {'kernel': ShapeDtypeStruct(shape=(4096, 1024), dtype=float32)}, 'wo': {'kernel': ShapeDtypeStruct(shape=(4096, 4096), dtype=float32)}, 'wq': {'kernel': ShapeDtypeStruct(shape=(4096, 4096), dtype=float32)}, 'wv': {'kernel': ShapeDtypeStruct(shape=(4096, 1024), dtype=float32)}}, 'attention_norm': {'kernel': ShapeDtypeStruct(shape=(4096,), dtype=float32)}, 'feed_forward': {'w1': {'kernel': ShapeDtypeStruct(shape=(4096, 14336), dtype=float32)}, 'w2': {'kernel': ShapeDtypeStruct(shape=(14336, 4096), dtype=float32)}, 'w3': {'kernel': ShapeDtypeStruct(shape=(4096, 14336), dtype=float32)}}, 'ffn_norm': {'kernel': ShapeDtypeStruct(shape=(4096,), dtype=float32)}}, '1': {'attention': {'wk': {'kernel': ShapeDtypeStruct(shape=(4096, 1024), dtype=float32)}, 'wo': {'kernel': ShapeDtypeStruct(shape=(4096, 4096), dtype=float32)}, 'wq': {'kernel': ShapeDtypeStruct(shape=(4096, 

ipdb>  params_shard_fns


{'lm_head': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f00d0>}, 'transformer': {'h': {'0': {'attention': {'wk': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f0c10>}, 'wo': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f16c0>}, 'wq': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f1a20>}, 'wv': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f25f0>}}, 'attention_norm': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f30a0>}, 'feed_forward': {'w1': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f3f40>}, 'w2': {'kernel': <function make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn at 0x7fda402f3eb0>}, 'w

ipdb>  c


> [0;32m/home/roadrunnerx-fork/llama3_jax/trainer_engine/checkpoint_lib.py[0m(257)[0;36mload_checkpoint[0;34m()[0m
[0;32m    255 [0;31m[0;34m[0m[0m
[0m[0;32m    256 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 257 [0;31m        [0mtrain_state[0m [0;34m=[0m [0munflatten_dict[0m[0;34m([0m[0mflattend_train_state[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    258 [0;31m        [0;32mif[0m [0mtarget[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    259 [0;31m            [0;32mreturn[0m [0mtrain_state[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


In [14]:
state = trainer.train(train_dataloader, val_dataloader, run_jitted=True)

Starting epoch 0 of training...
Epoch 0, Step 0, Train Loss: 2.5638, Accuracy: 0.4844
Eval Step 0, Loss: 2.2536, Accuracy: 0.5312
Evaluation complete. Average Loss: 2.2536, Average Accuracy: 0.5312
Epoch 0, Step 0, Eval Loss: 2.2536, Accuracy: 0.5312
Epoch 0, Step 1, Train Loss: 2.4324, Accuracy: 0.5156
Epoch 0, Step 2, Train Loss: 2.0826, Accuracy: 0.5625
Epoch 0, Step 3, Train Loss: 2.1698, Accuracy: 0.5781
Epoch 0, Step 4, Train Loss: 1.9138, Accuracy: 0.5938
Epoch 0, Step 5, Train Loss: 1.9676, Accuracy: 0.6094


In [15]:
flax_checkpoint_path = os.path.join(EXPORT_DIR, MODEL_NAME)
trainer.save_checkpoint(state, path=flax_checkpoint_path)

Saving checkpoint to /mnt/persistent-disk/export/llama-3.1-8B-Instruct-JAX...
Checkpoint saved to /mnt/persistent-disk/export/llama-3.1-8B-Instruct-JAX.


In [None]:
convert_lib.save_hf_compatible_checkpoint(
    f'flax_params::{flax_checkpoint_path}', HF_EXPORT_DIR, model_configurator)

# Download and save the tokenizer
tokenizer_repo = f"felafax/tokenizer-{MODEL_NAME}"
tokenizer_dir = snapshot_download(repo_id=tokenizer_repo)

# Move all files from tokenizer_dir to HF_EXPORT_DIR
for item in os.listdir(tokenizer_dir):
    s = os.path.join(tokenizer_dir, item)
    d = os.path.join(HF_EXPORT_DIR, item)
    if os.path.isfile(s):
        shutil.copy2(s, d)
        print(f"Copied {item} to {HF_EXPORT_DIR}")
    elif os.path.isdir(s):
        shutil.copytree(s, d, dirs_exist_ok=True)
        print(f"Copied directory {item} to {HF_EXPORT_DIR}")
print(f"All tokenizer files saved to {HF_EXPORT_DIR}")

In [None]:
checkpoint_lib.copy_directory(HF_EXPORT_DIR, GCS_DIR)
print(f"Checkpoint copied to {GCS_DIR}")

In [None]:
# HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")
# HUGGINGFACE_USERNAME = input(
#     "INPUT: Please provide your HUGGINGFACE_USERNAME: ")
# HUGGINGFACE_REPO_NAME = input(
#     "INPUT: Please provide your HUGGINGFACE_REPO_NAME: ")
# convert_lib.upload_checkpoint_to_hf(
#     HF_EXPORT_DIR, f"{HUGGINGFACE_USERNAME}/{HUGGINGFACE_REPO_NAME}",
#     HUGGINGFACE_TOKEN)