Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tuning with examples/pytorch/language-modeling/run_clm.py on torch/XLA + FSDP produce abnormal models #27432

Closed
2 of 4 tasks
totorochina opened this issue Nov 10, 2023 · 0 comments · Fixed by #27652
Closed
2 of 4 tasks

Comments

@totorochina
Copy link

System Info

  • transformers version: 4.36.0.dev0
  • Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.19.0
  • Safetensors version: 0.4.0
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0a0+gitcc01568 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no, use TPU with torch_xla
  • Using distributed or parallel set-up in script?: yes, using flags for xla_fsdp

Who can help?

@muellerzr @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I was following this two blogs/docs,
https://pytorch.org/blog/large-scale-training-hugging-face/
https://huggingface.co/docs/transformers/main_classes/trainer#pytorchxla-fully-sharded-data-parallel

  1. Create a v3-8 TPU vm on google cloud and login
export PROJECT=<project_name>
export REGION=<region>
export ZONE=<tpu_vm_instance_zone>
export VPC=<vpc_name>
export SUBNET=<vpc_subnet>
export TPUVM=<tpu_vm_instance_name>
export TYPE=v3-8
export IMAGE=tpu-vm-pt-2.0

gcloud compute tpus tpu-vm create ${TPUVM} \
    --zone=${ZONE} \
    --accelerator-type=${TYPE} \
    --version=${IMAGE} \
    --network=${VPC} \
    --subnetwork="projects/${PROJECT}/regions/${REGION}/subnetworks/${SUBNET}" \
    --internal-ips

gcloud alpha compute tpus tpu-vm ssh ${TPUVM} --zone=${ZONE} --tunnel-through-iap
  1. Update with latest torch_xla nightly
sudo apt update -y && sudo apt upgrade -y
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly-cp38-cp38-linux_x86_64.whl
  1. Install latest transformers. I tried with latest main and v4.31-release branches, accelerate==0.21.0 and latest as they were used in the blog, the problem remained the same.
cd $HOME
# git clone -b v4.31-release https://github.com/huggingface/transformers.git
git clone https://github.com/huggingface/transformers.git
cd transformers
# For Python 3.8
pip install -e .
pip install datasets evaluate scikit-learn accelerate py7zr
  1. Prepare llama2_fsdp_config.json and copy to home folder. Login with HF token.
# huggingface-cli login --token <YOUR_HF_TOKEN>
# llama2_fsdp_config.json
{
    "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
    ],
    "xla": true,
    "xla_fsdp_settings": {
        "compute_dtype": "bfloat16",
        "shard_param_on_dim_0": true,
        "pin_layout_in_collective_ops": true
    },
   "xla_fsdp_grad_ckpt": true
}
  1. Run run_clm.py with xla_spawn.py, set flag --model_name_or_path to fine tune instead of training from scratch
export PJRT_DEVICE=TPU
nohup python3 -u examples/pytorch/xla_spawn.py --num_cores 8 examples/pytorch/language-modeling/run_clm.py \
    --model_name_or_path "meta-llama/Llama-2-7b-hf" \
    --num_train_epochs 3 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --per_device_train_batch_size 6 \
    --per_device_eval_batch_size 6 \
    --do_train \
    --do_eval \
    --output_dir /tmp/llama-2-7b-hf-ft-xla \
    --overwrite_output_dir \
    --cache_dir /tmp \
    --block_size 2048 \
    --optim adafactor \
    --save_strategy no \
    --logging_strategy no \
    --gradient_checkpointing \
    --fsdp "full_shard" \
    --fsdp_config ~/llama2_fsdp_config.json > run.log 2>&1 &
  1. On a CUDA device, load the fine tuned model and inference
from transformers import AutoTokenizer, LlamaTokenizer
import transformers
import torch
model = "~/llama-2-7b-hf-ft-xla"
tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
sequences = pipeline(
    ['I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n'] * 1,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    max_length=200,
)
print(sequences)

It would get Runtime error for fine tuned Llama2-7B,

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

I also tried with GPT2, with GPT2, the model can be loaded and used for inference, however it would produce garbage outputs like

[{'generated_text': 'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n contends Creator smiling reminiscentoffset prophets contends contends Sheffield contends wetlandslocked maximizing maximizing WIratorct continuity=- ...'}]

For both fine tuned Llama2-7B & GPT2, I will get this kind of warnings during instantiating transformers.pipeline

Some weights of the model checkpoint at /home/hzchen/scripts/llm/gpt-ft-test were not used when initializing GPT2LMHeadModel: [<FSDP_LAYERS_OMITTED...>]
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at /home/hzchen/scripts/llm/gpt-ft-test and are newly initialized: [<LAYERS_OMITTED...>]
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

I also aware that the output fine tuned models have abnormally small size, e.g. fine tuned GPT2 get 60MB+ while origin 500MB+. Llama2 7B 3.2GB while origin 13GB, but fine tuned on CUDA will give 20+GB in size.

I also tried with accelerate + FSDP on 8*L4 GPU, everything worked fine with the same configs, that made me believe the problem is on XLA+FSDP.

Below is how I ran successfully on CUDA devices,

# fsdp_config.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: NO_PREFETCH
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
nohup accelerate launch --config_file ~/fsdp_config.yaml examples/pytorch/language-modeling/run_clm.py \
    --model_name_or_path "meta-llama/Llama-2-7b-hf" \
    --num_train_epochs 3 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --do_train \
    --do_eval \
    --output_dir /tmp/llama-2-7b-hf-ft-cuda \
    --overwrite_output_dir \
    --cache_dir /tmp \
    --block_size 2048 \
    --optim adafactor \
    --save_strategy no \
    --logging_strategy no \
    --gradient_checkpointing > run.log 2>&1 &

Expected behavior

The output fine tuned models using XLA+FSDP on TPU should be usable, like what it does on Accelerate+FSDP on GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant