In [1]:
import os
import sys
import importlib

# Add the parent directory of the current working directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

try:
    import felafax
    print("felafax package imported successfully")
except ImportError as e:
    print(f"Error importing felafax: {e}")

felafax package imported successfully


In [2]:
from felafax.trainer_engine import setup
setup.setup_environment()

In [16]:
from felafax.trainer_engine import utils, jax_utils
from felafax.trainer_engine import automodel_lib, checkpoint_lib, trainer_lib
from felafax import llama_config

setup.reload_modules()

Attempted to reload all felafax.trainer_engine modules


In [4]:
from typing import (Any, Dict, List, Mapping, Optional, Sequence, Tuple,
                    Union)

import jax
import jax.numpy as jnp
import chex
import optax

import torch

from datasets import load_dataset
from transformers import default_data_collator

In [5]:
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ") or "felarof01"
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ") or "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"

INPUT: Please provide your HUGGINGFACE_USERNAME:  
INPUT: Please provide your HUGGINGFACE_TOKEN:  


In [6]:
# Select a supported model from above list to use!
MODEL_NAME = "Meta-Llama-3.1-8B"

In [17]:
model_path, model, model_configurator, tokenizer = automodel_lib.AutoJAXModelForCausalLM.from_pretrained("llama-3.1-8B-JAX",
                                                                           HUGGINGFACE_TOKEN)

Downloading model llama-3.1-8B-JAX...


Fetching 3 files: 100%|██████████| 3/3 [00:01<00:00,  1.86it/s]

llama-3.1-8B-JAX was downloaded to /home/felafax-storage/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/llama3.1_8b_serialized.flax.





# Will just use the same dataset pipeline for now

In [8]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [9]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

Map: 100%|██████████| 27/27 [00:03<00:00,  8.50 examples/s]
Map: 100%|██████████| 5/5 [00:03<00:00,  1.31 examples/s]

Input tokens shape: (1, 32)
Target mask shape: (1, 32)





# Training loop

In [10]:
@chex.dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 5
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = 32
    print_every_n_steps: int = 1
    eval_every_n_steps: int = 1000


training_cfg = TrainingConfig()
optimizer = optax.sgd(training_cfg.learning_rate)

In [11]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    seq_length=training_cfg.seq_length,
    max_examples=training_cfg.dataset_size_limit,
)

Map: 100%|██████████| 27/27 [00:01<00:00, 16.28 examples/s] 
Map: 100%|██████████| 5/5 [00:01<00:00,  3.09 examples/s]


In [12]:
trainer = trainer_lib.CausalLMTrainer(
    model=model,
    model_ckpt_path=model_path,
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=training_cfg,
    mesh=jax_utils.MESH, 
    # model_params=model_params,
)

Loading causal language model...


In [13]:
model_params = trainer.model_params

In [15]:
state = trainer.train(train_dataloader, val_dataloader, run_jitted=True)

Starting epoch 0 of training...
Epoch 0, Step 0, Train Loss: 2.3347, Accuracy: 0.4375
Epoch 0, Step 0, Eval Loss: 2.0824, Accuracy: 0.5960
Epoch 0, Step 1, Train Loss: 1.9930, Accuracy: 0.5469
Epoch 0, Step 2, Train Loss: 2.2861, Accuracy: 0.5625
Epoch 0, Step 3, Train Loss: 1.5709, Accuracy: 0.6094
Epoch 0, Step 4, Train Loss: 1.5585, Accuracy: 0.6094
Epoch 0, Step 5, Train Loss: 1.6425, Accuracy: 0.6094


In [20]:
trainer.save_checkpoint(state, path="/home/ckpt/llama3.flax")

## Convert and upload checkpint

In [47]:
import gc
import json
import math
import os
import shutil

import flax
import jax
import jax.numpy as jnp
import numpy as np
import torch
from flax.traverse_util import flatten_dict
from transformers import LlamaConfig, LlamaForCausalLM

def match_keywords(string, positives, negatives):
    for positive in positives:
        if positive not in string:
            return False
    for negative in negatives:
        if negative in string:
            return False
    return True

def load_and_convert_checkpoint(path):
    _, flax_params = checkpoint_lib.Checkpointer.load_trainstate_checkpoint(path)
    import pdb; pdb.set_trace()
    flax_params = flatten_dict(flax_params['params'], sep='.')
    torch_params = {}
    for key, tensor in flax_params.items():
        if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
            tensor = tensor.T
        torch_params[key] = torch.tensor(checkpoint_lib.float_tensor_to_dtype(tensor, 'fp32'),
                                         dtype=torch.float16)
    return torch_params

def read_json(path):
    with open(path, "r") as f:
        return json.load(f)


def write_json(text, path):
    with open(path, "w") as f:
        json.dump(text, f)


def permute(w, n_heads, input_dim, output_dim):
    # permute for sliced rotary embedding
    return w.view(n_heads, output_dim // n_heads // 2, 2,
                  input_dim).transpose(1, 2).reshape(output_dim, input_dim)

In [41]:

def write_model(loaded, model_path, llama_pretrained_config):
    os.makedirs(model_path, exist_ok=True)
    tmp_model_path = os.path.join(model_path, "tmp")
    os.makedirs(tmp_model_path, exist_ok=True)

    llama_config = llama_pretrained_config
    n_layers = llama_config.num_hidden_layers
    n_heads = llama_config.num_attention_heads
    n_kv_heads = llama_config.num_key_value_heads
    dim = llama_config.hidden_size
    dims_per_head = dim // n_heads
    base = llama_config.rope_theta
    inv_freq = 1.0 / (base**(torch.arange(0, dims_per_head, 2).float() /
                             dims_per_head))

    param_count = 0
    index_dict = {"weight_map": {}}
    for layer_i in range(n_layers):
        filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
        state_dict = {
            f"model.layers.{layer_i}.self_attn.q_proj.weight":
            permute(
                loaded[f"transformer.h.{layer_i}.attention.wq.kernel"],
                llama_config.num_attention_heads,
                llama_config.hidden_size,
                llama_config.hidden_size,
            ),
            f"model.layers.{layer_i}.self_attn.k_proj.weight":
            permute(
                loaded[f"transformer.h.{layer_i}.attention.wk.kernel"],
                llama_config.num_key_value_heads,
                llama_config.hidden_size,
                llama_config.hidden_size //
                (llama_config.num_attention_heads //
                 llama_config.num_key_value_heads),
            ),
            f"model.layers.{layer_i}.self_attn.v_proj.weight":
            loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
            f"model.layers.{layer_i}.self_attn.o_proj.weight":
            loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
            f"model.layers.{layer_i}.mlp.gate_proj.weight":
            loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
            f"model.layers.{layer_i}.mlp.down_proj.weight":
            loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
            f"model.layers.{layer_i}.mlp.up_proj.weight":
            loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
            f"model.layers.{layer_i}.input_layernorm.weight":
            loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
            f"model.layers.{layer_i}.post_attention_layernorm.weight":
            loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
        }

        state_dict[
            f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
        for k, v in state_dict.items():
            index_dict["weight_map"][k] = filename
            param_count += v.numel()
        torch.save(state_dict, os.path.join(tmp_model_path, filename))

    filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
    # Unsharded
    state_dict = {
        "model.embed_tokens.weight": loaded["transformer.wte.embedding"],
        "model.norm.weight": loaded["transformer.ln_f.kernel"],
        "lm_head.weight": loaded["lm_head.kernel"],
    }

    for k, v in state_dict.items():
        index_dict["weight_map"][k] = filename
        param_count += v.numel()
    torch.save(state_dict, os.path.join(tmp_model_path, filename))

    # Write configs
    index_dict["metadata"] = {"total_size": param_count * 2}
    write_json(index_dict,
               os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))

    config = LlamaConfig(
        vocab_size=llama_config.vocab_size,
        hidden_size=llama_config.hidden_size,
        intermediate_size=llama_config.intermediate_size,
        num_hidden_layers=llama_config.num_hidden_layers,
        num_attention_heads=llama_config.num_attention_heads,
        num_key_value_heads=llama_config.num_key_value_heads,
        initializer_range=llama_config.initializer_range,
        rms_norm_eps=llama_config.rms_norm_eps,
        max_position_embeddings=llama_config.max_position_embeddings,
        rope_theta=llama_config.rope_theta,
    )
    config.save_pretrained(tmp_model_path)

    # Make space so we can load the model properly now.
    del state_dict
    del loaded
    gc.collect()

    print("Loading the checkpoint in a Llama model.")
    model = LlamaForCausalLM.from_pretrained(tmp_model_path,
                                             torch_dtype=torch.float16)
    # Avoid saving this as part of the config.
    del model.config._name_or_path

    print("Saving in the Transformers format.")
    model.save_pretrained(model_path)
    shutil.rmtree(tmp_model_path)

In [None]:
loaded_params = load_and_convert_checkpoint('flax_params::/home/ckpt/llama3.flax')

In [None]:
loaded_params

In [50]:
write_model(
    loaded_params,
    model_path="/home/ckpt/",
    llama_pretrained_config=model_configurator.get_hf_pretrained_config(model_configurator.get_model_config()),
)

Loading the checkpoint in a Llama model.


Loading checkpoint shards: 100%|██████████| 33/33 [00:23<00:00,  1.41it/s]


Saving in the Transformers format.
