In [None]:
!pip install --upgrade transformers -q

In [None]:
import os
os.environ['USE_TORCH'] = 'True'  # To use transformers library in TPU
os.environ['XLA_USE_BF16'] = 'True'
os.environ['PJRT_DEVICE'] = 'TPU'
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [None]:
import os
import contextlib
from dataclasses import dataclass

import torch
import numpy as np
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs 
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

# from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module

import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType, get_peft_model

import logging
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)

In [None]:
assert xr.is_spmd()==True

In [None]:
import sys
import importlib
sys.path.append('')
model_partitioning = importlib.import_module('trainer_lib.model_partitioning')
importlib.reload(model_partitioning)

In [None]:
import os
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file

In [None]:
# This notebook can be used to train any of the 70B, 80B models!
supported_models = [
    "meta-llama/Meta-Llama-3-70B",
    "meta-llama/Meta-Llama-3-70B-Instruct",
]

# Select a supported model from above list to use!
MODEL_NAME = "meta-llama/Meta-Llama-3-70B"
HUGGINGFACE_TOKEN = input("**INPUT** you huggingface token: ") # YOUR_HF_TOKEN
DEBUG_MODE = False

TRAINER_CONFIG = {
    "epochs": 1,
    "batch_size": 1,
    "max_length": 512,
    
    "lr": 5e-5,
    "logging_interval": 5,  # logs every 5 steps
    
    "lora_rank": 8,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
}

In [None]:
%%capture
from huggingface_hub import login
login(token=HUGGINGFACE_TOKEN)

In [None]:
def apply_lora(*, model, lora_rank=None, lora_alpha=None, lora_dropout=None):
    """Applies LoRA configuration to the model."""
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8 if not lora_rank else lora_rank,
        lora_alpha=32 if not lora_alpha else lora_alpha,
        lora_dropout=0.1 if not lora_dropout else lora_dropout,
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model

In [None]:
def init_model(*, model_name, hugging_face_token):
    """Downloads and initializes the model."""
    from huggingface_hub import hf_hub_download
    
    config = AutoConfig.from_pretrained(
        model_name,
        token=hugging_face_token)
    with torch.device('meta'):
        model = AutoModelForCausalLM.from_config(config)

    # Load and merge state dicts.
    from safetensors.torch import load_file
    num_safetensor_shards = 30
    state_dict = {}
    for i in range(1, num_safetensor_shards + 1):
        file = f"model-{i:05d}-of-{num_safetensor_shards:05d}.safetensors"
        print(f"downloading file {file} from HF...")
        file_path = hf_hub_download(repo_id=model_name, filename=file)
        state_dict.update(load_file(file_path))

    for name, param in model.named_parameters():
        if name in state_dict:
            param.data.copy_(state_dict[name].data)

    model.to_empty(device='cpu')
    
    model.load_state_dict(state_dict, strict=False, assign=True)

    model = apply_lora(
        model=model,
        lora_rank=TRAINER_CONFIG["lora_rank"],
        lora_alpha=TRAINER_CONFIG["lora_alpha"],
        lora_dropout=TRAINER_CONFIG["lora_dropout"],
    )
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
        token=hugging_face_token
    )

    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
        config.pad_token_id = tokenizer.pad_token_id
        
    return model, tokenizer

In [None]:
def apply_spmd(*, model, mesh):
    # Apply on layers within model.
    model_partitioning_util.partition_model(model, mesh)

# Configure dataset pipeline for your model

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [None]:
def get_dataset(*, tokenizer, batch_size=None, max_length=None, debug_mode=False):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    # Tokenize the dataset.
    def _tokenize(examples):
        # Tokenized is list within list. Compute labels for causalLM by shifting input_id; 
        # consequently truncate input_id to penultimate position.
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512+1 if not max_length else max_length+1)
        labels = tokenized['input_ids'].copy()
        tokenized['labels'] = [label[1:] for label in labels]
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        return tokenized

    # Load and preprocess the dataset.
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if debug_mode:
        dataset = dataset.select(range(32)) # Use just 32 exampfor faster iteration
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoader
    train_dataloader = torch.utils.data.DataLoader(
        ds['train'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=default_data_collator,
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        ds['test'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=default_data_collator,
    )

    return train_dataloader, test_dataloader

# Train the model

Now let's train the model. We are using PyTorch XLA's Fully Sharded Data Parallel (FSDP) to distribute the model across the 8 TPU cores available on TPU v3-8. This approach allows for efficient training on TPU hardware. We also utilize PyTorch/XLA's MpDeviceLoader to efficiently load data onto the TPU cores.

**NOTE:** It's important to note that the **first step of training will be slow**. This is because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using compiled+cached graph, and leveraging the full power of the all TPU cores for accelerated training.


In [None]:
def print_training_update(step,
                          loss,
                          epoch=None):
    """Prints the training metrics at a given step."""
    if xm.is_master_ordinal():  # Only print on the master device
        update_data = [
            'Training',
            f'Epoch={epoch}' if epoch is not None else None,
            f'Step={step}',
            f'Loss={loss:.5f}',
        ]
        print(' | '.join(item for item in update_data if item), flush=True)
        print()


In [None]:
model, tokenizer = init_model(
        model_name=MODEL_NAME, hugging_face_token=HUGGINGFACE_TOKEN
)
model = xmp.MpModelWrapper(model)

In [None]:
def train(index):
    global model, tokenizer

    torch.manual_seed(99)
    device = xm.xla_device()
    model = model.to(device)

    # Create a mesh for the model partitioning.
    num_devices = xr.global_runtime_device_count()
    mesh_shape = (1, num_devices, 1)
    device_ids = np.array(range(num_devices))
    mesh = Mesh(device_ids, mesh_shape, ("dp", "fsdp", "mp"))
    
    # Partition the model using SPMD.
    model_partitioning.partition_model(model=model, mesh=mesh)
    
    # Configure the training loop.
    optimizer = torch.optim.Adam(model.parameters(), lr=TRAINER_CONFIG["lr"])

    train_dataloader, test_dataloader = get_dataset(
        tokenizer=tokenizer,
        batch_size=TRAINER_CONFIG["batch_size"],
        max_length=TRAINER_CONFIG["max_length"],
    )
    train_dataloader = pl.MpDeviceLoader(
        train_dataloader, 
        device
    ) 
    test_dataloader = pl.MpDeviceLoader(
        test_dataloader, 
        device
    )

    for epoch in range(TRAINER_CONFIG["epochs"]):
        xm.master_print(f"Epoch {epoch} train begin {test_utils.now()}")
        tracker = xm.RateTracker()
        
        model.train()
        for step, batch in enumerate(train_dataloader):
            if step>1:
                break
                
            optimizer.zero_grad()
            
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding(attention_mask, mesh, (0, 1))
            xs.mark_sharding(labels, mesh, (0, 1))
            
            output = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            loss = output.loss
            loss.backward()
            
            optimizer.step()
            xm.mark_step()

            if step%TRAINER_CONFIG["logging_interval"]==0:
                loss_cpu = loss.item()
                xm.add_step_closure(
                    print_training_update,
                    args=(step, loss_cpu, epoch)
                )

        # UNCOMMENT BELOW TO RUN EVAL.
        # model.eval()
        # eval_loss = 0
        # with torch.no_grad():
        #     for step, batch in enumerate(test_dataloader):
        #         input_ids, attention_mask, labels = (
        #             batch["input_ids"],
        #             batch["attention_mask"],
        #             batch["labels"],
        #         )
        #         xs.mark_sharding(input_ids, mesh, (0, 1))
        #         xs.mark_sharding(attention_mask, mesh, (0, 1))
        #         xs.mark_sharding(labels, mesh, (0, 1))
                
        #         output = model(
        #             input_ids=input_ids, attention_mask=attention_mask, labels=labels
        #         )
        #         eval_loss += output.loss.item()
        # avg_eval_loss = eval_loss / len(test_dataloader)
        # xm.add_step_closure(
        #     lambda: print(f"Eval loss: {avg_eval_loss:.4f}"),
        # )
    result = {'device': xm.get_ordinal(), 'loss': loss.item()}
    print(f"Finished training!")
    return result

In [None]:
try:
    xmp.spawn(train, args=(), start_method="fork")
except Exception as e:
    # Catch the expected error of obtaining results from multiple TPU chips when starting distributed training from a notebook.
    print()

# Export the model to HuggingFace Hub
Uncoment the following cell to push the model to HuggingFace Hub.

In [None]:
import os
from huggingface_hub import login

login(token=HUGGINGFACE_TOKEN)

In [None]:
HUGGINGFACE_USERNAME = input("Please provide your HUGGINGFACE_USERNAME: ")
LOCAL_SAVE_DIR = "/mnt/persistent-disk/felafax-finetuned-model/"  

model = model.to(device='cpu')
merged_model = model.merge_and_unload()

print("Saving model locally...")
merged_model.save_pretrained(
    LOCAL_SAVE_DIR,
    max_shard_size="5GB",
    safe_serialization=True
)


print("Uncomment below code if you want to upload to HF.")
# print("Uploading to HF...")
# from huggingface_hub import HfApi
# api = HfApi()
# api.upload_folder(
#     folder_path=LOCAL_SAVE_DIR,
#     repo_id=f"{HUGGINGFACE_USERNAME}/felafax-llama3-finetuned-70B",
#     repo_type="model",
#     ignore_patterns=[".*"],
#     token=HUGGINGFACE_TOKEN
# )