In [1]:
!pip install --upgrade transformers -q

In [2]:
import os
os.environ['USE_TORCH'] = 'True'  # To use transformers library in TPU
os.environ['XLA_USE_BF16'] = 'True'
os.environ['PJRT_DEVICE'] = 'TPU'
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'

In [3]:
import os
import contextlib
from dataclasses import dataclass

import torch
import numpy as np
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
xr.use_spmd()
xr.initialize_cache('/mnt/persistent-disk/xla/')

import torch_xla.distributed.spmd as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import load_file

import logging
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)

  from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
  from torch_xla.experimental.xla_sharding import Mesh
  from .autonotebook import tqdm as notebook_tqdm


In [4]:
assert xr.is_spmd()==True

In [5]:
import sys
import importlib
sys.path.append('')
model_partitioning = importlib.import_module('LLaMa3.model_partitioning')
importlib.reload(model_partitioning)

<module 'LLaMa3.model_partitioning' from '/home/RoadrunnerX/LLaMa3/model_partitioning.py'>

In [6]:
def print_model(model, num_layers=1):
    print_is_xla()
    for i, (name, param) in enumerate(model.named_parameters()):
        if i >= num_layers:
            break
        
        print(f"Layer: {name}")
        print(f"Shape: {param.shape}")
        print(f"Device: {param.device}")
        print(f"Tensor value: {param.data}")
        print("-" * 50)
        
def print_is_xla():
    import torch_xla
    print("is xla: ", torch_xla._XLAC._xla_runtime_is_initialized())

In [7]:
# On a single TPU VM host, you can train/tune LLaMa 3/3.1 8B models with full precision or LoRA.
supported_models = [
    "TinyLlama/TinyLlama-1.1B-step-50K-105b",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Meta-Llama-3.1-8B"
]

# Select a supported model from above list to use!
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
HUGGINGFACE_TOKEN = "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"
DEBUG_MODE = False

TRAINER_CONFIG = {
    "epochs": 1,
    "batch_size": 1,
    "max_length": 512,
    
    "lr": 5e-5,
    "logging_interval": 5,  # logs every 5 steps
    
    "lora_rank": 8,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
}

# Configure LoRA config for your model.
Use the below code to configure the LoRA config for your model.

In [8]:
def apply_lora(*, model, lora_rank=None, lora_alpha=None, lora_dropout=None):
    """Applies LoRA configuration to the model."""
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8 if not lora_rank else lora_rank,
        lora_alpha=32 if not lora_alpha else lora_alpha,
        lora_dropout=0.1 if not lora_dropout else lora_dropout,
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model

In [9]:
def init_model(*, model_name, hugging_face_token):
    """Downloads and initializes the model."""
    config = AutoConfig.from_pretrained(
        model_name, 
        token=hugging_face_token)
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
        token=hugging_face_token
    )

    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
        config.pad_token_id = tokenizer.pad_token_id
        
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        token=hugging_face_token,
        low_cpu_mem_usage=True
    )

    model = apply_lora(
        model=model,
        lora_rank=TRAINER_CONFIG["lora_rank"],
        lora_alpha=TRAINER_CONFIG["lora_alpha"],
        lora_dropout=TRAINER_CONFIG["lora_dropout"],
    )

    model = xmp.MpModelWrapper(model)

    return model, tokenizer

In [10]:
def apply_spmd(*, model, mesh):
    # Apply on layers within model.
    model_partitioning_util.partition_model(model, mesh)

# Configure dataset pipeline for your model

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [11]:
def get_dataset(*, tokenizer, batch_size=None, max_length=None, debug_mode=False):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    # Tokenize the dataset.
    def _tokenize(examples):
        # Tokenized is list within list. Compute labels for causalLM by shifting input_id; 
        # consequently truncate input_id to penultimate position.
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512+1 if not max_length else max_length+1)
        labels = tokenized['input_ids'].copy()
        tokenized['labels'] = [label[1:] for label in labels]
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        return tokenized

    # Load and preprocess the dataset.
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if debug_mode:
        dataset = dataset.select(range(32)) # Use just 32 exampfor faster iteration
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoader
    train_dataloader = torch.utils.data.DataLoader(
        ds['train'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=default_data_collator,
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        ds['test'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=default_data_collator,
    )

    return train_dataloader, test_dataloader

# Train the model

Now let's train the model. We are using PyTorch XLA's Fully Sharded Data Parallel (FSDP) to distribute the model across the 8 TPU cores available on TPU v3-8. This approach allows for efficient training on TPU hardware. We also utilize PyTorch/XLA's MpDeviceLoader to efficiently load data onto the TPU cores.

**NOTE:** It's important to note that the **first step of training will be slow**. This is because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using compiled+cached graph, and leveraging the full power of the all TPU cores for accelerated training.


In [12]:
def print_training_update(device,
                          step,
                          loss,
                          rate,
                          epoch=None):
    """Prints the training metrics at a given step."""
    if xm.is_master_ordinal():  # Only print on the master device
        update_data = [
            'Training',
            f'Epoch={epoch}' if epoch is not None else None,
            f'Step={step}',
            f'Loss={loss:.5f}',
            # f'Rate={rate:.2f}',
            # f'Time={now()}'
        ]
        print(' | '.join(item for item in update_data if item), flush=True)
        print()


In [14]:
model, tokenizer = init_model(
        model_name=MODEL_NAME, hugging_face_token=HUGGINGFACE_TOKEN
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.73it/s]


trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.04241987003816259


In [15]:
def train(index):
    global model, tokenizer

    torch.manual_seed(99)
    device = xm.xla_device()
    model = model.to(device)

    # Create a mesh for the model partitioning.
    num_devices = xr.global_runtime_device_count()
    mesh_shape = (1, num_devices, 1)
    device_ids = np.array(range(num_devices))
    mesh = Mesh(device_ids, mesh_shape, ("dp", "fsdp", "mp"))

    # model = checkpoint_module(model)
    
    # Partition the model using SPMD.
    model_partitioning.partition_model(model=model, mesh=mesh)
    
    # Configure the training loop.
    optimizer = torch.optim.Adam(model.parameters(), lr=TRAINER_CONFIG["lr"])

    train_dataloader, test_dataloader = get_dataset(
        tokenizer=tokenizer,
        batch_size=TRAINER_CONFIG["batch_size"],
        max_length=TRAINER_CONFIG["max_length"],
    )
    train_dataloader = pl.MpDeviceLoader(
        train_dataloader, 
        device
    ) 
    test_dataloader = pl.MpDeviceLoader(
        test_dataloader, 
        device
    )

    for epoch in range(TRAINER_CONFIG["epochs"]):
        xm.master_print(f"Epoch {epoch} train begin {test_utils.now()}")
        tracker = xm.RateTracker()
        
        model.train()
        for step, batch in enumerate(train_dataloader):
            if step>10:
                break
                
            optimizer.zero_grad()
            
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding(attention_mask, mesh, (0, 1))
            xs.mark_sharding(labels, mesh, (0, 1))
            
            output = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            loss = output.loss
            loss.backward()
            
            optimizer.step()
            xm.mark_step()

            if step%TRAINER_CONFIG["logging_interval"]==0:
                loss_cpu = loss.item()
                xm.add_step_closure(
                    print_training_update,
                    args=(device, step, loss_cpu, tracker.rate(), epoch)
                )
            
        model.eval()
        eval_loss = 0
        with torch.no_grad():
            for step, batch in enumerate(test_dataloader):
                input_ids, attention_mask, labels = (
                    batch["input_ids"],
                    batch["attention_mask"],
                    batch["labels"],
                )
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(attention_mask, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                
                output = model(
                    input_ids=input_ids, attention_mask=attention_mask, labels=labels
                )
                eval_loss += output.loss.item()
        avg_eval_loss = eval_loss / len(test_dataloader)
        xm.add_step_closure(
            lambda: print(f"Eval loss: {avg_eval_loss:.4f}"),
        )


In [16]:
xmp.spawn(train, args=(), start_method="fork")

I0000 00:00:1722464499.867909   25675 pjrt_api.cc:100] GetPjrtApi was found for tpu at /usr/local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1722464499.867909   25674 pjrt_api.cc:100] GetPjrtApi was found for tpu at /usr/local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1722464499.867989   25674 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1722464499.867988   25675 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1722464499.868004   25675 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
I0000 00:00:1722464499.868002   25674 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
I0000 00:00:1722464499.883581   25672 pjrt_api.cc:100] GetPjrtApi was found for tpu at /usr/local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1722464499.883637   25672 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1722464499

Epoch 0 train begin 22:22:30


Map: 100%|██████████| 43996/43996 [00:25<00:00, 1694.57 examples/s]
Map:  75%|███████▌  | 33000/43996 [00:19<00:07, 1449.05 examples/s]

Epoch 0 train begin 22:22:33


Map: 100%|██████████| 43996/43996 [00:25<00:00, 1720.92 examples/s]


Epoch 0 train begin 22:22:36


Map: 100%|██████████| 43996/43996 [00:24<00:00, 1783.70 examples/s]


Epoch 0 train begin 22:22:38
Training | Epoch=0 | Step=0 | Loss=13.53303

Training | Epoch=0 | Step=5 | Loss=12.80109

Training | Epoch=0 | Step=10 | Loss=12.47537



We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#tran

AttributeError: 'float' object has no attribute 'item'

# Export the model to HuggingFace Hub
Uncoment the following cell to push the model to HuggingFace Hub.

In [17]:
import os
from huggingface_hub import login

login(token=HUGGINGFACE_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /mnt/persistent-disk/hf/token
Login successful


In [18]:
model = model.cpu()
model.push_to_hub(
    "felarof01/llama3-8.1-roadrunner",
    tokenizer=tokenizer,
    private=False,
    create_pr=False,
    max_shard_size="5GB",
)

AttributeError: 'MpModelWrapper' object has no attribute 'cpu'