Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing [existing] default config for accelerator in trainer module #29993

Closed
1 of 4 tasks
b5y opened this issue Apr 2, 2024 · 5 comments · Fixed by #29997
Closed
1 of 4 tasks

Missing [existing] default config for accelerator in trainer module #29993

b5y opened this issue Apr 2, 2024 · 5 comments · Fixed by #29997
Assignees

Comments

@b5y
Copy link

b5y commented Apr 2, 2024

System Info

  • transformers version: 4.39.3
  • Platform: Linux-6.5.0-25-generic-x86_64-with-glibc2.35
  • Python version: 3.11.7
  • Huggingface_hub version: 0.20.2
  • Safetensors version: 0.4.1
  • Accelerate version: 0.28.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    • distributed_type: MULTI_GPU
    • mixed_precision: fp16
    • use_cpu: False
    • debug: False
    • num_processes: 2
    • machine_rank: 0
    • num_machines: 1
    • rdzv_backend: static
    • same_network: False
    • main_training_function: main
    • downcast_bf16: False
    • tpu_use_cluster: False
    • tpu_use_sudo: False
  • PyTorch version (GPU?): 2.1.2+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes (?)

Who can help?

@muellerzr and @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import logging
import os
from pathlib import Path

from data import TrainDatasetForEmbedding, EmbedCollator
from modeling import BiEncoderModel
from trainer import BiTrainer

from transformers import AutoConfig, AutoTokenizer
from transformers import (
    HfArgumentParser,
    set_seed,
)

from arguments import ModelArguments, DataArguments, RetrieverTrainingArguments as TrainingArguments


logger = logging.getLogger(__name__)

set_seed(42)

num_labels = 1
model_name_or_path = "BAAI/bge-m3"


model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments

data_args = DataArguments


tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    cache_dir=".cache_dir",
    use_fast=False,
)


config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    cache_dir=".cache_dir",
)


# normalizing
normlized = True
# pooling method: either cls or mean
sentence_pooling_method = 'cls' # or 'mean'
# share the negatives across all GPUs. This argument will extend the number of negatives.
negatives_cross_device = False
# It will influence the distribution of similarity scores.
temperature = 0.02
# use passages in the same batch as negatives. Default value is True.
use_inbatch_neg = True


model = BiEncoderModel(
    model_name=model_name_or_path,
    normlized=normlized,
    sentence_pooling_method=sentence_pooling_method,
    negatives_cross_device=negatives_cross_device,
    temperature=temperature,
    use_inbatch_neg=use_inbatch_neg,
)


data_args.train_data = os.path.abspath("data_split")

train_dataset = TrainDatasetForEmbedding(args=data_args, tokenizer=tokenizer)
training_args = TrainingArguments

training_args.per_device_train_batch_size = 256
training_args.train_group_size = 15
training_args.negatives_cross_device = negatives_cross_device
# select a appropriate for your model. Recommend 1e-5/2e-5/3e-5 for large/base/small-scale.
training_args.learning_rate = 3e-5
training_args.temperature = temperature
# instruction for query, which will be added to each query. You also can set it "" to add nothing to query.
training_args.query_instruction_for_retrieval = ""
# use passages in the same batch as negatives. Default value is True.
training_args.use_inbatch_neg = use_inbatch_neg


# max length for query. Please set it according the average length of queries in your data.
data_args.query_max_len = 33
# max length for passage. Please set it according the average length of passages in your data.
data_args.passage_max_len = 115


trainer = BiTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=EmbedCollator(
        tokenizer,
        query_max_len=data_args.query_max_len,
        passage_max_len=data_args.passage_max_len
    ),
    tokenizer=tokenizer
)

Expected behavior

Expected behavior should be fetching default accelerate config if it's not provided.

I've been trying to reproduce codebase from FlagEmbedding project in jupyter notebook.
It seems there is some problem with accelerator_config. The RetrieverTrainingArguments class is modified and looks like this:

import os
from dataclasses import dataclass, field
from typing import Optional, Dict

from transformers import TrainingArguments
from transformers.trainer_pt_utils import AcceleratorConfig

###.....
### Other classes are  the same as in arguments.py in the aforementioned link
###.....

@dataclass
class RetrieveAcceleratorConfig(AcceleratorConfig):
    split_batches: bool = field(
        default=False,
        metadata={
            "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
                    " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
                    " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
                    " in your script multiplied by the number of processes."
        },
    )
    dispatch_batches: bool = field(
        default=None,
        metadata={
            "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
                    " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
                    " underlying dataset is an `IterableDataslet`, `False` otherwise."
        },
    )
    even_batches: bool = field(
        default=True,
        metadata={
            "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
                    " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
                    " all workers."
        },
    )
    use_seedable_sampler: bool = field(
        default=True,
        metadata={
            "help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])."
                    "Ensures training results are fully reproducable using a different sampling technique. "
                    "While seed-to-seed results may differ, on average the differences are neglible when using"
                    "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
        },
    )

@dataclass
class RetrieverTrainingArguments(TrainingArguments):
    negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    temperature: Optional[float] = field(default=0.02)
    fix_position_embedding: bool = field(default=False,
                                         metadata={"help": "Freeze the parameters of position embeddings"})
    sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
    normlized: bool = field(default=True)
    use_inbatch_neg: bool = field(default=True, metadata={"help": "use passages in the same batch as negatives"})
    deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "deepspeed plugin to use"})
    debug: Optional[str] = field(
        default="",
        metadata={
            "help": (
                "Whether or not to enable debug mode. default is '', "
                "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
            )
        },
    )
    accelerator_config: Optional[dict] = field(
        default=None,
        metadata={
            "help": (
                "Config to be used with the internal Accelerator object initialization. The value is either a "
                "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
            )
        },
    )

And I am getting the following error in jupyter notebook:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[17], line 1
----> 1 trainer = BiTrainer(
      2     model=model,
      3     args=training_args,
      4     train_dataset=train_dataset,
      5     data_collator=EmbedCollator(
      6         tokenizer,
      7         query_max_len=data_args.query_max_len,
      8         passage_max_len=data_args.passage_max_len
      9     ),
     10     tokenizer=tokenizer
     11 )

File [~/anaconda3/envs/ai/lib/python3.11/site-packages/transformers/trainer.py:373](http://localhost:8888/BGE_M3/anaconda3/envs/ai/lib/python3.11/site-packages/transformers/trainer.py#line=372), in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    370 self.deepspeed = None
    371 self.is_in_train = False
--> 373 self.create_accelerator_and_postprocess()
    375 # memory metrics - must set up as early as possible
    376 self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)

File [~/anaconda3/envs/ai/lib/python3.11/site-packages/transformers/trainer.py:4255](http://localhost:8888/BGE_M3/anaconda3/envs/ai/lib/python3.11/site-packages/transformers/trainer.py#line=4254), in Trainer.create_accelerator_and_postprocess(self)
   4249 gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
   4251 # create accelerator object
   4252 self.accelerator = Accelerator(
   4253     deepspeed_plugin=self.args.deepspeed_plugin,
   4254     gradient_accumulation_plugin=gradient_accumulation_plugin,
-> 4255     **self.args.accelerator_config.to_dict(),
   4256 )
   4257 # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
   4258 self.gather_function = self.accelerator.gather_for_metrics

AttributeError: 'NoneType' object has no attribute 'to_dict'

Tested with transformers versions are v4.39.2 and v4.39.3.

UPDATED: Even if I put full path of default_config.yaml, error stays the same and 'NoneType' changes to 'str'.

@muellerzr muellerzr self-assigned this Apr 2, 2024
@muellerzr
Copy link
Contributor

Thanks, I'll look into this. It's odd it's None because the control flow should account for this, thanks for the flag.

@muellerzr
Copy link
Contributor

@b5y your issue is the default here is a type, not an instance. It should still remain as None because otherwise you can start up the distributed process :)

@b5y
Copy link
Author

b5y commented Apr 2, 2024

Well, I know that there must be instance, not a type. But the main reason why I left it as it is now is because I wanted to see the result of the last cell in the notebook (the one with BiTrainer class). I've tried many approaches, especially those mentioned in the docs, but eventually nothing worked. So I left draft version, not the one I tried from the beginning where I put default=None.

I knew how to fix it inside transformers library, but my main goal was/is to prepare production-ready code.

@muellerzr
Copy link
Contributor

muellerzr commented Apr 2, 2024

I'm not sure what's happening in your code, as I'm unable to reproduce this issue. Please provide a full reproducer with your exact current code, for us to help.

One thing I notice, why are you not instantiating the TrainingArguments?

training_args = TrainingArguments

training_args.per_device_train_batch_size = 256
training_args.train_group_size = 15
training_args.negatives_cross_device = negatives_cross_device
# select a appropriate for your model. Recommend 1e-5/2e-5/3e-5 for large/base/small-scale.
training_args.learning_rate = 3e-5
training_args.temperature = temperature
# instruction for query, which will be added to each query. You also can set it "" to add nothing to query.
training_args.query_instruction_for_retrieval = ""
# use passages in the same batch as negatives. Default value is True.
training_args.use_inbatch_neg = use_inbatch_neg

This does not seem right whatsoever, and one should not be modifying values like this after the fact.

@b5y
Copy link
Author

b5y commented Apr 3, 2024

Sorry, that was a typo in instantiation. And I do agree with modifying values. Setting up an instantiation like this

training_args = TrainingArguments(
    output_dir = "output_dir",
    per_device_train_batch_size=256,
    # train_group_size=15,
    learning_rate = 3e-5,
    temperature = temperature,
    # query_instruction_for_retrieval = "",
    use_inbatch_neg = use_inbatch_neg
)

Fixed my problem.

Anyways, the error with accelerator_config field was weird.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants