New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Colab notebook #2
Comments
@woctezuma thanks!!! Is the base the only checkpoint available for the base diffusion model? I cannot reproduce the same results as showed in paper images "Figure 1" with the same text prompt. Crowson, K. Clip guided diffusion hq 256x256 |
Unfortunately, this is normal, because the publicly available model:
You should get outputs similar to the third row of Figure 9. From a user perspective, the main benefit of GLIDE is that it is much faster than the CLIP-guided methods which I have tried so far.
I think so. From what I can see in the code below, there are 6 checkpoints:
glide-text2im/glide_text2im/download.py Lines 10 to 17 in 742510e
|
@woctezuma thanks! I can see that the sampling part is slightly different than yours, adding the # Create the text tokens to feed to the model.
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(
tokens, options['text_ctx']
)
# Create the classifier-free guidance tokens (empty)
full_batch_size = batch_size * 2
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
[], options['text_ctx']
)
# Pack the tokens together into model kwargs.
model_kwargs = dict(
tokens=th.tensor(
[tokens] * batch_size + [uncond_tokens] * batch_size, device=device
),
mask=th.tensor(
[mask] * batch_size + [uncond_mask] * batch_size,
dtype=th.bool,
device=device,
),
)
# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = th.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return th.cat([eps, rest], dim=1)
# Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
model.del_cache()
# Show the output
show_images(samples) |
To clarify any confusion:
Unless I am missing something, the # Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
model.del_cache()
I need to see the diff of what you did to understand better. I would be glad to test this and see the results, if they are better. :) The black cat with white paws looks nice. 👍 |
Thanks! I have two versions, this one samples = diffusion.p_sample_loop(
model,
(batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=cond_fn,
) where
and in the latest colab from the repo samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size] with def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = th.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return th.cat([eps, rest], dim=1) |
I have slightly changed the structure of your
text2im
notebook so that:Run
text2im.ipynb
Reference: https://github.com/woctezuma/glide-text2im-colab
The text was updated successfully, but these errors were encountered: