Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPEX 2.1.10+xpu regression with Stable Diffusion workloads that were working with IPEX 2.0.120+xpu on WSL2/Linux #483

Closed
simonlui opened this issue Dec 14, 2023 · 24 comments
Labels
ARC ARC GPU Crash Execution crashes LLM

Comments

@simonlui
Copy link

Describe the bug

Running ComfyUI, a Stable Diffusion frontend, with my Docker image here, I am getting issues with the simplest default workflow which uses Stable Diffusion 1.5. I am getting the following even after disabling ipex.optimize().

2023-12-14 08:57:22,727 - root - ERROR - !!! Exception during processing !!!
2023-12-14 08:57:22,727 - root - ERROR - Traceback (most recent call last):
  File "/ComfyUI/execution.py", line 153, in recursive_execute
    output_data, output_ui = get_output_data(obj, input_data_all)
  File "/ComfyUI/execution.py", line 83, in get_output_data
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
  File "/ComfyUI/execution.py", line 76, in map_node_over_list
    results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
  File "/ComfyUI/nodes.py", line 1299, in sample
    return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
  File "/ComfyUI/nodes.py", line 1269, in common_ksampler
    samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
  File "/ComfyUI/comfy/sample.py", line 100, in sample
    samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
  File "/ComfyUI/comfy/samplers.py", line 715, in sample
    return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
  File "/ComfyUI/comfy/samplers.py", line 621, in sample
    samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
  File "/ComfyUI/comfy/samplers.py", line 560, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/ComfyUI/comfy/k_diffusion/sampling.py", line 137, in sample_euler
    denoised = model(x, sigma_hat * s_in, **extra_args)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/comfy/samplers.py", line 284, in forward
    out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/comfy/samplers.py", line 274, in forward
    return self.apply_model(*args, **kwargs)
  File "/ComfyUI/comfy/samplers.py", line 271, in apply_model
    out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
  File "/ComfyUI/comfy/samplers.py", line 252, in sampling_function
    cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
  File "/ComfyUI/comfy/samplers.py", line 226, in calc_cond_uncond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
  File "/ComfyUI/comfy/model_base.py", line 85, in apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 854, in forward
    h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
  File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 46, in forward_timestep_embed
    x = layer(x, context, transformer_options)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/comfy/ldm/modules/attention.py", line 606, in forward
    x = block(x, context=context[i], transformer_options=transformer_options)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/comfy/ldm/modules/attention.py", line 433, in forward
    return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
  File "/ComfyUI/comfy/ldm/modules/diffusionmodules/util.py", line 189, in checkpoint
    return func(*inputs)
  File "/ComfyUI/comfy/ldm/modules/attention.py", line 493, in _forward
    n = self.attn1(n, context=context_attn1, value=value_attn1)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ComfyUI/comfy/ldm/modules/attention.py", line 388, in forward
    return self.to_out(out)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half

Prompt executed in 0.08 seconds

This and other complex processing workloads for most part had no issues with IPEX 2.0.120+xpu. I know also this isn't a Docker issue either from the image or Docker itself because doing the installation sanity check yields the following.

root@fedora:/ComfyUI# python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__); [print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())];"
/usr/local/lib/python3.10/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: ''If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
2.1.0a0+cxx11.abi
2.1.10+xpu
[0]: _DeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=16288MB, max_compute_units=512, gpu_eu_count=512)

There are similar but probably different problems with other Stable Diffusion front ends, where I also tested out SD.Next and the dev branch of stable-diffusion-webui to sanity check if it was just a ComfyUI issue and although not in Docker, I got similar issues with similar simple workloads and setup with the latest IPEX on both frontends. SD.Next just says loading IPEX fails and doesn't do anything while stable-diffusion-webui does have a backtrace but IPEX support is in beta and it has to do with Generators not existing which is a separate issue from this one which I can file if someone requests it.

Versions

Ran inside Ubuntu 22.04 Docker container with aforementioned image with Fedora 39 host

PyTorch version: 2.1.0a0+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.1.10+xpu
IPEX commit: a12f9f6
Build type: Release

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: N/A
Clang version: N/A
IGC version: N/A
CMake version: N/A
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.6.6-200.fc39.x86_64-x86_64-with-glibc2.35
Is XPU available: True
DPCPP runtime version: N/A
MKL version: N/A
GPU models and configuration:
[0] _DeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=1, total_memory=16288MB, max_compute_units=512, gpu_eu_count=512)
Intel OpenCL ICD version: 23.17.26241.33-647~22.04
Level Zero version: 1.3.26241.33-647~22.04

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 5950X 16-Core Processor
CPU family: 25
Model: 33
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU max MHz: 5084.0000
CPU min MHz: 550.0000
BogoMIPS: 6800.05
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization: AMD-V
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 8 MiB (16 instances)
L3 cache: 64 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.10+xpu
[pip3] numpy==1.26.2
[pip3] torch==2.1.0a0+cxx11.abi
[pip3] torchaudio==2.1.0a0+cxx11.abi
[pip3] torchvision==0.16.0a0+cxx11.abi
[conda] N/A

@alexsin368
Copy link

@simonlui I just picked up your issue and will work on reproducing it first

@simonlui
Copy link
Author

Let me know if you need help with the Docker images setup but it should be self explanatory to set up. If you use the base Dockerfile, you will get the currently working version using IPEX 2.0.120+xpu and using the Dockerfile.latest file should get you a version using IPEX 2.1.10+xpu which doesn't work.

@alexsin368
Copy link

alexsin368 commented Dec 21, 2023

I'm running into an issue trying to build the docker image. I've set my proxy settings but it's complaining about the public key being not available. See the attached log.

docker_build_error.txt

@simonlui
Copy link
Author

simonlui commented Dec 21, 2023

Sorry for the late reply. The issue seems to be with the GPG key you have retrieved which is from Intel itself, and I believe the command that fetches and installs it didn't run correctly. Can you clean out your cache from building the image and try rebuilding it again to see what you get? I checked and can access and use the GPG key without issue here when I rebuilt the image from scratch just a while ago in the United States where I live. That to me means that it's not an outage from fetching the GPG key which was an issue a month ago. I'm not sure if your geographic region, HTTPS proxy settings, or corporate firewall/policy has anything to do with not being able to access it but I do find it weird you aren't able to access your company's own domain here for something like this. I am pulling the key using this line in the Dockerfile which I derived from Intel's own documentation for oneAPI installation for APT Linux system in Step 3 here. Here's the line I used which I believe you are running into an issue with.

RUN no_proxy=$no_proxy wget --progress=dot:giga -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB    | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null &&    echo 'deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main'    | tee /etc/apt/sources.list.d/oneAPI.list

Besides a few added arguments, this should be verbatim the commands from that page chained together. Maybe try without proxy? Just for reference, I just use the following when I build my image here locally: docker build -t ipex-arc-comfy:latest -f Dockerfile.latest .


Unrelated to the above, but there are a few other pieces of information I want to mention because I forgot to include in my initial bug report which I believe you will find helpful with solving this bug. I built a container using the following arguments on my Linux system to generate the bug report's output.

docker run -it --device /dev/dri --network=host -v /home/simonlui/Code_Repositories/ComfyUI:/ComfyUI:Z -v /home/simonlui/Code_Repositories/models:/models:Z -v deps:/deps -v huggingface:/root/.cache/huggingface  --security-opt=label=disable --name comfy-server -e ComfyArgs="--listen --disable-ipex-optimize" localhost/ipex-arc-comfy:latest

Obviously, adapt the above to your system but I believe you should be able to replicate it using the same arguments I did.

If you want to experiment with the image and installation with a container, you need to override the entrypoint startup.sh script I set. You can do this by adding in --entrypoint bash before the localhost/ipex-arc-comfy:latest argument with the example build line I used and omitting the -e ComfyArgs="--listen --disable-ipex-optimize". You can still run ComfyUI manually if you run the following lines inside such a container:

source /deps/venv/bin/activate
python3 /ComfyUI/main.py --listen --disable-ipex-optimize

But you should be able to quit out into the container bash shell using Crtl-C and do whatever instead of the normally intended usage of starting and stopping the container.

@alexsin368
Copy link

The only system I have available with an Arc A770 GPU to work with at this time is a system based outside of the US. Because I'm in the US, and I need VPN to access it, there are proxy settings to be configured. I resolved that, but I might have hit a firewall issue. I'm currently working with my team to resolve this before I can proceed with building the docker image. This is not an issue on US-based systems.

And thank you for the additional notes. Your sample docker run command will help me run mine with changes to the arguments.

@simonlui
Copy link
Author

simonlui commented Jan 7, 2024

Any updates with this issue?

@KerfuffleV2
Copy link

The fact that it is taking weeks to locate an Intel GPU at Intel is really not inspiring much confidence and AMD is coming out with a cheap 16GB VRAM card this month. My A770 is still in the return window until the end of the month and I am seriously considering taking advantage of it. I got this card to do SD and right now it is a serious struggle. Is there a reason to believe its going to get better?

@simonlui
Copy link
Author

The fact that it is taking weeks to locate an Intel GPU at Intel is really not inspiring much confidence and AMD is coming out with a cheap 16GB VRAM card this month. My A770 is still in the return window until the end of the month and I am seriously considering taking advantage of it. I got this card to do SD and right now it is a serious struggle. Is there a reason to believe its going to get better?

You can decide whatever you want to do but there are paths forward. This is only an issue with using the latest code and ComfyUI. Using the previous version of IPEX, IPEX 2.0.120+xpu, SD works correctly with ComfyUI and quite a few extensions too. It should work with no issue with other frontends like stable-diffusion-webui and SD.Next which patched workarounds for IPEX using the latest version if you must have it.

@KerfuffleV2
Copy link

You can decide whatever you want to do but there are paths forward. This is only an issue with using the latest code and ComfyUI. Using the previous version of IPEX, IPEX 2.0.120+xpu, SD works correctly with ComfyUI and quite a few extensions too. It should work with no issue with other frontends like stable-diffusion-webui and SD.Next which patched workarounds for IPEX using the latest version if you must have it.

Yes, I know it's possible to get it working currently after a fashion. I am actually maintaining my own set of patches based on SD.next's whole CUDA emulation layer: comfyanonymous/ComfyUI#476 (comment) - so it is possible to hack it into a working state (with no help from Intel as far as I know) but that is obviously not an idea scenario.

I think we have different definitions of "path forward". You are saying the path forward is to stay in the past and just run an old version. Also, the community hacked together something that works around the issues this time but what about the next?

I really don't think I am unique in wanting to know that my hardware is going to continue to see support, that problems will be fixed/responded to in a relatively timely way. Running old versions and losing access to performance/security/QoL improvements is not a solution.

@simonlui
Copy link
Author

I think we have different definitions of "path forward". You are saying the path forward is to stay in the past and just run an old version. Also, the community hacked together something that works around the issues this time but what about the next?

I would not have filed this issue the very day IPEX 2.1.10+xpu released if I didn't have a vested interest in getting it fixed properly. However, with things being at the "fixing" phase of things, "moving forward" is meant to mean being able to still run Stable Diffusion in this scenario although with less than ideal setup, workarounds and hacks to make it work and I wasn't aware if you were stuck at that part so it's great you know how to get it working. But I do agree with your definition too that the status quo is not acceptable going forward, and should be fixed before the next release of IPEX.

I really don't think I am unique in wanting to know that my hardware is going to continue to see support, that problems will be fixed/responded to in a relatively timely way. Running old versions and losing access to performance/security/QoL improvements is not a solution.

Given the way Intel's GPUs have launched thus far as a new player in the field and knowing where they are prioritizing things, I can't say I am surprised at this speed which this is getting addressed and I will say it is frustrating for me too. But the facts also mean that users are going to have to put up with hassle from various fronts like this and people who have reviewed these cards have said as much and various other issues have popped up about these GPUs. Again, you can decide what you want to do based on that and express displeasure at it but I don't believe it is helpful to solving the issue. I actually do want to put in another issue for enhancement talking about the root of it which is the fact most Stable Diffusion front-ends need to implement a workaround layer to implement operations that should be natively supported in IPEX but are not.

@alexsin368
Copy link

Hi @simonlui thanks for your patience. The team has been supporting issues that came in over the last month and I just came back from a business trip. Now that I'm back, I will be prioritizing your issue and get back to you with an update soon.

The first step of any debug is to reproduce the issue, which includes gathering the exact hardware and software resources. We don't have access to your setup, which means there is additional overhead needed before the issue is reproduced. If you have an easier way to reproduce the issue, that would be most helpful and can speed up the time to resolve it.

@simonlui
Copy link
Author

If you have an easier way to reproduce the issue, that would be most helpful and can speed up the time to resolve it.

The main issue seems to be your access to hardware which I don't know if you have resolved your GPG key access issues yet. Technically, you should be able to reproduce the issue as long as it's the same Arc Alchemist GPU family since the sample workflow in ComfyUI that would reproduce this error is a simple SD 1.5 workflow which should run on about everything as long as you have an Arc family graphics card and IPEX so even an A310 would work here. If you have to be even more strict there, you could possibly use a GPU using the same chip of which I think includes the other A770 8GB, the A750 and the A580. I think the Flex 170 is the only card Intel produces though that reasonably gets close but not exactly to the A770 16 GB since it uses the same chip and memory configuration but not firmware so I guess you can try it with that if you have access to that.

@intel-ravig
Copy link
Contributor

@simonlui - I took over this ticket recently and have been able to reproduce this issue.
After debugging and tracing a little bit, I suspect the issue is originating from "forward" function in Class CrossAttention.

q,k,v all have dtype of torch.float16.
But line "self.to_out=....." generates torch.float32 sometimes - which ultimately throws the dtype mismatch. [dtype initialization???]

Just for kicks I made an addition in /deps_latest/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py

def forward(self, input: Tensor) -> Tensor:
  print(self.weight.dtype, input.dtype)
        if (self.weight.dtype != input.dtype):
            input = input.to(self.weight.dtype)
        return F.linear(input, self.weight, self.bias)

After making the above change, the model does run, but I do see numerous counts of data_type mismatches further down the lane.
Since torch.float32 is default. I am suspecting dtype is not flowing into functions properly or is getting set to default. Can you look over in your function for "to_out" and see what is going on.

In the mean time, I will talk to IPEX engineering team and bring to their attention.

@simonlui
Copy link
Author

After debugging and tracing a little bit, I suspect the issue is originating from "forward" function in Class CrossAttention.

q,k,v all have dtype of torch.float16.
But line "self.to_out=....." generates torch.float32 sometimes - which ultimately throws the dtype mismatch. [dtype initialization???]
After making the above change, the model does run, but I do see numerous counts of data_type mismatches further down the lane.

This makes sense as the backtrace does complain about this but the strange thing is that ComfyUI does work with the previous version of IPEX somehow if you run my non-latest docker image that does the setup for that. This either seems like a restriction that was put in or regression from a change in IPEX. This may belong in a different issue for this and please tell me if you think so but currently, other IPEX implementations have workarounds that works around functions IPEX does not implement correctly so it works with everything needed to run Stable Diffusion. Those workarounds shouldn't really exist if IPEX was doing everything correctly. ComfyUI's implementation worked with IPEX without those workarounds last version because some things were not being run GPU like the Text Encoder. If you run ComfyUI with --gpu-only, it never worked after a certain point. The expectation was that this arrangement although fragile should have been able to run on the latest version of IPEX but no longer does.

Since torch.float32 is default. I am suspecting dtype is not flowing into functions properly or is getting set to default. Can you look over in your function for "to_out" and see what is going on.

It depends on the model used but the default Stable Diffusion checkpoint ComfyUI tells you to use and downloads for the default workflow is FP16 by default from RunwayML's release. And ComfyUI is not my application, this is an open source frontend for Stable Diffusion at https://github.com/comfyanonymous/ComfyUI. The code you are referring to is at https://github.com/comfyanonymous/ComfyUI/blob/d76a04b6ea61306349861a7c4657567507385947/comfy/ldm/modules/attention.py#L382 so I hope that helps with you tracking down this issue more.

@intel-ravig
Copy link
Contributor

intel-ravig commented Jan 24, 2024

Thanks @simonlui . I do understand that newer version does not work, but both PyTorch and IPEX both are upgraded.

Some comments here -
a. Have you verified the same setup with cuda device + latest PyTorch?

b. Yes the function "torch.nn.functional.scaled_dot_product_attention" is in question. Please refer to similar issue tagged here - pytorch/pytorch#110213 . Perhaps related? and pointing towards a PyTorch difference rather than IPEX.

c. I tested fp16 models of CompVis/stablediffusionv1.4 and https://huggingface.co/runwayml/stable-diffusion-v1-5 on the latest PyTorch and IPEX released and they both worked fine.

@Disty0
Copy link

Disty0 commented Jan 26, 2024

This is an autocast issue rather than sdpa issue. sdpa is working as intended, autocast should've catched it before it hit sdpa.

There are more autocast issues i have worked around here:
https://github.com/vladmandic/automatic/blob/dev/modules/intel/ipex/hijacks.py#L105

@jingxu10 jingxu10 added ARC ARC GPU Crash Execution crashes LLM labels Jan 26, 2024
@simonlui
Copy link
Author

Sorry, I didn't have time until now to check and reply to this until now.

a.) The docker arrangement works with a Nvidia GPU on another machine I had access to, but was using WSL2 through Windows.
b.) Probably not, ComfyUI has had Pytorch 2.1 support from comfyanonymous/ComfyUI@48242be onwards with the removed usage of xformers.
c.) I don't doubt that it may be possible that if you use it directly or via something like Diffusers that it may work. However, this is a Stable Diffusion frontend issue with ComfyUI which doesn't use Diffusers to run these models. And it did work with the older version of IPEX and Pytorch. I would agree with @Disty0 this may not be the fault of the SDPA and that it could be something else that didn't work correctly like autocasting of types.

@Disty0
Copy link

Disty0 commented Feb 11, 2024

scaled_dot_product_attention returns float32 when float16 inputs are used with ipex. bfloat16 inputs returns bfloat16 as expected.
This might be related to the ComfyUI issue. Autocast should catch this tho.

Diffusers also catches it with this line:
https://github.com/huggingface/diffusers/blob/v0.26.2/src/diffusers/models/attention_processor.py#L1249

@simonlui
Copy link
Author

Has there been any updates with this issue? It seems like per the prior release cycles for IPEX on XPU, a new version will be coming soon, so I would like to know if this has been addressed or if we will need to keep working around the issue. Thanks.

@intel-ravig
Copy link
Contributor

@simonlui - This task is on priority list on our engineering team side and they are actively working on it. However, it will likely not get fixed in the upcoming release. So the workarounds on the application side will be needed for some more time.

@intel-ravig
Copy link
Contributor

intel-ravig commented Mar 22, 2024

@simonlui - I received word from the engineering team that a public commit addressing the issue has been merged.
7ea2a3c

I have given a test on our internal branch and the issue seems to go away.

Can you please build the source on your side and verify the functional correctness on your side.

I also tried batch sizes > 1 till 6 and they all run fine.

@simonlui
Copy link
Author

simonlui commented Mar 23, 2024

@intel-ravig I can confirm it does indeed work after I compiled xpu-main in this repository's Docker.compile image process, although I had difficulty proceeding past a certain point because of deepspeed issues I think your development team still have yet to resolve. But since it still split out the main packages and I don't think that matters too much in this scenario as that is a separate acceleration unit from the GPUs themselves, I forced installed them with pip and tried out ComfyUI and it seems to work fine without issue.

image

I even tried out offloading everything to the GPU itself and I have had issues with the past with the CLIP encoder crashing also and that was also solved by this SDP fix. Thank you and thanks to the engineering team for getting this in. This will help us all in lowering our maintenance burden hacking fixes to work around this issue.

Edit: I should note this is with the base SDXL models being loaded with torch.xpu.optimize too so that also works well here.

@intel-ravig
Copy link
Contributor

intel-ravig commented Mar 25, 2024

@simonlui - Okay great! Thanks for confirming. If you are satisfied with the fix, please close this issue.
If you would like us to take a look at the 'deepspeed' issue thing, please open a new issue/ticket with a reproducer attached. Thanks

@simonlui
Copy link
Author

Sorry for the delay, I am closing the issue as it is fixed. The deepspeed issue I am pretty sure you guys will rectify at some point since it does seem the Docker image is updated often.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ARC ARC GPU Crash Execution crashes LLM
Projects
None yet
Development

No branches or pull requests

6 participants