Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set random seeds fixed #618

Open
scp92 opened this issue Jan 3, 2023 · 17 comments
Open

How to set random seeds fixed #618

scp92 opened this issue Jan 3, 2023 · 17 comments
Labels
good first issue Good for newcomers

Comments

@scp92
Copy link

scp92 commented Jan 3, 2023

❓ Questions and Help

Different results occur when I run the same code twice. And the set_seed func run before all.

def set_seed(seed):
    random.seed(seed)  # Python random module.
    np.random.seed(seed)  # Numpy module.
    set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = cudnn_benchmark
    torch.backends.cudnn.deterministic = cudnn_deterministic
@danthe3rd
Copy link
Contributor

Hi,
What part of xFormers is causing this issue exactly? Can you post a minimal example that reproduces the non-deterministic behavior?

@SmiVan
Copy link

SmiVan commented Jan 9, 2023

Got the same issue, enabling xformers in my project causes results to be slightly random even with the seed and other settings kept constant.

Observed on Stable Diffusion 1.5 and 2.1 using huggingface/diffusers Stable Diffusion pipeline with the KDPM2 Ancestral Discrete scheduler.

The differences are very subtle, see example:
frogs-on-stilts

Note that the two xformers runs were performed sequentially here.

On some other prompts the differences may be more drastic, I'm not exactly sure why, but the general layout of the image will remain the same.

@danthe3rd
Copy link
Contributor

Hello, thanks for the report.
I will have a look, but we don't guarantee a deterministic behavior in general when it comes to numerics. It depends on the order in which the operations are done/scheduled, and because (a+b)+c is generally different from a+(b+c), it would add some complexity to have pure determinism.
The difference should also be very small, and not affect general accuracy. Is there a reason why you need a truly deterministic algorithm?

@danthe3rd
Copy link
Contributor

Also might be a duplicate of #624 which contains a minimal repro

@SmiVan
Copy link

SmiVan commented Jan 9, 2023

It's not at all necessary for me to have a truly deterministic algorithm - I just found it a bit jarring when I noticed it.
That being said, I think it should be noted somewhere more clearly that there is a margin of variation involved when using this, just in case perfect reproducibility matters for someone.

Here's a different example with a more complex prompt where the variation is far more apparent (both images are with xformers, generated sequentially, XOR in the middle - though here the difference is obvious to the naked eye):
more-frogs-on-stilts

@takuma104
Copy link
Contributor

takuma104 commented Jan 10, 2023

That being said, I think it should be noted somewhere more clearly that there is a margin of variation involved when using this, just in case perfect reproducibility matters for someone.

I agree. There is a documentation on reproducibility for Pytorch.
https://pytorch.org/docs/stable/notes/randomness.html

Though, an application like Stable Diffusion may be special for general machine learning applications.

Since test_mem_eff_attention.py guarantees that the difference from the reference implementation (deterministic behavior as far as I have tried) is reasonably small (4e-3 in fp16), the accuracy required to reproduce the picture generated by Stable Diffusion seems more stringent.

@danthe3rd danthe3rd added the good first issue Good for newcomers label Jan 10, 2023
@danthe3rd
Copy link
Contributor

The use-case of stable diffusion is quite extreme because we're applying mem-eff multiple times per model forward, and repeating that dozen times, which can amplify even very small variations. However I totally agree with you that we should:
(1) Update documentation to make that clear
(2) Maybe also raise a warning if torch.use_deterministic_algorithms is set when using mem-eff attention

I added the "Good first issue" tag, so we welcome contributions for that - if you want to add a pull request to address either of those, that would be great :)

@takuma104
Copy link
Contributor

@danthe3rd I have not made a pull request yet, but I wrote it anyway. main...takuma104:xformers:non-determistic-warn

Please let me know if there is an appropriate wording for the document, as I am not a native speaker.

I have a few questions.

  • Do I need to write unit test code? It seems that Pytorch has not written any test code for this part. As far as I have tested it in my environment, the behavior seems to be correct.

  • It looks like there are multiple backends, but so far I have only written one for cutlass. Is there a way to switch backends?

@danthe3rd
Copy link
Contributor

danthe3rd commented Jan 10, 2023

The code looks great! I didn't know about this alertNotDeterministic function actually.
We don't need to write tests for that.
But we would need to add this warning to the cutlass backward pass as well.
Regarding the note wording, we could simply say "This operator may be nondeterministic".
Can you open a PR? :)

Regarding Flash, I don't see a way to call "alertNotDeterministic" from python, but maybe @tridao can add it to Flash-Attention later

EDIT: Not sure if Flash is deterministic or not actually, would be worth testing

@tridao
Copy link

tridao commented Jan 10, 2023

@danthe3rd Is this about the forward pass or the backward pass?
I believe the forward pass for FlashAttention is deterministic (I had a test for this at some point), while the backward pass may not be (due to atomic adds).

@danthe3rd
Copy link
Contributor

Oh that's great news thanks!
I'll make the change @takuma104

@takuma104
Copy link
Contributor

@danthe3rd
I have added to backward pass and pulled request (#635). Thanks for the correction of the documentation. It is indeed better to keep it as simple as this.

I'm a little concerned about this alertNotDeterministic(). A warning like this will appear.

UserWarning: efficient_attention_forward_cutlass does not have a deterministic 
implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. 
You can file an issue at https://github.com/pytorch/pytorch/issues to help us 
prioritize adding deterministic support for this operation. 
(Triggered internally at /... /ATen/Context.cpp:82.)

This warning seems to be caused by Pytorch. Hmmm...

@tridao
Thanks for the great news! It seems that adding op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp is reproducible. I will patch Diffusers and draw a picture to see how reproducible it is.

@takuma104
Copy link
Contributor

takuma104 commented Jan 11, 2023

I have tried the patch to Diffusers to force use FlashAttention , it seems to work without stopping Unet inference, but I got the following error in VAE decode.

ValueError: xformers.memory_efficient_attention: Operator flshattF does not support this input

The shape of the q,k,v inputs are all (1, 1, 9216, 512). I tried specifying this shape in the minimal code, and sure enough, I get the same error. I followed up with the debugger and it seems that the K in q and v exceeded SUPPORTED_MAX_K and this caused a Value Error. Hmmm.. :(

@danthe3rd
Copy link
Contributor

Flash attention does not support K>128 unfortunately. But I believe you are doing something wrong there. Is your sequence of size 1?!

@takuma104
Copy link
Contributor

takuma104 commented Jan 11, 2023

Sorry, I have to make a correction.

Here is the code I used:
huggingface/diffusers@de7a88b

The following line was reported just before the Value Error.
torch.Size([1, 9216, 512]) torch.Size([1, 9216, 512]) torch.Size([1, 9216, 512])

I have the batch size set at 1. I interpreted this result as (B,M,H,K)=(1,1,9216,512), but this is incorrect, since the following memory_efficient_attention() specifications:

If inputs have dimension 3, it is assumed that the dimensions are [B, M, K] and H=1

so it had to be interpreted as (B,M,H,K)=(1,9216,1,512). So the sequence length is 9216.

@takuma104
Copy link
Contributor

Usually it uses Flash Attention, and if it is not available, it fallbacks to Cutlass, which gave me almost the result I wanted. Thanks for the information on Flash Attention.
https://gist.github.com/takuma104/9d25bb87ae3b52e41e0132aa737c0b03

@Drakmour
Copy link

So now it is normal that generations with exact same seed will have tiny differences with xformers? I check my old generations that I've made in december and a bit earlier and every time I generated images with exact same settings it gave me exact same 100% results. Now I will get changes and I can't prevent that? Cuz sometimes even small change can cauze a bad generation, and eye, a finger, etc. And when I make 100+ 512x512 images to regenerate them with highres fix, I can't cuz I won't get the same image.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

6 participants