Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch.compile #1024

Merged
merged 2 commits into from
Jan 4, 2024
Merged

Add support for torch.compile #1024

merged 2 commits into from
Jan 4, 2024

Conversation

p1atdev
Copy link
Contributor

@p1atdev p1atdev commented Dec 26, 2023

Added:

  • New options, --torch_compile and --dynamo_backend
    • --torch_compile: Enables torch.compile. Default is False.
      • This option is currently incompatible with --xformers. Please use --sdpa option instead.
    • --dynamo_backend: The backend used with torch.compile. Default is "inductor". "eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt" are avaiable, but most are not tested.
      • inductor and eager were worked.

Changed:

  • Bumped the einops version from 0.6.0 to 0.6.1 due to be compatible with torch.compile. (more information)

Related:

@p1atdev p1atdev changed the title Add support torch.compile Add support for torch.compile Dec 26, 2023
@FurkanGozukara
Copy link

Torch compile not available on Windows right?

Also what improvements / changes it brings?

@p1atdev
Copy link
Contributor Author

p1atdev commented Dec 26, 2023

Yes, torch.compile does not work on Windows, but it works on WSL.

In my small experiment, training with options --sdpa, --torch_compile and --dynamo_backend eager was faster than --xformers only. (RTX 3070Ti)

wandb: https://wandb.ai/p1atdev/sd-scripts-torch_compile/workspace?workspace=user-p1atdev

The training is very slow while a few steps after torch.compile but then gets faster. Also, torch.compile is expected to be high performance with a modern NVIDIA GPU (like H100, A100, or V100).

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

Therefore, using torch.compile may be faster than just using xformers for large trainings.

@FurkanGozukara
Copy link

Yes, torch.compile does not work on Windows, but it works on WSL.

In my small experiment, training with options --sdpa, --torch_compile and --dynamo_backend eager was faster than --xformers only. (RTX 3070Ti)

wandb: https://wandb.ai/p1atdev/sd-scripts-torch_compile/workspace?workspace=user-p1atdev

The training is very slow while a few steps after torch.compile but then gets faster. Also, torch.compile is expected to be high performance with a modern NVIDIA GPU (like H100, A100, or V100).

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

Therefore, using torch.compile may be faster than just using xformers for large trainings.

ty can you tell the it / s difference?

by the way best looking example is xformers

@p1atdev
Copy link
Contributor Author

p1atdev commented Dec 27, 2023

The following are screenshots taken during longer trainings:

  • First time of --sdpa, --torch_compile and --dynamo_backend eager
    image-1

  • First time of --xformers
    image

  • Second time of --sdpa, --torch_compile and --dynamo_backend eager
    image

wandb: https://wandb.ai/p1atdev/pvc-torch_compile

I'm not familiar with torch.compile so I don't know the exact details, but I think torch.compile requires a certain steps of warm-up, so the second training with torch.compile is faster than xformers.

@kohya-ss kohya-ss changed the base branch from main to dev January 4, 2024 01:49
@kohya-ss kohya-ss merged commit 07bf2a2 into kohya-ss:dev Jan 4, 2024
1 check passed
@kohya-ss
Copy link
Owner

kohya-ss commented Jan 4, 2024

Sorry for the delay. Thank you so much for this great PR! I don't use Linux/WSL personally, but this is really nice!

@sdbds
Copy link
Contributor

sdbds commented Jan 4, 2024

Yes, torch.compile does not work on Windows, but it works on WSL.

In my small experiment, training with options --sdpa, --torch_compile and --dynamo_backend eager was faster than --xformers only. (RTX 3070Ti)

wandb: https://wandb.ai/p1atdev/sd-scripts-torch_compile/workspace?workspace=user-p1atdev

The training is very slow while a few steps after torch.compile but then gets faster. Also, torch.compile is expected to be high performance with a modern NVIDIA GPU (like H100, A100, or V100).

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

Therefore, using torch.compile may be faster than just using xformers for large trainings.

I noticed that this is not sd-script's accelerator 0.0.23 but 0.0.25.
We should upgrade the dependency version otherwise a lot of options won't work.

kohya-ss added a commit that referenced this pull request Jan 4, 2024
@kohya-ss
Copy link
Owner

kohya-ss commented Jan 4, 2024

I updated accelerate to 0.0.25. I hope this makes this PR to work.

@FurkanGozukara
Copy link

p1atdev

what does second time of training means?

@dill-shower
Copy link

The following are screenshots taken during longer trainings

Can you please test speed with inductor backend?

@kohya-ss
Copy link
Owner

@p1atdev What version of PyTorch do you recommend, 2.1 or does 2.0 work fine? I would like to mention it in the documentation when updating.

@p1atdev
Copy link
Contributor Author

p1atdev commented Jan 13, 2024

I tested with PyTorch version 2.1.2+cu118 and it worked. Also according to the PyTorch release notes, torch.compile is more stable in version 2.1 or later.

https://github.com/pytorch/pytorch/releases/tag/v2.1.0

@kohya-ss
Copy link
Owner

Thank you for clarification!

@feffy380
Copy link
Contributor

feffy380 commented Jan 18, 2024

@p1atdev What training script did you test with? With train_network.py (SD1.x lora) it crashes in the sdpa forward function:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function rearrange at 0x775e95df74c0>(*(FakeTensor(..., device='cuda:0', size=(4, s1, 320), dtype=torch.float16,
           grad_fn=<CloneBackward0>), 'b n (h d) -> b h n d'), **{'h': 8}):
unhashable type: non-singleton SymInt

from user code:
   File "/home/hope/src/sd/sd-scripts/library/original_unet.py", line 741, in resume_in_forward_sdpa_at_739
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
  File "/home/hope/src/sd/sd-scripts/library/original_unet.py", line 741, in <lambda>
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))

(However, I am using pytorch-rocm 2.3 nightly and don't know if torch.compile is fully supported on AMD in the first place)

Disty0 pushed a commit to Disty0/sd-scripts that referenced this pull request Jan 28, 2024
@jdack41
Copy link

jdack41 commented Feb 9, 2024

I've got same error on mac, and also got same error on wsl with cuda. Mac's torch version is 2.2.0, and wsl has 2.1.2.

@jdack41
Copy link

jdack41 commented Feb 10, 2024

@p1atdev @kohya-ss
It seems there is issue on training with sdxl_train.py. sdxl_train_network.py could run.

@ultranationalism
Copy link

I've got same error on mac, and also got same error on wsl with cuda. Mac's torch version is 2.2.0, and wsl has 2.1.2.

same error on torch 2.2.0 cu118

@ultranationalism
Copy link

@p1atdev @kohya-ss It seems there is issue on training with sdxl_train.py. sdxl_train_network.py could run.

try to upgrade your einops to the lastest version on torch2.1.2+cu118

@iamargentum
Copy link

i updated my torch to 2.1.1 and tried training with the torch_compile flag, but it keeps failing with this error - "LayerNormKernelImpl" not implemented for 'Half'
any idea about what this is and how it could be fixed?

@jdack41
Copy link

jdack41 commented Feb 13, 2024

@p1atdev @kohya-ss It seems there is issue on training with sdxl_train.py. sdxl_train_network.py could run.

try to upgrade your einops to the lastest version on torch2.1.2+cu118

Thank you for reply. This solved non hashable error both Mac and wsl(updated einops to 0.7.0).
But saved weights causes NansException on Automatic1111(not lora, finetuned weights).

@jdack41
Copy link

jdack41 commented Feb 15, 2024

https://discuss.pytorch.org/t/how-to-save-load-a-model-with-torch-compile/179739/2
According to this thread, torch.compile will add a prefix ‘_orig_mod.’ to state_dict() of the model.
So removing ‘_orig_mod.’ when saving weights solved NansException.
like this

    def update_sd(prefix, sd):
        for k, v in sd.items():
            key = prefix + k.replace('_orig_mod.', '')
            if save_dtype is not None:
                v = v.detach().clone().to("cpu").to(save_dtype)
            state_dict[key] = v

@jdack41
Copy link

jdack41 commented Feb 15, 2024

Also saving Lora probably needs to do the same or it is structurally broken.

@jdack41 jdack41 mentioned this pull request Feb 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants