# Finetune Large Language Model step by step


This notebook shows how to fine-tune a large language model (LLM) on a single system with 4 GPUs using Pytorch Fully Sharded Data Parallel (FSDP). For demonstration purposes, I use the Llama2-7B model and the IMDB sentiment analysis task, which reformulated as a generation task. 

By training Llama2-7B for one epoch (~25 minutes) in this notebook, we achieve <span style="color:#008bf8ff; font-weight: bold;">96.86%</span> accuracy, surpassing the best model on this leaderboard: [IMDB benchmark](https://paperswithcode.com/sota/sentiment-analysis-on-imdb). You can reproduce the result by clicking `Run All`. 

This notebook requires pytorch installed and it installs all other dependencies by itself.

****

This notebook covers the following steps:

- **Initializing the distributed environment**

- **Loading and exploring a dataset**

- **Designing a prompt to transform classification task into generation task**

- **Loading training data in a distributed manner**

- **Configuring the FSDP model wrapper**

- **Configuring FSDP activation checkpointing**

- **Saving sharded model weights to disk**

- **Training LLM with FSDP**

- **Evaluating generative LLM**

****

## Fully Sharded Data Parallel

Fully Sharded Data Parallel (FSDP) shards data, model parameters, gradients and optimizer states to train very large model with limited resources, inspired by [Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training](https://arxiv.org/abs/2004.13336) and [DeepSpeed Zero](https://arxiv.org/abs/1910.02054). Reference [FSDP doc](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) to understand it. Here we give a brief introduction to FSDP.


In standard [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) training, every worker processes a separate data batch with a whole model. The whole model weights, optimizer states and gradients replicated across all workers. For very large model, a single GPU may not have enough Memory to load all these tensors.

FSDP shards data same as DDP, it also shards model parameters, optimizer states and gradients. Let's say there are 4 workers, each worker only hold 1/4 pieces of all tensors.

FSDP also shards model vertically, it divides the whole model into FSDP units, during forward and backward passes, it executes unit by unit (typicall we put one layer into one unit for LLM). 


FSDP data and model split:

In [None]:
from IPython.display import Image
Image(filename='img/fsdp.jpg') 

## Install dependencies

In [None]:
!pip install transformers==4.34.0
!pip install accelerate==0.22.0
!pip install sentencepiece==0.1.99
!pip install datasets==2.14.4
!pip install seaborn==0.12.2

## Initialize distributed environment

We train LLM in a distributed way, so the first step is to initailize distributed environment. In below jupyter notebook cell, only one worker is initialied. At the end of this notebook, we lauch our training job with `torchrun` command to launch 4 workers.

In [None]:
import torch.distributed as dist

dist.init_process_group("nccl")
world_size = dist.get_world_size()
local_rank = dist.get_rank()

print(world_size, local_rank)

## Load IMDB dataset

Firstly let's load the IMDB dataset and do minimum EDA.

IMDB is a Large Movie Review Dataset. This is a dataset for binary sentiment classification containing substantially more data than previous benchmark datasets. It provides a set of 25,000 highly polar movie reviews for training, and 25,000 for testing. There is additional unlabeled data for use as well.

Reference [IMDB](https://huggingface.co/datasets/imdb) for more information.

In [None]:
from datasets import load_dataset
from collections import Counter
import seaborn as sns

In [None]:
MODEL_DIR = "llama2/models_hf/7B"

In [None]:
dataset = load_dataset('imdb')

In [None]:
dataset

In [None]:
num_words = list(map(lambda x: len(x.split(' ')), dataset['train']['text']))

In [None]:
max(num_words)

In [None]:
sns.histplot(num_words)

In [None]:
sum(num_words) / len(num_words)

In [None]:
def show_label_count(labels):
    label_names =  ['negative', 'positive']
    c = dict(Counter([label_names[x] for x in labels]))
    data = {
        'x': list(c.keys()),
        'y': list(c.values())
    }
    sns.barplot(data, x='x', y='y')

In [None]:
show_label_count(dataset['train']['label'])

In [None]:
show_label_count(dataset['test']['label'])

In [None]:
dataset['train'][0]

## Design a prompt to reformuate IMDB classification task to generation task

In [None]:
import copy
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from torch.utils.data import DataLoader

Splitting the prompts into 2 parts allows the truncation only being performed on the `text` part.

In [None]:
prompt_part1 = \
f'''Given a movie review/comment by a user in following format:
#### Movie review:
<review>
#### Answer:
<answer>
Please rate the movie review as positive or negative from the perspective of the user's overall personal feelings to the movie. Answer it with only 'positive' or 'negative' without any explanation.

#### Movie review:
{{text}}
'''

prompt_part2 = \
f'''
#### Answer: {{label}}'''

prompt_part2_inference = \
'''
#### Answer: '''

In [None]:
dataset['train'][0]

### The actual training data looks like this:

In [None]:
print(prompt_part1.format(text=dataset['train'][0]['text']) + prompt_part2.format(label='negative'))

### When we use the above text as training data, by default we train the model on every tokens, the loss of every token are equally used to train the model, this is not what we want, we want to finetune the model only on the label token.

### The actual data for inference:

In [None]:
print(prompt_part1.format(text=dataset['train'][0]['text']) + prompt_part2_inference)

In [None]:
tokenizer = LlamaTokenizer.from_pretrained(MODEL_DIR)
tokenizer.pad_token = tokenizer.bos_token

In [None]:
MAX_LENGTH = 512
# we only finetune using the loss at label position, ignore other labels.
IGNORE_INDEX = -100

label_names = ['negative', 'positive']

def preprocess_train(sample):
    part1 = prompt_part1.format(text=sample['text'])
    part2 = prompt_part2.format(label=label_names[sample['label']])
    
    tokenized = tokenizer('<s> ' + part1, part2, add_special_tokens=False, truncation='only_first', padding='max_length', max_length=MAX_LENGTH)
    
    labels = torch.tensor(copy.deepcopy(tokenized['input_ids']), dtype=torch.int64)
    actual_token_len = sum(tokenized['attention_mask'])

    labels[ :actual_token_len-1] = IGNORE_INDEX
    labels[actual_token_len:] = IGNORE_INDEX

    return {
        'input_ids': torch.tensor(tokenized['input_ids'], dtype=torch.int64),
        'attention_mask': torch.tensor(tokenized['attention_mask']),
        'labels': labels
    }



In [None]:
train_ds = dataset['train'].map(preprocess_train, remove_columns=['text', 'label']).with_format('torch')

In [None]:
train_ds[0]

In [None]:
train_ds

## Configure FSDP

We do following confiuration for FSDP:

* There are 32 decoder layers of Llama2 model, we wrap each layer into a FSDP unit, so that the model is sharded vertically.

* Activation checkpointing: which reruns the forward pass for each unit during backward, this allows the intermediate tensors can be released from GPU while training in layer by layer manner.

* Wrap the model into FSDP wrapper to enable FSDP for the model

In [None]:
from functools import partial
from dataclasses import dataclass
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing
)

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

def get_llama_wrapper():
    llama_auto_wrap_policy = partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            LlamaDecoderLayer,
        },
    )

    return llama_auto_wrap_policy

def apply_fsdp_checkpointing(model):
    print(f"--> applying fsdp activation checkpointing...")

    apply_activation_checkpointing(
        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
    )

@dataclass
class fsdp_config:
    mixed_precision: bool=True
    use_fp16: bool=False
    sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
    checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT
    fsdp_activation_checkpointing: bool=True
    pure_bf16: bool = False
    optimizer: str= "AdamW"

def setup_fdsp_model():
    model = LlamaForCausalLM.from_pretrained(MODEL_DIR, load_in_8bit=False, device_map=None, torch_dtype=torch.float16, use_cache=True)
    model.to(torch.bfloat16)
    model.config.pad_token_id = model.config.bos_token_id

    model = FSDP(
        model,
        auto_wrap_policy=get_llama_wrapper(),
        mixed_precision=None,
        sharding_strategy=fsdp_config.sharding_strategy,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
        sync_module_states=False,
        param_init_fn=None
    )
    apply_fsdp_checkpointing(model)


## Save model checkpoint during training

FSDP model weights spreads across all ranks, so we need all rank run into this function to get a full state, then only rank 0 save it onto disk.

In [None]:
from pathlib import Path

def save_model_checkpoint(
    model, 
    output_dir,
    rank
):
    """saving model via rank0 cpu streaming and full_state_dict"""

    with FSDP.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
    ):
        cpu_state = model.state_dict()

        print(f"saving process: rank {rank}  done w model state_dict\n")
   

    if rank == 0:
        print(f"--> saving model ...")
        save_dir = Path.cwd() / output_dir
        save_dir.mkdir(parents=True, exist_ok=True)
        save_full_path = str(save_dir) + "/pytorch_model.bin"

        # save model
        torch.save(cpu_state, save_full_path)
        
        print(f"model checkpoint saved at {save_full_path}\n")

## Define training loop

In [None]:
from tqdm import tqdm
from transformers.tokenization_utils_base import BatchEncoding

OUTPUT_DIR = './checkpoints_llama_1'
NUM_EPOCHS = 1

def train(model, train_dataloader, optimizer, local_rank):
    for epoch in range(NUM_EPOCHS):
        for step, x in tqdm(enumerate(train_dataloader), total=len(train_dataloader), disable=(local_rank!=0), desc=f'Epoch {epoch}/{NUM_EPOCHS}'):
            model.train()
            x = BatchEncoding(x).to(local_rank)

            loss = model(**x).loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if step % 50 == 0 and local_rank == 0:
                print('train loss:', loss.item()) 
    save_model_checkpoint(model, OUTPUT_DIR, local_rank)

## Launch the training

Now let's put all of above together and launch the training loop distributedly using torchrun command.

In [None]:
%%writefile train.py

import os
from pathlib import Path
from datasets import load_dataset
from collections import Counter
import copy
from tqdm import tqdm
from functools import partial
from dataclasses import dataclass

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.tokenization_utils_base import BatchEncoding
from transformers import LlamaForCausalLM, LlamaTokenizer, default_data_collator

fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)


MODEL_DIR = "llama2/models_hf/7B"
NUM_EPOCHS = 1
OUTPUT_DIR = './checkpoints_llama_1'
BATCH_SIZE = 16
LR = 2e-5

# initialize distributed environment
dist.init_process_group("nccl")
world_size = dist.get_world_size()
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
torch.cuda.empty_cache()

dataset = load_dataset('imdb')


prompt_part1 = \
f'''Given a movie review/comment by a user in following format:
#### Movie review:
<review>
#### Answer:
<answer>
Please rate the movie review as positive or negative from the perspective of the user's overall personal feelings to the movie. Answer it with only 'positive' or 'negative' without any explanation.

#### Movie review:
{{text}}
'''

prompt_part2 = \
f'''
#### Answer: {{label}}'''

prompt_part2_inference = \
'''
#### Answer: '''

tokenizer = LlamaTokenizer.from_pretrained(MODEL_DIR)
tokenizer.pad_token = tokenizer.bos_token

MAX_LENGTH = 512
# we only finetune using the loss at label position, ignore other labels.
IGNORE_INDEX = -100

label_names = ['negative', 'positive']

def preprocess_train(sample):
    part1 = prompt_part1.format(text=sample['text'])
    part2 = prompt_part2.format(label=label_names[sample['label']])
    
    tokenized = tokenizer('<s> ' + part1, part2, add_special_tokens=False, truncation='only_first', padding='max_length', max_length=MAX_LENGTH)
    
    labels = torch.tensor(copy.deepcopy(tokenized['input_ids']), dtype=torch.int64)
    actual_token_len = sum(tokenized['attention_mask'])

    labels[:actual_token_len-1] = IGNORE_INDEX
    labels[actual_token_len:] = IGNORE_INDEX

    return {
        'input_ids': torch.tensor(tokenized['input_ids'], dtype=torch.int64),
        'attention_mask': torch.tensor(tokenized['attention_mask']),
        'labels': labels
    }

def create_train_dataloader(train_ds, batch_size, local_rank):
    train_sampler = DistributedSampler(
        train_ds,
        rank=dist.get_rank(),
        num_replicas=dist.get_world_size(),
        shuffle=True,
    )

    train_dataloader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True,
        collate_fn=default_data_collator,
    )
    return train_dataloader


def get_llama_wrapper():
    llama_auto_wrap_policy = partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            LlamaDecoderLayer,
        },
    )

    return llama_auto_wrap_policy

def apply_fsdp_checkpointing(model):
    print(f"--> applying fsdp activation checkpointing...")

    apply_activation_checkpointing(
        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
    )

@dataclass
class fsdp_config:
    mixed_precision: bool=True
    use_fp16: bool=False
    sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
    checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT
    fsdp_activation_checkpointing: bool=True
    pure_bf16: bool = False
    optimizer: str= "AdamW"

def setup_fdsp_model():
    model = LlamaForCausalLM.from_pretrained(MODEL_DIR, load_in_8bit=False, device_map=None, torch_dtype=torch.float16, use_cache=True)
    model.to(torch.bfloat16)
    model.config.pad_token_id = model.config.bos_token_id

    model = FSDP(
        model,
        auto_wrap_policy=get_llama_wrapper(),
        mixed_precision=None,
        sharding_strategy=fsdp_config.sharding_strategy,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
        sync_module_states=False,
        param_init_fn=None
    )
    apply_fsdp_checkpointing(model)
    return model


def save_model_checkpoint(
    model, 
    output_dir,
    rank
):
    """saving model via rank0 cpu streaming and full_state_dict"""

    with FSDP.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
    ):
        cpu_state = model.state_dict()

        print(f"saving process: rank {rank}  done w model state_dict\n")
   

    if rank == 0:
        print(f"--> saving model ...")
        save_dir = Path.cwd() / output_dir
        save_dir.mkdir(parents=True, exist_ok=True)
        save_full_path = str(save_dir) + "/pytorch_model.bin"

        # save model
        torch.save(cpu_state, save_full_path)
        
        print(f"model checkpoint saved at {save_full_path}\n")

def train(model, train_dataloader, optimizer, local_rank):
    for epoch in range(NUM_EPOCHS):
        for step, x in tqdm(enumerate(train_dataloader), total=len(train_dataloader), disable=(local_rank!=0), desc=f'Epoch {epoch}/{NUM_EPOCHS}'):
            model.train()
            x = BatchEncoding(x).to(local_rank)

            loss = model(**x).loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if step % 50 == 0 and local_rank == 0:
                print('train loss:', loss.item()) 
    save_model_checkpoint(model, OUTPUT_DIR, local_rank)

train_ds = dataset['train'].map(preprocess_train, remove_columns=['text', 'label']).with_format('torch')
train_dataloader = create_train_dataloader(train_ds, BATCH_SIZE, local_rank)

model = setup_fdsp_model()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.0)
train(model, train_dataloader, optimizer, local_rank)


In [None]:
!NCCL_DEBUG=WARN torchrun --nnodes 1 --nproc_per_node 4 train.py

## Evaluation


During training, we only saved `pytorch_model.bin` for simiplicity, we need to copy model configuration files from pretrain model directory to checkpoint directory.

For this specific task, we only need to generate one token (`positive` or `negative`),  so calling model.forward is sufficient.

For the task we need to generate multiple tokens,  we use `model.generate` to replace `model.forward`.

In [None]:
!cp $MODEL_DIR/*.json $OUTPUT_DIR/

In [None]:
!ls $OUTPUT_DIR/

In [None]:
from transformers.tokenization_utils_base import BatchEncoding
import numpy as np

In [None]:
model = LlamaForCausalLM.from_pretrained(OUTPUT_DIR, load_in_8bit=False, device_map='cuda:0', torch_dtype=torch.float16, use_cache=True)
model.to(torch.bfloat16)
model.config.pad_token_id = model.config.bos_token_id

In [None]:
def preprocess_eval(sample):
    part1 = prompt_part1.format(text=sample['text'])
    
    return tokenizer('<s> ' + part1, prompt_part2_inference, add_special_tokens=False, truncation='only_first', padding='max_length', max_length=MAX_LENGTH)

In [None]:
dataset_eval = load_dataset('imdb', split='test')
eval_ds = dataset_eval.map(preprocess_eval, remove_columns=['text', 'label']).with_format('torch')
eval_dataloader = torch.utils.data.DataLoader(eval_ds,batch_size=32)

In [None]:
print(tokenizer.encode('negative positive', add_special_tokens=False))

In [None]:
result = []
with torch.no_grad():
    for x in tqdm(eval_dataloader, total=len(eval_dataloader)):
        model_inputs = BatchEncoding(x).to(0)
        outputs = model(**model_inputs)
        actual_token_len = model_inputs['attention_mask'].sum(-1).unsqueeze(1).unsqueeze(-1).expand(-1,-1,32000)
        gathered = torch.gather(outputs.logits.detach(), dim=1, index=actual_token_len-1).squeeze(1).cpu()
        result.append(torch.argmax(gathered.squeeze(1)[:, [8178, 6374]], dim=-1))

In [None]:
preds = torch.cat(result, dim=0).numpy()
labels = np.array(dataset_eval['label'])

In [None]:
(preds == labels).sum() / len(preds)