Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional *image* generation (img2img) #70

Open
CallShaul opened this issue Jun 19, 2023 · 2 comments
Open

Conditional *image* generation (img2img) #70

CallShaul opened this issue Jun 19, 2023 · 2 comments

Comments

@CallShaul
Copy link

CallShaul commented Jun 19, 2023

Hi,

In order to add support for conditional image generation, in addition to the initial image embedding into unet_cond,
(extra_args['unet_cond'] = img_cond) what should I put in extra_args['cross_cond'] and extra_args['cross_cond_padding'] ?

(before the loss calculation in the line: losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args))

@crowsonkb
@nekoshadow1
@brycedrennan

Thanks !

@CallShaul CallShaul changed the title img2img support conditional *image* generation (img2img) Oct 1, 2023
@CallShaul CallShaul changed the title conditional *image* generation (img2img) Conditional *image* generation (img2img) Oct 1, 2023
@drscotthawley
Copy link

drscotthawley commented Feb 22, 2024

I'd be interested in seeing an answer to this as well. e.g. for the simple case of MNIST, how might we implement (or activate) class-conditional generation?

i see class_cond in the code, and a cond_dropout_rate in the config files, so maybe it's already training that way... But the in the output from demo(), it seems to just be random. Perhaps we just need to change line 369 in train.py from this...

            class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device)

To something more "intentional", such as...

            class_cond = torch.remainder(torch.arange(0, accelerator.num_processes*n_per_proc-1), num_classes).reshape([accelerator.num_processes, n_per_proc]).int().to(device)

....?

Update: yep! That worked! :-)

demo_grid_13499_969d27db3303994e126b

@CallShaul
Copy link
Author

CallShaul commented Apr 18, 2024

Solution:

I've made it work, here's the main steps:
(some more workarounds are needed to make it run, in the inference as well, but this is the main idea):

  1. get the conditioned image in each batch training iteration:
unet_cond = get_condition_channels(model_config, img_cond)
extra_args['unet_cond'] = unet_cond.to(device)
  1. modify the "losses" line calculation, and add the image condition there:
    losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args)
  • some fixes are needed to be done in the "forward" function, on the model file image_v1.py
  • perform similar conditioning in the inference stage

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants