Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fsdp inference checkpoints #39

Merged
merged 22 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ The inference folder also includes a chat completion example, that adds built-in
python chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chats.json --quantization --use_auditnlg

```
## Loading back FSDP checkpoints

In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
**To convert the checkpoint use the following command**:
```bash
python checkpoint_converter_fsdp_hf.py --model_name PATH/to/FSDP/Checkpoints --save_dir PATH/to/save/checkpoints --model_path PATH/or/HF/model_name
chauhang marked this conversation as resolved.
Show resolved Hide resolved
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved

# --model_path specifies the HF Llama model name or path where it has config.json and tokenizer.json
```

## Other Inference Options

Expand Down
2 changes: 1 addition & 1 deletion inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:

1. [inference script](inference.py) script provides support for Hugging Face accelerate and PEFT fine tuned models.
1. [inference script](inference.py) script provides support for Hugging Face accelerate, PEFT and FSDP fine tuned models.

2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.

Expand Down
13 changes: 11 additions & 2 deletions inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@

from transformers import LlamaTokenizer
from safety_utils import get_safety_checker
from model_utils import load_model, load_peft_model
from model_utils import load_model, load_peft_model, load_llama_from_config
from accelerate import init_empty_weights
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved
# Get the current file's directory
current_directory = os.path.dirname(os.path.abspath(__file__))

# Get the parent directory
parent_directory = os.path.dirname(current_directory)

# Append the parent directory to sys.path
sys.path.append(parent_directory)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the code for path setting inside main / function

from model_checkpointing import load_sharded_model_single_gpu
chauhang marked this conversation as resolved.
Show resolved Hide resolved

def main(
model_name,
Expand Down Expand Up @@ -83,7 +92,7 @@ def main(
if peft_model:
model = load_peft_model(model, peft_model)

model.eval()
# model.eval()
chauhang marked this conversation as resolved.
Show resolved Hide resolved

batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()}
Expand Down
12 changes: 10 additions & 2 deletions inference/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from peft import PeftModel
from transformers import LlamaForCausalLM
from transformers import LlamaForCausalLM, LlamaConfig

# Function to load the main model for text generation
def load_model(model_name, quantization):
Expand All @@ -19,4 +19,12 @@ def load_model(model_name, quantization):
# Function to load the PeftModel for performance optimization
def load_peft_model(model, peft_model):
peft_model = PeftModel.from_pretrained(model, peft_model)
return peft_model
return peft_model

# Loading the model from config to load FSDP checkpoints into that
def load_llama_from_config(config_path):
model_config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config=model_config)
return model


1 change: 1 addition & 0 deletions model_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
load_model_sharded,
load_sharded_model_single_gpu
)
30 changes: 30 additions & 0 deletions model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,33 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
save_state_dict(state_dict, writer)

return

def load_sharded_model_single_gpu(model, model_path):

dcp.load_state_dict(
state_dict=state_dict_to_load_to,
storage_reader=FsspecReader(path),
no_dist=True,
)
print(f"Sharded state checkpoint loaded from {load_dir}")

def load_sharded_model_single_gpu(model,model_path):

reader = FileSystemReader(model_path)

state_dict = {
"model": model.state_dict()
}

dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader= FileSystemReader(model_path),
no_dist=True,
)

ck = state_dict["model"].keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {state_dict.keys()}")
model.load_state_dict(state_dict["model"])

print(f"Sharded state checkpoint loaded from {model_path}")
return model