-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Adding VQGAN Training script #5483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
cd86f42
Init commit
isamu-isozaki c0e44b0
Removed einops
isamu-isozaki ccaf393
Added default movq config for training
isamu-isozaki 4b361cc
Update explanation of prompts
isamu-isozaki a726069
Fixed inheritance of discriminator and init_tracker
isamu-isozaki 0b0cea3
Fixed incompatible api between muse and here
isamu-isozaki 68be3c5
Fixed output
isamu-isozaki 3072e5d
Setup init training
isamu-isozaki a302201
Basic structure done
isamu-isozaki 388f880
Removed attention for quick tests
isamu-isozaki fca82c5
Style fixes
isamu-isozaki 1924fab
Fixed vae/vqgan styles
isamu-isozaki 5637444
Removed redefinition of wandb
isamu-isozaki 2f5421d
Fixed log_validation and tqdm
isamu-isozaki e318ca8
Nothing commit
isamu-isozaki c69bff6
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki f59f22a
Fixed merge conflicts
isamu-isozaki ce00b6c
Added commit loss to lookup_from_codebook
isamu-isozaki 71c612e
Merge branch 'main' into vqgan
sayakpaul 79bdc26
Update src/diffusers/models/vq_model.py
isamu-isozaki d16fea1
Adding perliminary README
isamu-isozaki ce7b2ec
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki af8a47d
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki 42504e5
Fixed one typo
isamu-isozaki a3c6658
Merge branch 'main' into vqgan
sayakpaul 7c6aeec
Local changes
isamu-isozaki 7854b0f
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki e1ed9ee
Merge branch 'main' into vqgan
sayakpaul 4ad7a22
Fixed main issues
isamu-isozaki 4570347
Merging
isamu-isozaki 0e79a25
Merging
isamu-isozaki 5087644
Update src/diffusers/models/vq_model.py
isamu-isozaki 45abf09
Testing+Fixed bugs in training script
isamu-isozaki a3f1e03
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki eb59684
Some style fixes
isamu-isozaki 97367cb
Added wandb to docs
isamu-isozaki 35aa51b
Merge branch 'main' into vqgan
sayakpaul d797bcd
Fixed timm test
isamu-isozaki d04733c
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki cc9a3e7
Merge branch 'main' into vqgan
isamu-isozaki d481d1f
Merge branch 'main' into vqgan
sayakpaul 1149dfb
get testing suite ready.
sayakpaul b219a1a
Merge branch 'main' into vqgan
sayakpaul d705ed4
remove return loss
isamu-isozaki 75d36b6
remove return_loss
isamu-isozaki 6e3ef01
Remove diffs
isamu-isozaki adbed45
Remove diffs
isamu-isozaki d19c78e
Merge branch 'main' into vqgan
DN6 9ebae82
Merge branch 'main' into vqgan
isamu-isozaki 9f46121
fix ruff format
isamu-isozaki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| ## Training an VQGAN VAE | ||
| VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file). | ||
|
|
||
|
|
||
| Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets). | ||
|
|
||
| ### Installing the dependencies | ||
|
|
||
| Before running the scripts, make sure to install the library's training dependencies: | ||
|
|
||
| **Important** | ||
|
|
||
| To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | ||
| ```bash | ||
| git clone https://github.com/huggingface/diffusers | ||
| cd diffusers | ||
| pip install . | ||
| ``` | ||
|
|
||
| Then cd in the example folder and run | ||
| ```bash | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
|
|
||
| And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
|
||
| ```bash | ||
| accelerate config | ||
| ``` | ||
|
|
||
| ### Training on CIFAR10 | ||
|
|
||
| The command to train a VQGAN model on cifar10 dataset: | ||
|
|
||
| ```bash | ||
| accelerate launch train_vqgan.py \ | ||
| --dataset_name=cifar10 \ | ||
| --image_column=img \ | ||
| --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \ | ||
| --resolution=128 \ | ||
| --train_batch_size=2 \ | ||
| --gradient_accumulation_steps=8 \ | ||
| --report_to=wandb | ||
| ``` | ||
|
|
||
| An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images). | ||
| The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below. | ||
|
|
||
| # Modifying the architecture | ||
|
|
||
| To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below | ||
| ``` | ||
| { | ||
| "_class_name": "VQModel", | ||
| "_diffusers_version": "0.17.0.dev0", | ||
| "act_fn": "silu", | ||
| "block_out_channels": [ | ||
| 128, | ||
| 256, | ||
| 256, | ||
| 512 | ||
| ], | ||
| "down_block_types": [ | ||
| "DownEncoderBlock2D", | ||
| "DownEncoderBlock2D", | ||
| "DownEncoderBlock2D", | ||
| "AttnDownEncoderBlock2D" | ||
| ], | ||
| "in_channels": 3, | ||
| "latent_channels": 4, | ||
| "layers_per_block": 2, | ||
| "norm_num_groups": 32, | ||
| "norm_type": "spatial", | ||
| "num_vq_embeddings": 16384, | ||
| "out_channels": 3, | ||
| "sample_size": 32, | ||
| "scaling_factor": 0.18215, | ||
| "up_block_types": [ | ||
| "AttnUpDecoderBlock2D", | ||
| "UpDecoderBlock2D", | ||
| "UpDecoderBlock2D", | ||
| "UpDecoderBlock2D" | ||
| ], | ||
| "vq_embed_dim": 4 | ||
| } | ||
| ``` | ||
| To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below | ||
| ``` | ||
| { | ||
| "_class_name": "VQModel", | ||
| "_diffusers_version": "0.17.0.dev0", | ||
| "act_fn": "silu", | ||
| "block_out_channels": [ | ||
| 128, | ||
| 256, | ||
| 256, | ||
| ], | ||
| "down_block_types": [ | ||
| "DownEncoderBlock2D", | ||
| "DownEncoderBlock2D", | ||
| "DownEncoderBlock2D", | ||
| ], | ||
| "in_channels": 3, | ||
| "latent_channels": 4, | ||
| "layers_per_block": 2, | ||
| "norm_num_groups": 32, | ||
| "norm_type": "spatial", | ||
| "num_vq_embeddings": 16384, | ||
| "out_channels": 3, | ||
| "sample_size": 32, | ||
| "scaling_factor": 0.18215, | ||
| "up_block_types": [ | ||
| "UpDecoderBlock2D", | ||
| "UpDecoderBlock2D", | ||
| "UpDecoderBlock2D" | ||
| ], | ||
| "vq_embed_dim": 4 | ||
| } | ||
| ``` | ||
| For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that. | ||
|
|
||
| ## Extra training tips/ideas | ||
| During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646) | ||
| Secondly, training should seem to be done when both the discriminator and the generator loss converges. | ||
| Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it. | ||
| Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| """ | ||
| Ported from Paella | ||
| """ | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| from diffusers.models.modeling_utils import ModelMixin | ||
|
|
||
|
|
||
| # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py | ||
| class Discriminator(ModelMixin, ConfigMixin): | ||
| @register_to_config | ||
| def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6): | ||
| super().__init__() | ||
| d = max(depth - 3, 3) | ||
| layers = [ | ||
| nn.utils.spectral_norm( | ||
| nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) | ||
| ), | ||
| nn.LeakyReLU(0.2), | ||
| ] | ||
| for i in range(depth - 1): | ||
| c_in = hidden_channels // (2 ** max((d - i), 0)) | ||
| c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) | ||
| layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) | ||
| layers.append(nn.InstanceNorm2d(c_out)) | ||
| layers.append(nn.LeakyReLU(0.2)) | ||
| self.encoder = nn.Sequential(*layers) | ||
| self.shuffle = nn.Conv2d( | ||
| (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 | ||
| ) | ||
| self.logits = nn.Sigmoid() | ||
|
|
||
| def forward(self, x, cond=None): | ||
| x = self.encoder(x) | ||
| if cond is not None: | ||
| cond = cond.view( | ||
| cond.size(0), | ||
| cond.size(1), | ||
| 1, | ||
| 1, | ||
| ).expand(-1, -1, x.size(-2), x.size(-1)) | ||
| x = torch.cat([x, cond], dim=1) | ||
| x = self.shuffle(x) | ||
| x = self.logits(x) | ||
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| accelerate>=0.16.0 | ||
| torchvision | ||
| transformers>=4.25.1 | ||
| datasets | ||
| timm | ||
| numpy | ||
| tqdm | ||
| tensorboard |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.