Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_text_to_image.py | RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16' #3453

Closed
morgankohler opened this issue May 16, 2023 · 7 comments · Fixed by huggingface/transformers#23942
Labels
bug Something isn't working

Comments

@morgankohler
Copy link

Describe the bug

When running train_text_to_image.py, setting --mixed_precision="bf16" causes an error in the transformers clip model. I am opening this here as I am not sure how to reproduce this from the transformers repo.

Reproduction

#!/bin/bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="bf16" train_text_to_image.py
--pretrained_model_name_or_path=$MODEL_NAME
--dataset_name=$dataset_name
--use_ema
--resolution=512 --center_crop --random_flip
--train_batch_size=1
--gradient_accumulation_steps=4
--gradient_checkpointing
--max_train_steps=15000
--learning_rate=1e-05
--max_grad_norm=1
--lr_scheduler="constant" --lr_warmup_steps=0
--output_dir="sd-pokemon-model"

Logs

The following values were not passed to `accelerate launch` and had defaults used instead:
        `--num_processes` was set to a value of `2`
                More than one GPU was found, enabling multi-GPU training.
                If this was unintended please pass in `--num_processes=1`.
        `--num_machines` was set to a value of `1`
        `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/accelerate/accelerator.py:260: FutureWarning: `logging_dir` is deprecated and will be removed in version 0.18.0 of 🤗 Accelerate. Use `project_dir` instead.
  warnings.warn(
/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/accelerate/accelerator.py:260: FutureWarning: `logging_dir` is deprecated and will be removed in version 0.18.0 of 🤗 Accelerate. Use `project_dir` instead.
  warnings.warn(
05/16/2023 12:25:17 - INFO - __main__ - Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 2
Process index: 1
Local process index: 1
Device: cuda:1

Mixed precision type: bf16

05/16/2023 12:25:17 - INFO - __main__ - Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 2
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: bf16

{'thresholding', 'variance_type', 'clip_sample_range', 'sample_max_value', 'prediction_type', 'dynamic_thresholding_ratio'} was not found in config. Values will be initialized to default values.
{'norm_num_groups'} was not found in config. Values will be initialized to default values.
{'addition_embed_type', 'time_embedding_act_fn', 'resnet_time_scale_shift', 'mid_block_type', 'class_embeddings_concat', 'conv_out_kernel', 'cross_attention_norm', 'class_embed_type', 'timestep_post_act', 'dual_cross_attention', 'resnet_out_scale_factor', 'conv_in_kernel', 'only_cross_attention', 'encoder_hid_dim', 'time_embedding_type', 'mid_block_only_cross_attention', 'upcast_attention', 'use_linear_projection', 'resnet_skip_time_act', 'time_cond_proj_dim', 'projection_class_embeddings_input_dim', 'addition_embed_type_num_heads', 'num_class_embeds', 'time_embedding_dim'} was not found in config. Values will be initialized to default values.
{'addition_embed_type', 'time_embedding_act_fn', 'resnet_time_scale_shift', 'mid_block_type', 'class_embeddings_concat', 'conv_out_kernel', 'cross_attention_norm', 'class_embed_type', 'timestep_post_act', 'dual_cross_attention', 'resnet_out_scale_factor', 'conv_in_kernel', 'only_cross_attention', 'encoder_hid_dim', 'time_embedding_type', 'mid_block_only_cross_attention', 'upcast_attention', 'use_linear_projection', 'resnet_skip_time_act', 'time_cond_proj_dim', 'projection_class_embeddings_input_dim', 'addition_embed_type_num_heads', 'num_class_embeds', 'time_embedding_dim'} was not found in config. Values will be initialized to default values.
05/16/2023 12:25:26 - WARNING - datasets.builder - Found cached dataset parquet (/home/user/.cache/huggingface/datasets/lambdalabs___parquet/lambdalabs--pokemon-blip-captions-10e3527a764857bd/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████████████████████████████████| 1/1 [00:00<00:00, 598.59it/s]
100%|██████████████████████████████████████| 1/1 [00:00<00:00, 626.67it/s]
05/16/2023 12:25:32 - INFO - __main__ - ***** Running training *****
05/16/2023 12:25:32 - INFO - __main__ -   Num examples = 833
05/16/2023 12:25:32 - INFO - __main__ -   Num Epochs = 143
05/16/2023 12:25:32 - INFO - __main__ -   Instantaneous batch size per device = 1
05/16/2023 12:25:32 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 8
05/16/2023 12:25:32 - INFO - __main__ -   Gradient Accumulation steps = 4
05/16/2023 12:25:32 - INFO - __main__ -   Total optimization steps = 15000
Steps:   0%|                                    | 0/15000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/user/test_diffuser_train_script/train.py", line 959, in <module>
    main()
  File "/home/user/test_diffuser_train_script/train.py", line 848, in main
    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 816, in forward
    return self.text_model(
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 717, in forward
    causal_attention_mask = self._build_causal_attention_mask(
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 760, in _build_causal_attention_mask
    mask.triu_(1)  # zero out the lower diagonal
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
Traceback (most recent call last):
  File "/home/user/test_diffuser_train_script/train.py", line 959, in <module>
    main()
  File "/home/user/test_diffuser_train_script/train.py", line 848, in main
    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 816, in forward
    return self.text_model(
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 717, in forward
    causal_attention_mask = self._build_causal_attention_mask(
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 760, in _build_causal_attention_mask
    mask.triu_(1)  # zero out the lower diagonal
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
Steps:   0%|                                    | 0/15000 [00:01<?, ?it/s]
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2585075) of binary: /home/user/anaconda3/envs/pyenv/bin/python
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/pyenv/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/accelerate/commands/launch.py", line 919, in launch_command
    multi_gpu_launcher(args)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/accelerate/commands/launch.py", line 612, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/user/anaconda3/envs/pyenv/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-05-16_12:25:39
  host      : gpu-server-3
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 2585076)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-16_12:25:39
  host      : gpu-server-3
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 2585075)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

System Info

diffusers version: 0.17.0.dev0
Platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.31
Python version: 3.9.16
PyTorch version (GPU?): 1.13.1 (True)
Huggingface_hub version: 0.14.1
Transformers version: 4.30.0.dev0
Accelerate version: 0.20.0.dev0
xFormers version: 0.0.19
Using GPU in script?: Yes (2 A100)
Using distributed or parallel set-up in script?: Yes

@morgankohler morgankohler added the bug Something isn't working label May 16, 2023
@yasyf
Copy link
Contributor

yasyf commented May 17, 2023

looks like this is a transformers issue and can be fixed like this cc @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Actually I think we could redirect this issue directly to PyTorch. Don't think we should solve this in every model in transformers. Gently pinging @Chillee here

@d8ahazard
Copy link
Contributor

Actually I think we could redirect this issue directly to PyTorch. Don't think we should solve this in every model in transformers. Gently pinging @Chillee here

Was this issue introduced in torch 2.0.1? It just "appeared" in my Dreambooth app, and I'd definitely like to fix it until it's solved upstream.

@Chillee
Copy link
Contributor

Chillee commented May 25, 2023

Is this a new problem? it looks like triu just never had a kernel added for cuda.

That did seem to have just been fixed last week though: pytorch/pytorch#101414

@huseyintemiz
Copy link

I have faced the same problem when I make an inference with StableDiffusionPipeline (v2.1) with bfloat16 type
(Torch 2.0). All inference code can run successfully on torch1.13 setup.

@sayakpaul
Copy link
Member

Maybe you need to install PyTorch from source to ensure the support is reflected in your installation?

@kashif
Copy link
Contributor

kashif commented Jun 1, 2023

huggingface/transformers#23942 should fix this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants