Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some feedback #6

Open
rkfg opened this issue Sep 23, 2023 · 12 comments
Open

Some feedback #6

rkfg opened this issue Sep 23, 2023 · 12 comments

Comments

@rkfg
Copy link

rkfg commented Sep 23, 2023

Disclaimer: I'm not an AI researcher so I could've done something wrong.

So I got it working after some minor changes and import fixes. Now when I run it with 512x768 resolution the issue is: RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision, but got a signal size of[12, 8] I suppose that's the corresponding layer size which becomes rectangular because of the base resolution being like that. 512x512 should result in a [8, 8] array here and it all works fine.

As expected, running A1111 with --no-half makes it work but the speed is much worse.

I used the parameters for SD1.4 and simply hardcoded them to quickly test if it works at all. On a fine tuned model epiCRealism naturalSin FreeU makes the images worse: they become more saturated, the skin texture turns into plastic (maybe because we suppress the high frequency features exactly?). It starts looking more like the base models or the early fine tuned models:

Original:
1

FreeU:
2

AnimateDiff doesn't seem to work in --no-half mode, throws a CUDA error. So we're limited by 512x512. Same symptoms of oversaturation, the skin quality doesn't apply due to grain and artifacts. However, FreeU added a third hand. I tried two slightly different prompts.

Original:
02086-751880090
02085-751880090

FreeU:

02087-751880090
02084-751880090

The faces are garbled in all cases but to be honest I much prefer the results without FreeU. The colors are better, the anatomy is better, the skirt is much more detailed and moves more naturally.

My patch, applied to stable-diffusion-webui/repositories/stable-diffusion-stability-ai:

diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
index cc3875c..ede0b5a 100644
--- a/ldm/modules/diffusionmodules/openaimodel.py
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -4,6 +4,7 @@ import math
 import numpy as np
 import torch as th
 import torch.nn as nn
+import torch.fft as fft
 import torch.nn.functional as F
 
 from ldm.modules.diffusionmodules.util import (
@@ -418,6 +419,24 @@ class Timestep(nn.Module):
         return timestep_embedding(t, self.dim)
 
 
+def Fourier_filter(x, threshold, scale):
+    # FFT
+    x_freq = fft.fftn(x, dim=(-2, -1))
+    x_freq = fft.fftshift(x_freq, dim=(-2, -1))
+
+    B, C, H, W = x_freq.shape
+    mask = th.ones((B, C, H, W)).cuda()
+
+    crow, ccol = H // 2, W //2
+    mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
+    x_freq = x_freq * mask
+
+    # IFFT
+    x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
+    x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
+
+    return x_filtered
+
 class UNetModel(nn.Module):
     """
     The full UNet model with attention and timestep embedding.
@@ -798,8 +817,24 @@ class UNetModel(nn.Module):
             hs.append(h)
         h = self.middle_block(h, emb, context)
         for module in self.output_blocks:
-            h = th.cat([h, hs.pop()], dim=1)
-            h = module(h, emb, context)
+            if True:
+                hs_ = hs.pop()
+
+                # --------------- FreeU code -----------------------
+                # Only operate on the first two stages
+                if h.shape[1] == 1280:
+                    h[:,:640] = h[:,:640] * 1.2
+                    hs_ = Fourier_filter(hs_, threshold=1, scale=0.9)
+                if h.shape[1] == 640:
+                    h[:,:320] = h[:,:320] * 1.4
+                    hs_ = Fourier_filter(hs_, threshold=1, scale=0.2)
+                # ---------------------------------------------------------
+
+                h = th.cat([h, hs_], dim=1)
+                h = module(h, emb, context)
+            else:
+                h = th.cat([h, hs.pop()], dim=1)
+                h = module(h, emb, context)
         h = h.type(x.dtype)
         if self.predict_codebook_ids:
             return self.id_predictor(h)

The simplest way to switch between FreeU and vanilla is to change if True: to if False:. Again, it's just a hack to test if it works.

In conclusion, if everything is correct on my end, it's probably not worth it for the best fine tuned models. On the opposite, to make it work in all cases you have to run it in full 32 bit resolution at ≈3x slowdown and get images that look worse than without it. The base models sure benefit from it but honestly, who uses them except the researchers and LoRA trainers?

I hope I did a mistake somewhere so these results are all wrong. After all, I just copied the part that differs from the original code and fixed the errors to make it work, but who knows.

@ChenyangSi
Copy link
Owner

Thanks for your feedback. I tried AnimateDiff according to https://github.com/guoyww/AnimateDiff.

The following results without pick:

Original:
0-best-quality,-masterpiece,-1girl,-looking-at-viewer,-blurry-background,-upper

FreeU (different factors):

0-best-quality,-masterpiece,-1girl,-looking-at-viewer,-blurry-background,-upper
0-best-quality,-masterpiece,-1girl,-looking-at-viewer,-blurry-background,-upper

@rkfg
Copy link
Author

rkfg commented Sep 23, 2023

Maybe it works better for cartoon-style images and animations because they naturally lack high frequencies... What were the factors that you used in the above? I can try with anime models.

@ChenyangSi
Copy link
Owner

This is just a simple attempt according to the readme of AnimateDiff. We will be providing more results on the FreeU page and paper. We appreciate your continued interest.

@justindujardin
Copy link

justindujardin commented Sep 24, 2023

Maybe it works better for cartoon-style images and animations because they naturally lack high frequencies... What were the factors that you used in the above? I can try with anime models.

@rkfg I also had poor results with FreeU, and then I started switching s1/s2/b1/b2 back to 1.0 at some point during the denoising process. The global features seem to be mostly settled early so that you can transition back to normal values between about 30% and 75% of the way through your steps.

And the results are greatly improved. I've tested it thoroughly, and most of the FreeU "fixes" are kept while still letting the fine details shine through at the end.

freeu_partial_mario

This is basically what I'm doing:

    steps = 30
    unet.freeu.sd21()

    def cb(step, _, __):
        if step == int(steps * 0.5):
            unet.freeu.ones()

    output = pipe(prompt, num_inference_steps=steps, callback=cb)

freeu_partial_waterfall

@rkfg
Copy link
Author

rkfg commented Sep 24, 2023

Mario is noticeably improved, yes, but I prefer the vanilla waterfall, it's more detailed and interesting even though the contrast is a bit lower. SD 2.1 isn't that good in general, even with finetunes. Can you try on the best 1.5 models? Both cartoon and realistic? Would be interesting to see if this method can improve the output over what we can get without it.

@adhikjoshi
Copy link

Maybe it works better for cartoon-style images and animations because they naturally lack high frequencies... What were the factors that you used in the above? I can try with anime models.

@rkfg I also had poor results with FreeU, and then I started switching s1/s2/b1/b2 back to 1.0 at some point during the denoising process. The global features seem to be mostly settled early so that you can transition back to normal values between about 30% and 75% of the way through your steps.

And the results are greatly improved. I've tested it thoroughly, and most of the FreeU "fixes" are kept while still letting the fine details shine through at the end.

freeu_partial_mario

This is basically what I'm doing:

    steps = 30

    unet.freeu.sd21()



    def cb(step, _, __):

        if step == int(steps * 0.5):

            unet.freeu.ones()



    output = pipe(prompt, num_inference_steps=steps, callback=cb)

freeu_partial_waterfall

Can you share full code snippet?

@kadirnar
Copy link

@ChenyangSi , @rkfg
Hi, have you tested with torch.complex32? It loses all features and a gray photo appears.

@justindujardin
Copy link

freeu_partial_waterfall

Can you share full code snippet?

I describe my changes in the diffusers repo: huggingface/diffusers#5164 (comment)

@dajes
Copy link

dajes commented Sep 24, 2023

RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision, but got a signal size of[12, 8]

To fix this you can just cast x to float for this operation in the first line of the Fourier_filter function
x_freq = fft.fftn(x.float(), dim=(-2, -1))
And cast back at the last line
return x_filtered.to(x.dtype)

@YisuiTT
Copy link

YisuiTT commented Sep 25, 2023

@ChenyangSi Hi, can you share how to add FreeU code in T2V, like AnimateDiff?

@ykk648
Copy link

ykk648 commented Oct 20, 2023

ykk648/AnimateDiff-I2V@0842585
@YisuiTT you can refer to my codes

@YisuiTT
Copy link

YisuiTT commented Nov 7, 2023

@ykk648 Thank u for your codes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants