Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. #995

Closed
480284856 opened this issue May 18, 2022 · 12 comments 路 Fixed by #1123
Labels
bug Something isn't working

Comments

@480284856
Copy link

480284856 commented May 18, 2022

馃悰 Describe the bug

After following the ResNet50 example in the tutorial as soon as possible,I got the error as the title said. It is like my last usage of hf's accelerate, I can't figure out this complex problem for my first usage. Of course I have tried my best to solve it and the reasons is likely:
colossalai check -i and its output is:
Colossalai should be built with cuda extension to use the FP16 optimizer
If you want to activate cuda mode for MoE, please install with cuda_ext!
CUDA Version: N/A (CUDA_HOME is not set)
PyTorch Version: 1.11.0+cu102
CUDA Version in PyTorch Build: 10.2
PyTorch CUDA Version Match: x
CUDA Extension: x

but I tried in a machine of 11.3 CUDA and I threw a same error.

Below is part of my code:

logger = get_dist_logger()
	# args = colossalai.get_default_parser().parse_args()
	colossalai.launch_from_torch(config='config.py')
	config = Config()
	tokenizer = JiebaTokenizer.from_pretrained('Lowin/chinese-bigbird-base-4096')
	model = BB()
	optimizer = optim.AdamW(params=model.parameters(),lr=1e-5,weight_decay=1e-2)
	lossFunc = F.cross_entropy
	rouge =   load_metric('rouge')

	valida = json.load(open("dataset/dev.json"))
	trains = json.load(open("dataset/train.json"))
	dataSetTrain = DS(trains,tokenizer,config)
	dataSetValid = DS(valida,tokenizer,config)
	tDL = DataLoader(dataSetTrain,batch_size=config.batch_size_train,shuffle=True)
	vDL = DataLoader(dataSetValid,batch_size=config.batch_size_valid)

	engine,tDL,vDL,_ = colossalai.initialize(
		model,
		optimizer,
		lossFunc,
		tDL,
		vDL
	)

	for epoch in range(gpc.config.NUM_EPOCH):
		tDL = tqdm(tDL,leave=False)
		engine.train()
		for batch in tDL:
			labels = batch.pop('labels').cuda()
			batch = {key:value.cuda() for key,value in batch.items()}
			logist = engine(batch)
			loss_sum = engine.criterion(logist.view(-1,config.vocab_size),labels.view(-1))
			title_length = labels.ne(0).sum().item()
			loss = loss_sum/title_length
			engine.backward(loss)
			engine.step()
			engine.zero_grad()
			tDL.set_description(f'Epoch:{epoch}:')
			tDL.set_postfix(loss=loss.item())

Code of model construction

class BB(torch.nn.Module):
	def __init__(self):
		super(BB,self).__init__()
		self.transformer = BigBirdModel.from_pretrained('Lowin/chinese-bigbird-base-4096')
		self.dropout = torch.nn.Dropout(0.2)
		self.output = torch.nn.Linear(768,39999)
        

	def forward(self,batch):
		# batch = self._set_token_type_ids_(batch)
		outputs = self.transformer(**batch).last_hidden_state  #bs token_num outputsize 
		logits = self.output(self.dropout(outputs))  #bs token_num vocab_size
		return logits

here is error info:
/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py:981: UserWarning: floordiv is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/transformers/models/big_bird/modeling_big_bird.py:981: UserWarning: floordiv is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
Traceback (most recent call last):
File "test3_v3.3.py", line 138, in
logist = engine(batch)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/colossalai/engine/_base_engine.py", line 183, in call
return self.model(*args, **kwargs)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 947, in forward
Traceback (most recent call last):
File "test3_v3.3.py", line 138, in
logist = engine(batch)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/colossalai/engine/_base_engine.py", line 183, in call
return self.model(*args, **kwargs)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 947, in forward
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 0: 197 198
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 1: 197 198
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 44596) of binary: /home/guxj/anaconda3/envs/NLP_colossalai/bin/python
Traceback (most recent call last):
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/distributed/launch.py", line 193, in
main()
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/distributed/launch.py", line 189, in main
launch(args)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/distributed/launch.py", line 174, in launch
run(args)
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/distributed/run.py", line 715, in run
elastic_launch(
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/guxj/anaconda3/envs/NLP_colossalai/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
test3_v3.3.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2022-05-18_01:27:08
host : dlp01
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 44597)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2022-05-18_01:27:08
host : dlp01
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 44596)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Environment

CUDA: 10.2
pytorch: 1.11.0
python:3.8.13(miniconda)

@480284856 480284856 added the bug Something isn't working label May 18, 2022
@FrankLeeeee
Copy link
Contributor

Hi, thanks for your report. May I check whether you install colossalai via our download page or build from source?

@480284856
Copy link
Author

480284856 commented May 18, 2022

Hi, thanks for your report. May I check whether you install colossalai via our download page or build from source?

Yes,I followed completely as the document:
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI

install dependency

pip install -r requirements/requirements.txt

install colossalai

pip install .

@FrankLeeeee
Copy link
Contributor

I see, there are two problems.

  1. Colossal-AI could not detect CUDA when installing. Can you export CUDA_HOME=<path/to/your/cuda> or install from our download page directly? Add a -v flag to pip install, e.g. pip install -v will show logs why cuda extension is not installed.
  2. There are unused parameters in your model.

@480284856
Copy link
Author

I see, there are two problems.

  1. Colossal-AI could not detect CUDA when installing. Can you export CUDA_HOME=<path/to/your/cuda> or install from our download page directly? Add a -v flag to pip install, e.g. pip install -v will show logs why cuda extension is not installed.
  2. There are unused parameters in your model.

Thanks, I recreated the env and the error remained I think is point 2.
Could you help me how to inject find_unused_parameters=True to my model with Colossal-AI?

@FrankLeeeee
Copy link
Contributor

I see, there are two problems.

  1. Colossal-AI could not detect CUDA when installing. Can you export CUDA_HOME=<path/to/your/cuda> or install from our download page directly? Add a -v flag to pip install, e.g. pip install -v will show logs why cuda extension is not installed.
  2. There are unused parameters in your model.

Thanks, I recreated the env and the error remained I think is point 2. Could you help me how to inject find_unused_parameters=True to my model with Colossal-AI?

Unfortunately, it is currently not enabled for now. However, you can use zero in your configuration so that torch DDP will not be used. As for parameters of torch DDP, we will fix it and allow for user customization.

@480284856
Copy link
Author

480284856 commented May 18, 2022

@FrankLeeeee I tried to use zero config and the content of my config.py is:

import imp
from colossalai.amp import AMP_TYPE
from colossalai.zero.shard_utils import TensorShardStrategy

NUM_EPOCH = 12
gradient_accumulation = 10
clip_grad_norm = 2.0
shard_strategy = 'zero'

zero = dict(
    model_config=dict(
        shard_strategy=TensorShardStrategy(),
        reduce_scatter_bucket_size_mb=25,
        fp32_reduce_scatter=False,
        tensor_placement_policy="cuda",
        gradient_predivide_factor=1.0,
        use_memory_tracer=False,
        reuse_fp16_shard=False
    ),
    optimizer_config=dict(
        gpu_margin_mem_ratio=0.8,
        initial_scale=2**5,
        min_scale=1,
        growth_factor=2,
        backoff_factor=0.5,
        growth_interval=1000,
        hysteresis=2,
        max_scale=2**32
    )
)

and then initializing my model with:

	model = BB()
	with ZeroInitContext(target_device=torch.cuda.current_device(),
                    shard_strategy=gpc.config.zero.model_config.shard_strategy,
                    shard_param=True):
		model = nn.parallel.DistributedDataParallel(model,device_ids=torch.cuda.current_device(),find_unused_parameters=True)

Unforunately, It threw an error agian which said that:

Traceback (most recent call last):
  File "test3_v3.3.py", line 117, in <module>
    model = nn.parallel.DistributedDataParallel(model,device_ids=torch.cuda.current_device(),find_unused_parameters=True)
  File "/home/guxj/anaconda3/envs/NLP_colossalAI/lib/python3.8/site-packages/colossalai/utils/model/utils.py", line 52, in wrapper
    f(module, *args, **kwargs)
  File "/home/guxj/anaconda3/envs/NLP_colossalAI/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 546, in __init__
    if device_ids is not None and len(device_ids) > 1:
TypeError: object of type 'int' has no len()
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 66168 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 66167) of binary: /home/guxj/anaconda3/envs/NLP_colossalAI/bin/python

May I ask your help again?(^_^)

@FrankLeeeee
Copy link
Contributor

This line is no longer needed when using zero.
model = nn.parallel.DistributedDataParallel(model,device_ids=torch.cuda.current_device(),find_unused_parameters=True)

@480284856
Copy link
Author

But if I create my model with:

with ZeroInitContext(target_device=torch.cuda.current_device(),
                    shard_strategy=gpc.config.zero.model_config.shard_strategy,
                    shard_param=True):
		model = BB()

It will stop running with the info:

RuntimeError: Error(s) in loading state_dict for BigBirdModel:
        size mismatch for bert.embeddings.word_embeddings.weight: copying a param with shape torch.Size([40000, 768]) from checkpoint, the shape in current model is torch.Size([15360000]).
        size mismatch for bert.embeddings.position_embeddings.weight: copying a param with shape torch.Size([4096, 768]) from checkpoint, the shape in current model is torch.Size([1572864]).
        size mismatch for bert.embeddings.token_type_embeddings.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([768]).

maybe I should give machine two model weights but I only have one.

@FrankLeeeee
Copy link
Contributor

FrankLeeeee commented May 18, 2022

This seems a bit tricky. I have requested our team to investigate into this issue. For now, I will create a temporary branch to enable find_unused_parameters so that you can give it a try. I will notify you once it is ready.

@480284856
Copy link
Author

OK, Looking forward to using it!

@FrankLeeeee
Copy link
Contributor

FrankLeeeee commented May 19, 2022

Hi @480284856 , I was a bit busy yesterday and have just created a new branch here. You can install Colossal-AI via the following commands.

git clone https://github.com/FrankLeeeee/ColossalAI.git
cd ColossalAI
git checkout hotfix/support-torch-ddp-config

pip install -r requirements/requirements.txt
pip install -v .

If CUDA Extension is not installed, it will show logs in the first few lines when doing pip install -v .

You can configure torch DDP by adding the following to your config.py.

torch_ddp = dict(
    find_unused_parameters=True
)

process_group and device_ids have already been added by default in Colossal-AI.

@FrankLeeeee
Copy link
Contributor

Meanwhile, you can try to use the NAIVE mode of AMP as it does not use torch DDP as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants