Copyright (c) Meta Platforms, Inc. and affiliates.
This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

## Quick Start Notebook

This notebook shows how to train a Llama 2 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA.

### Step 0: Install pre-requirements and convert checkpoint

The example uses the Hugging Face trainer and model which means that the checkpoint has to be converted from its original format into the dedicated Hugging Face format.
The conversion can be achieved by running the `convert_llama_weights_to_hf.py` script provided with the transformer package.
Given that the original checkpoint resides under `models/7B` we can install all requirements and convert the checkpoint with:

In [None]:
%cd "/users/swang299/code/AntGPT-Llama2/llama-recipes"

import os
import sys
from typing import List, Union

import fire
import torch
import transformers
from datasets import load_dataset
import os.path as osp
from tqdm import tqdm
from peft import PeftModel

# Unused imports removed
from utils import fsdp_auto_wrap_policy
from transformers import (
    LlamaForCausalLM,
    LlamaTokenizer,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    default_data_collator,
    BitsAndBytesConfig
)
import torch.distributed as dist
# Unused imports removed
from utils.train_utils import (
    set_tokenizer_params,
    train,
    evaluation,
    freeze_transformer_layers,
    check_frozen_layers_peft_model,
    setup,
    setup_environ_flags,
    cleanup,
    clear_gpu_cache,
    get_parameter_dtypes,
    print_model_size,
    get_policies  
)

from utils.dataset_utils import get_preprocessed_dataset

from utils.config_utils import (
    update_config,
    generate_peft_config,
    generate_dataset_config,
)
from peft import get_peft_model, TaskType, prepare_model_for_int8_training
import configs
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)
from torch.utils.data import DistributedSampler
import policies
from policies import AnyPrecisionAdamW
from configs import fsdp_config, train_config
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pkg_resources import packaging
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


# Update the configuration for the training and sharding process
kwags = {"use_peft":True, "peft_method":"lora", "quantization":True, "model_name":'/gpfs/data/superlab/models/llama2/llama/checkpoints/hf/Llama-2-7b-hf', "output_dir":"./", "dataset":"ego4d_lta_dataset"}
update_config((train_config, fsdp_config), **kwags)

# Set the seeds for reproducibility
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)

if train_config.enable_fsdp:
    setup()
    # torchrun specific
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
    torch.cuda.set_device(rank)
    setup_environ_flags(rank)

# Calculate gradient accumulation steps
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
    
# Load the pre-trained model and setup its configuration
model = LlamaForCausalLM.from_pretrained(
    train_config.model_name,
    load_in_8bit=True if train_config.quantization else None,
    device_map="auto" if train_config.quantization else None,
)

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
if train_config.quantization:
    model = prepare_model_for_int8_training(model)
    
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if train_config.enable_fsdp and fsdp_config.pure_bf16:
    model.to(torch.bfloat16)

# Load the tokenizer and add special tokens
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name, legacy=False)
    # tokenizer.add_special_tokens(
    #         {
    #             "pad_token": "<PAD>",
    #         }
    #     )
if train_config.use_peft:
    # peft_config = generate_peft_config(train_config, kwags)
    # model = get_peft_model(model, peft_config)
    # model.print_trainable_parameters()
    model = PeftModel.from_pretrained(model, 'peft_ckpt/ego4d_lta/lora/7B/0')
    # pass
#setting up FSDP if enable_fsdp is enabled
if train_config.enable_fsdp:
    if not train_config.use_peft and train_config.freeze_layers:
        
        freeze_transformer_layers(train_config.num_freeze_layers)

    mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
    my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

    model = FSDP(
        model,
        auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
        mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
        sharding_strategy=fsdp_config.sharding_strategy,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
    )
    if fsdp_config.fsdp_activation_checkpointing:
        policies.apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
    model.to("cuda")



In [None]:

dataset_config = generate_dataset_config(train_config, kwags)

    # Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
    tokenizer,
    dataset_config,
    split="train",
)

if not train_config.enable_fsdp or rank == 0:
    print(f"--> Training Set Length = {len(dataset_train)}")

dataset_val = get_preprocessed_dataset(
    tokenizer,
    dataset_config,
    split="test",
)
if not train_config.enable_fsdp or rank == 0:
        print(f"--> Validation Set Length = {len(dataset_val)}")

train_sampler = None
val_sampler = None
if train_config.enable_fsdp:
    train_sampler = DistributedSampler(
        dataset_train,
        rank=dist.get_rank(),
        num_replicas=dist.get_world_size(),
        shuffle=True,
    )
    if train_config.run_validation:
        val_sampler = DistributedSampler(
            dataset_val,
            rank=dist.get_rank(),
            num_replicas=dist.get_world_size(),
        )
    
# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=train_config.batch_size_training,
    num_workers=train_config.num_workers_dataloader,
    pin_memory=True,
    sampler=train_sampler if train_sampler else None,
    drop_last=True,
    collate_fn=default_data_collator,
)

if train_config.run_validation:
    eval_dataloader = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=train_config.val_batch_size,
        num_workers=train_config.num_workers_dataloader,
        pin_memory=True,
        sampler=val_sampler if val_sampler else None,
        drop_last=True,
        collate_fn=default_data_collator,
    )
    
# Initialize the optimizer and learning rate scheduler
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
    optimizer = AnyPrecisionAdamW(
        model.parameters(),
        lr=train_config.lr,
        momentum_dtype=torch.bfloat16,
        variance_dtype=torch.bfloat16,
        use_kahan_summation=False,
    )
else:
    optimizer = optim.AdamW(
        model.parameters(),
        lr=train_config.lr,
        weight_decay=0.0,
    )
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)


In [None]:
from utils.memory_utils import MemoryTrace

def myevaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
    """
    Evaluates the model on the given dataloader
    
    Args:
        model: The model to evaluate
        eval_dataloader: The dataloader containing the evaluation data
        local_rank: The rank of the current node in a distributed setting
        tokenizer: The tokenizer used to decode predictions
    
    Returns: eval_ppl, eval_epoch_loss
    """
    model.eval()
    eval_preds = []
    eval_loss = 0.0  # Initialize evaluation loss
    eval_dataset_len = 0
    with MemoryTrace() as memtrace:
        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
            # print(batch)
            input_txt = tokenizer.decode(batch["input_ids"][0].cpu().numpy(), skip_special_tokens=True)
            # print(input_txt)
            for key in batch.keys():
                if train_config.enable_fsdp:
                    batch[key] = batch[key].to(local_rank)
                else:
                    batch[key] = batch[key].to('cuda:0')
            # Ensure no gradients are computed for this scope to save memory
            with torch.no_grad():
                # Forward pass and compute loss
                outputs = model(**batch)
                loss = outputs.loss
                eval_loss += loss.detach().float()
                first_key = next(iter(batch))
                eval_dataset_len+= len(batch[first_key])
                
            # Decode predictions and add to evaluation predictions list
            preds = torch.argmax(outputs.logits, -1)
            # print("***********************")
            print(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)[0])
            print("-------------------------")
            eval_preds.extend(
                tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
            )
    
    # If there's more than one CUDA device, reduce evaluation loss across all devices
    if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
    
    # Compute average loss and perplexity
    eval_epoch_loss = eval_loss / eval_dataset_len
    eval_ppl = torch.exp(eval_epoch_loss)
    
    # Print evaluation metrics
    print(f" {eval_ppl=} {eval_epoch_loss=}")
    return eval_ppl, eval_epoch_loss


In [None]:
# eval before training
eval_results = myevaluation(
    model, 
    train_config, 
    eval_dataloader, 
    local_rank if train_config.enable_fsdp else None, 
    tokenizer
)
print(f"{eval_results}")

In [None]:
%cd "/users/swang299/code/AntGPT-Llama2/llama-recipes"
import fire
import torch
import os
import sys
import time
from typing import List

from transformers import LlamaTokenizer
from inference.safety_utils import get_safety_checker
from inference.model_utils import load_model, load_peft_model

model_name = '/gpfs/data/superlab/models/llama2/llama/checkpoints/hf/Llama-2-7b-hf'
peft_model = 'ego4d_lta/7B'
quantization = True
max_new_tokens =100 #The maximum numbers of tokens to generate
seed=42 #seed value for reproducibility
do_sample=True #Whether or not to use sampling ; use greedy decoding otherwise.
min_length=None #The minimum length of the sequence to be generated, input prompt + min_new_tokens
use_cache=True  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
top_p=1.0 # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature=1.0 # [optional] The value used to modulate the next token probabilities.
top_k=50 # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty=1.0 #The parameter for repetition penalty. 1.0 means no penalty.
length_penalty=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
enable_azure_content_safety=False, # Enable safety check with Azure content safety api
enable_sensitive_topics=False, # Enable check for sensitive topics using AuditNLG APIs
enable_saleforce_content_safety=False, # Enable safety check woth Saleforce safety flan t5

kwargs = {}
# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

"""
model = load_model(model_name, quantization)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
    {
        
        "pad_token": "<PAD>",
    }
)


if peft_model:
    model = load_peft_model(model, peft_model)

model.eval()
    
"""

In [None]:
user_prompt = "Predict the next most possible 20 actions in the format of verb noun pair in chronological order that match the given observed 8 actions and common sense most. Below is the observed 8 actions.\n\n### Observed actions: attach pump, hold phone, hold phone, hold screw, hold screwdriver, put screwdriver, hold drill, put drill\n\n### Prediction: "

batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()}
with torch.no_grad():
    # outputs = model.generate(**batch, max_new_tokens=100)
    outputs = model.generate(
        **batch,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_p=top_p,
        temperature=temperature,
        min_length=min_length,
        use_cache=use_cache,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        **kwargs 
    )
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Model output:\n{output_text}")

In [None]:
user_prompt = "Predict the next most possible 20 actions in the format of verb noun pair in chronological order that match the given observed 8 actions and common sense most. Below is the observed 8 actions.\n\n### Observed actions: take shirt, adjust shirt, take shirt, take shirt, take shirt, adjust shirt, adjust shirt, take iron\n\n### Prediction: "
user_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Describe a time when you had to make a difficult decision.

### Response:
"""
batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()}
print(batch['input_ids'].shape)
with torch.no_grad():
    # Forward pass and compute loss
    outputs = model(**batch)    
# Decode predictions and add to evaluation predictions list
print(outputs.logits.shape)
preds = torch.argmax(outputs.logits, -1)
print(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=False)[0])

In [None]:
tokenizer.pad_token_id