Skip to content

[BUG] train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size #3982

@vecorro

Description

@vecorro

Describe the bug

When using DeepSpeed 0.10.0 (or version > 0.8.2) with Ray 2.5.1 I get the following error when trying to run a job on 3 Reay workers:

AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 9 != 1 * 3 * 1

To Reproduce
Steps to reproduce the behavior:

Run the following script on a Ray cluster with 3 nodes, each hosting 1 NVIDIA GPU A100 (40GB)
ray job submit --address <head-IP> --working-dir ./src/ --runtime-env-json='{"pip": ["torch==2.0.1","transformers==4.30.2","deepspeed==0.10.0", "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.39.1","datasets==2.13.0", "einops==0.6.1"]}' -- python finetune.py --num-workers=3 --model="tiiuae/falcon-7b"

Expected behavior
I expected this release to include a fix to #3228

ds_report output
I cannot run ds_repost as I'm using a Ray cluster, but the issue seems to be the same as #3228, although I'm using Xero stage 3.

DS config

deepspeed_cfg = {'fp16': {'enabled': 'auto',
                     'loss_scale': 0,
                     'loss_scale_window': 1000,
                     'initial_scale_power': 16,
                     'hysteresis': 2,
                     'min_loss_scale': 1},
            'optimizer': {'type': 'AdamW',
                          'params': {'lr': 'auto',
                                     'betas': 'auto',
                                     'eps': 'auto',
                                     'weight_decay': 'auto'}},
            'scheduler': {'type': 'WarmupLR',
                          'params': {'warmup_min_lr': 'auto',
                                     'warmup_max_lr': 'auto',
                                     'warmup_num_steps': 'auto',
                                     'warmup_type': 'linear'}},
            'zero_optimization': {'stage': 3,
                                  'offload_optimizer': {'device': 'cpu', 'pin_memory': False},
                                  'offload_param': {'device': 'cpu', 'pin_memory': False},
                                  'overlap_comm': False,
                                  'contiguous_gradients': True,
                                  'reduce_bucket_size': 'auto',
                                  'stage3_prefetch_bucket_size': 'auto',
                                  'stage3_param_persistence_threshold': 'auto',
                                  'sub_group_size': 1000000000.0,
                                  'stage3_max_live_parameters': 1000000000.0,
                                  'stage3_max_reuse_distance': 1000000000.0,
                                  'stage3_gather_16bit_weights_on_model_save': 'auto'},
            'gradient_accumulation_steps': 'auto',
            'gradient_clipping': False,
            'steps_per_print': 10,
            'train_batch_size': 'auto',
            'train_micro_batch_size_per_gpu': 'auto',
            'wall_clock_breakdown': False}

If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: Ray image rayproject/ray:nightly-py39-cu118
  • GPU count and types: 3 Ray 2.5.1 workers each with 1 x NVIDIA A100 (40 GB) GPU
  • Python version: 3.9
  • CUDA 11.8
  • Ray working env: --runtime-env-json='{"pip": ["torch==2.0.1","transformers==4.30.2","deepspeed==0.10.0", "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.39.1","datasets==2.13.0", "einops==0.6.1"]}'

Additional context
SCRIPT

import os

os.environ['LD_LIBRARY_PATH'] = f"/usr/local/cuda/lib64:{os.environ['LD_LIBRARY_PATH']}"
os.environ['PATH'] = f"/usr/local/cuda/bin:{os.environ['PATH']}"
print(f"LD_LIBRARY_PATH: {os.environ['LD_LIBRARY_PATH']}")
print(f"PATH: {os.environ['PATH']}")

import ray
import argparse
import torch
import json
import pandas as pd
from pprint import pprint
from datasets import load_dataset
from ray.data.preprocessors import BatchMapper
from ray.train.huggingface import TransformersTrainer
from ray.air.config import ScalingConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments
from transformers import Trainer
from transformers import DataCollatorForLanguageModeling
from peft import prepare_model_for_int8_training
from peft import LoraConfig, get_peft_model
from ray.air import session

# LoRA Configuration
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
# LORA_TARGET_MODULES = ['query_key_value']
LORA_TARGET_MODULES = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
]


def load_model(model_id, lora_config):
    # Load Model
    model = AutoModelForCausalLM.from_pretrained(
            model_id,
            load_in_8bit=True,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            #device_map="auto"
    )
    model.config.use_cache = False
    model = prepare_model_for_int8_training(model)
    model = get_peft_model(model, lora_config)
    return model


def trainer_init_per_worker(train_dataset, eval_dataset=None, **config):
    print(f"Is CUDA available? {torch.cuda.is_available()}")

    # Use the actual number of CPUs assigned by Ray
    os.environ["OMP_NUM_THREADS"] = str(
            session.get_trial_resources().bundles[-1].get("CPU", 1)
    )

    # Enable tf32 for better performance
    torch.backends.cuda.matmul.allow_tf32 = True

    # Load tokenizer
    model_id = config.get('model')
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    # Apply LoRA
    lora_config = LoraConfig(
            r=config.get('lora').get('r'),
            lora_alpha=config.get('lora').get('alpha'),
            target_modules=config.get('lora').get('target_modules'),
            lora_dropout=config.get('lora').get('dropout'),
            bias="none",
            task_type="CAUSAL_LM",
    )

    # Training Configuration
    training_args = TrainingArguments(
            per_device_train_batch_size=1,
            # auto_find_batch_size=True,
            gradient_accumulation_steps=3,
            num_train_epochs=1,
            learning_rate=2e-4,
            fp16=False,
            bf16=False,
            save_total_limit=4,
            logging_steps=5,
            save_strategy='steps',
            weight_decay=0,
            push_to_hub=False,
            disable_tqdm=True,
            no_cuda=not config.get('platform').get('use_gpu'),
            gradient_checkpointing=True,
            output_dir="./outputs_ray",
            ddp_find_unused_parameters=False,
            deepspeed=config.get('platform').get('deepspeed'),
            max_steps=40,
    )

    # Data Collator
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # Configure the Trainer
    trainer = Trainer(
            model=load_model(model_id, lora_config),
            train_dataset=train_dataset,
            args=training_args,
            data_collator=data_collator
    )

    return trainer


def prepare_dataset(path, model_id):
    dataset = load_dataset(path, split="train")

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    dataset_prompts = {}
    dataset_prompts['text'] = []

    def generate_prompt(data_point):
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately 
    completes the request.  # noqa: E501
    ### Instruction:
    {data_point["instruction"]}
    ### Input:
    {data_point["input"]}
    ### Response:
    {data_point["output"]}"""

    for data_point in dataset:
        prompt = generate_prompt(data_point)
        dataset_prompts['text'].append(prompt)

    # Transform to Ray dataset format
    dataset_prompts_df = pd.DataFrame.from_dict(dataset_prompts)
    dataset_ray = ray.data.from_pandas(dataset_prompts_df)

    return dataset_ray


def prepare_batch_mapper(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    def preprocess_function(batch):
        ret = tokenizer(list(batch["text"]), padding=True, truncation=True, return_tensors="np")
        return dict(ret)

    batch_mapper = BatchMapper(preprocess_function, batch_format="pandas")

    return batch_mapper


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-workers", type=int, default=2, help="Sets number of workers for training.")
    parser.add_argument("--use-cpu", action="store_true", default=False, help="Enables CPU training")
    parser.add_argument("--no-deepspeed", action="store_true", default=False, help="Disables DeepSpeed strategy")
    parser.add_argument('--model', action='store', type=str, default="tiiuae/falcon-7b",
                        help='Model from HuggingFace to use')
    parser.add_argument('--data', action='store', type=str, default="yahma/alpaca-cleaned",
                        help='Path of the data to use for finetuning the model')

    args, _ = parser.parse_known_args()

    # Get DeepSpeed config
    if not args.no_deepspeed:
        deepspeed = {
                "fp16": {
                        "enabled": "auto",
                        "loss_scale": 0,
                        "loss_scale_window": 1000,
                        "initial_scale_power": 16,
                        "hysteresis": 2,
                        "min_loss_scale": 1
                },

                "optimizer": {
                        "type": "AdamW",
                        "params": {
                                "lr": "auto",
                                "betas": "auto",
                                "eps": "auto",
                                "weight_decay": "auto"
                        }
                },

                "scheduler": {
                        "type": "WarmupLR",
                        "params": {
                                "warmup_min_lr": "auto",
                                "warmup_max_lr": "auto",
                                "warmup_num_steps": "auto"
                        }
                },

                "zero_optimization": {
                        "stage": 3,
                        "offload_optimizer": {
                                "device": "cpu",
                                "pin_memory": False
                        },
                        "offload_param": {
                                "device": "cpu",
                                "pin_memory": False
                        },
                        "overlap_comm": False,
                        "contiguous_gradients": True,
                        "sub_group_size": 1e9,
                        "reduce_bucket_size": "auto",
                        "stage3_prefetch_bucket_size": "auto",
                        "stage3_param_persistence_threshold": "auto",
                        "stage3_max_live_parameters": 1e9,
                        "stage3_max_reuse_distance": 1e9,
                        "stage3_gather_16bit_weights_on_model_save": True
                },

                "gradient_accumulation_steps": "auto",
                "gradient_clipping": False,
                "steps_per_print": 10,
                "train_batch_size": "auto",
                "train_micro_batch_size_per_gpu": "auto",
                "wall_clock_breakdown": False,
        }

    # Init Ray cluster
    ray.init(address="auto")
    print(f" Ray CLuster resources:\n {ray.cluster_resources()}")

    # Prepare Ray dataset and batch mapper
    dataset = prepare_dataset(args.data, args.model)
    batch_mapper = prepare_batch_mapper(args.model)

    # Trainer init config
    trainer_init_config = {"model": args.model,
                           "lora": {"r": LORA_R,
                                    "alpha": LORA_ALPHA,
                                    "dropout": LORA_DROPOUT,
                                    "target_modules": LORA_TARGET_MODULES
                                    },
                           "platform": {"use_gpu": not args.use_cpu,
                                        "deepspeed": deepspeed if not args.no_deepspeed else None}
                           }

    trainer = TransformersTrainer(
            trainer_init_per_worker=trainer_init_per_worker,
            trainer_init_config=trainer_init_config,
            scaling_config=ScalingConfig(num_workers=args.num_workers, use_gpu=not args.use_cpu),
            datasets={
                    "train": dataset
            },
            preprocessor=batch_mapper,
    )

    # Launch the training on the cluster
    result = trainer.fit()

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions