In [6]:
import os
import sys
import importlib

# Add the parent directory of the current working directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

try:
    import felafax
    print("felafax package imported successfully")
except ImportError as e:
    print(f"Error importing felafax: {e}")

felafax package imported successfully


In [7]:
from felafax.trainer_engine import setup
setup.setup_environment()

In [10]:
from typing import (Any, Dict, List, Mapping, Optional, Sequence, Tuple,
                    Union)

import jax
import jax.numpy as jnp
import chex
import optax

import torch

from datasets import load_dataset
from transformers import default_data_collator

In [11]:
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ") or "felarof01"
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ") or "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"

INPUT: Please provide your HUGGINGFACE_USERNAME:  
INPUT: Please provide your HUGGINGFACE_TOKEN:  


In [14]:
# Select a supported model from above list to use!
MODEL_NAME = "Meta-Llama-3.1-8B"

In [19]:
model_path, model, model_configurator, tokenizer = automodel_lib.AutoJAXModelForCausalLM.from_pretrained("llama-3.1-8B-JAX",
                                                                           HUGGINGFACE_TOKEN)

Downloading model llama-3.1-8B-JAX...


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 16.34it/s]


llama-3.1-8B-JAX was downloaded to /home/felafax-storage/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9.


# Will just use the same dataset pipeline for now

In [16]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [None]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

# Training loop

In [17]:
@chex.dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 5
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = 32
    print_every_n_steps: int = 1


training_cfg = TrainingConfig()
optimizer = optax.sgd(training_cfg.learning_rate)

In [18]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    seq_length=training_cfg.seq_length,
    max_examples=training_cfg.dataset_size_limit,
)

Map: 100%|██████████| 27/27 [00:00<00:00, 35.36 examples/s] 
Map: 100%|██████████| 5/5 [00:00<00:00,  7.72 examples/s]


In [58]:
from felafax.trainer_engine import utils, jax_utils
from felafax.trainer_engine import automodel_lib, checkpoint_lib, trainer_lib
from felafax import llama_config

setup.reload_modules()

Attempted to reload all felafax.trainer_engine modules


In [59]:
trainer = trainer_lib.CausalLMTrainer(
    model=model,
    model_ckpt_path="/home/felafax-storage/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/llama3.1_8b_serialized.flax", # model_path,
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=training_cfg,
    mesh=jax_utils.MESH, 
    model_params=model_params,
)

TypeError: 'NamedSharding' object is not iterable

In [41]:
model_params = trainer.model_params

In [None]:
model_params

In [57]:
trainer.state_shapes_partitioned

TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of CausalLlamaModule(
    # attributes
    config = PretrainedConfig {
      "attention_dropout": 0.0,
      "base_model": "llama3.1_8b",
      "embedding_dropout": 0.0,
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "residue_dropout": 0.0,
      "rms_norm_eps": 1e-05,
      "rope_theta": 500000.0,
      "transformers_version": "4.43.3",
      "vocab_size": 128256
    }
    
    dtype = float32
    param_dtype = float32
    precision = None
)>, params={'params': {'lm_head': {'kernel': PartitionSpec('fsdp', 'mp')}, 'transformer': {'h': {'0': {'attention': {'wk': {'kernel': PartitionSpec('fsdp', 'mp')}, 'wo': {'kernel': PartitionSpec('mp', 'fsdp')}, 'wq': {'kernel': PartitionSpec('fsdp', 'mp')}, 'wv': {'kernel': PartitionSpec('fsdp', '

In [52]:
trainer.train(train_dataloader, val_dataloader, run_jitted=False)

Starting epoch 0 of training...


ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 8, 'mp': 1), spec=PartitionSpec(('dp', 'fsdp'), None, 'mp')), which implies that the global size of its dimension 0 should be divisible by 8, but it is equal to 1 (full shape: (1, 64, 4096))