Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Same seed produces different results #1339

Closed
thisislance98 opened this issue Nov 18, 2022 · 16 comments
Closed

Same seed produces different results #1339

thisislance98 opened this issue Nov 18, 2022 · 16 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@thisislance98
Copy link

Describe the bug

Basically is you set the scheduler to EulerAncestralDiscreteScheduler and the custom pipeline to lpw_stable_diffusion you will get different images when you generate

Reproduction

Here is a colab to reproduce the results
https://colab.research.google.com/drive/1ypHAf2TBWnvxkLJf08cSMko3r48WYNYn?usp=sharing

Logs

No response

System Info

I'm using colab with standard gpu

@thisislance98 thisislance98 added the bug Something isn't working label Nov 18, 2022
@averad
Copy link

averad commented Nov 18, 2022

@thisislance98 looking at your code example it appears your not passing the seed information as latents to the pipe. Here is some information that should get you generating images using set seeds.

Stable Diffusion with Repeatable Seeds (excerpt below)

Latents Generation
In order to reuse the seeds we need to generate the latents ourselves. Otherwise, the pipeline will do it internally and we won't have a way to replicate them.

Latents are the initial random Gaussian noise that gets transformed to actual images during the diffusion process.

To generate them, we'll use a different random seed for each latent, and we'll save them so we can reuse them later.

generator = torch.Generator(device=device)

latents = None
seeds = []
for _ in range(num_images):
    # Get a new random seed, store it and use it as the generator state
    seed = generator.seed()
    seeds.append(seed)
    generator = generator.manual_seed(seed)
    
    image_latents = torch.randn(
        (1, pipe.unet.in_channels, height // 8, width // 8),
        generator = generator,
        device = device
    )
    latents = image_latents if latents is None else torch.cat((latents, image_latents))
    
# latents should have shape (4, 4, 64, 64) in this case
latents.shape

@Marcophono2
Copy link

I can confirm that I have the same problem if using torch.bfloat16. The results are similar but different is details. I.E. the same lady in the same position but sometimes smiling, sometimes not. Or one time looking to the left, another time to the right. If using torch.float16 same seeds = same results. Also I am using the Euler(A) scheduler.

@averad
Copy link

averad commented Nov 19, 2022

@Marcophono2 can you share the code you are using to generate the image?

@camenduru
Copy link
Contributor

maybe second pipe missing generator=generator

generator = torch.Generator(device="cuda").manual_seed(87)
display(pipe.text2img(prompt, negative_prompt=neg_prompt, width=512,height=512,max_embeddings_multiples=3, generator=generator).images[0])

generator = torch.Generator(device="cuda").manual_seed(87)
display(pipe.text2img(prompt, negative_prompt=neg_prompt, width=512,height=512,max_embeddings_multiples=3).images[0])

@thisislance98
Copy link
Author

@averad I tried just the same latents in to text2img and am still getting different results

https://colab.research.google.com/drive/1ypHAf2TBWnvxkLJf08cSMko3r48WYNYn?usp=sharing

feel free to update the colab if you can get it to work

@camenduru
Copy link
Contributor

@thisislance98 try generator = torch.cuda.manual_seed_all(seed) or generator = torch.cuda.manual_seed(seed)

for _ in range(2):
    # Get a new random seed, store it and use it as the generator state
    seed = 23 #generator.seed()
    seeds.append(seed)
    generator = torch.cuda.manual_seed_all(seed)
    
    image_latents = torch.randn(
        (1, pipe.unet.in_channels, height // 8, width // 8),
        generator = generator,
        device = device
    )
    # latents = image_latents if latents is None else torch.cat((latents, image_latents))
    # image_latents = image_latents.half()

    display(pipe.text2img(prompt, negative_prompt=neg_prompt,latents=image_latents, width=512,height=512,max_embeddings_multiples=3, num_inference_steps=steps).images[0])

@Marcophono2
Copy link

Sure, @averad .
I found out that it's the combination of EulerA + xformer which outputs different images under same conditions, including seed of course. The differences are only visible in details. The hairs are a bit different or else. Have a look to the four attached images. Initially I thought that two of that images are identical. But indeed they aren't.
16689800240610
16689800228514
16689800134959
16689800111334

from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler

euler = EulerAncestralDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

pipe1 = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",scheduler=euler, revision="fp16",torch_dtype=torch.float16).to("cuda:1")

	pipe1.enable_xformers_memory_efficient_attention() 
	generator = torch.Generator(device='cuda').manual_seed(int(666))
	with torch.inference_mode():
		
		image = pipe1("a beautiful woman", num_inference_steps=int(20), width=int(512), height=int(512), generator = generator, guidance_scale=float(7)).images[0]

@Marcophono2
Copy link

I made some more tests and could take notice that the differences decrease with increasing number of steps. That is surprising. I would have awaited a multiplication of the differences with every step. The very noisy image after step 1 is identical. But starting from step 2 there are differences. But after 250 steps there is only a very, very small difference in general noising left. But only if you switch between two different images at the same position.

@patrickvonplaten
Copy link
Contributor

That's very interesting! @Marcophono2, are you running the code always on the same device, same hardware, etc...?

@Marcophono2
Copy link

Yes, @patrickvonplaten.

@adhikjoshi
Copy link

That's very interesting! @Marcophono2, are you running the code always on the same device, same hardware, etc...?

lpw_stable_diffusion needs lots of updates

@antoche
Copy link
Contributor

antoche commented Dec 6, 2022

I have just noticed that I am getting non-deterministic results too with EulerAncestral and enforcing all RNG to happen on the CPU (see #1514). The interesting thing is, results deterministic if I call disable_xformers_memory_efficient_attention(). Looks like the memory-efficient attention has a strong effect on determinism.

@patrickvonplaten
Copy link
Contributor

BTW, this PR should help with this here a bit: #1718

@amrakm
Copy link

amrakm commented Jan 10, 2023

I can confirm that xformer is the culprit here.

I tried passing a list of generators using the new method described in PR #1718 but it didn't work
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(n_images)]

Only solution was to disable xformers. Note that this will affect the performance

pipe.disable_xformers_memory_efficient_attention()

@takuma104
Copy link
Contributor

Hi,
According to the xFomers team, the default cutlass backend for mem_eff_attention() is not guaranteed to be deterministic. In contrast, Flash Attention would have deterministic behavior, according to them. I have written a patch to apply to Diffusers and it is almost reproducible over multiple generations. I'm thinking of submitting a pull request once I have refined the code a bit more.

https://gist.github.com/takuma104/9d25bb87ae3b52e41e0132aa737c0b03

@github-actions
Copy link

github-actions bot commented Feb 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

9 participants