Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
cd86f42
Init commit
isamu-isozaki Oct 23, 2023
c0e44b0
Removed einops
isamu-isozaki Oct 29, 2023
ccaf393
Added default movq config for training
isamu-isozaki Oct 29, 2023
4b361cc
Update explanation of prompts
isamu-isozaki Oct 29, 2023
a726069
Fixed inheritance of discriminator and init_tracker
isamu-isozaki Oct 30, 2023
0b0cea3
Fixed incompatible api between muse and here
isamu-isozaki Oct 30, 2023
68be3c5
Fixed output
isamu-isozaki Oct 30, 2023
3072e5d
Setup init training
isamu-isozaki Oct 30, 2023
a302201
Basic structure done
isamu-isozaki Oct 30, 2023
388f880
Removed attention for quick tests
isamu-isozaki Oct 31, 2023
fca82c5
Style fixes
isamu-isozaki Oct 31, 2023
1924fab
Fixed vae/vqgan styles
isamu-isozaki Oct 31, 2023
5637444
Removed redefinition of wandb
isamu-isozaki Oct 31, 2023
2f5421d
Fixed log_validation and tqdm
isamu-isozaki Oct 31, 2023
e318ca8
Nothing commit
isamu-isozaki Oct 31, 2023
c69bff6
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki Oct 31, 2023
f59f22a
Fixed merge conflicts
isamu-isozaki Feb 28, 2024
ce00b6c
Added commit loss to lookup_from_codebook
isamu-isozaki Feb 28, 2024
71c612e
Merge branch 'main' into vqgan
sayakpaul Mar 7, 2024
79bdc26
Update src/diffusers/models/vq_model.py
isamu-isozaki Mar 7, 2024
d16fea1
Adding perliminary README
isamu-isozaki Mar 8, 2024
ce7b2ec
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki Mar 8, 2024
af8a47d
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki Mar 25, 2024
42504e5
Fixed one typo
isamu-isozaki Mar 25, 2024
a3c6658
Merge branch 'main' into vqgan
sayakpaul Mar 26, 2024
7c6aeec
Local changes
isamu-isozaki Apr 26, 2024
7854b0f
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki Apr 26, 2024
e1ed9ee
Merge branch 'main' into vqgan
sayakpaul Apr 27, 2024
4ad7a22
Fixed main issues
isamu-isozaki Apr 28, 2024
4570347
Merging
isamu-isozaki Apr 28, 2024
0e79a25
Merging
isamu-isozaki Apr 28, 2024
5087644
Update src/diffusers/models/vq_model.py
isamu-isozaki Apr 28, 2024
45abf09
Testing+Fixed bugs in training script
isamu-isozaki Apr 28, 2024
a3f1e03
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki Apr 28, 2024
eb59684
Some style fixes
isamu-isozaki Apr 28, 2024
97367cb
Added wandb to docs
isamu-isozaki Apr 28, 2024
35aa51b
Merge branch 'main' into vqgan
sayakpaul Apr 29, 2024
d797bcd
Fixed timm test
isamu-isozaki Apr 29, 2024
d04733c
Merge branch 'vqgan' of https://github.com/isamu-isozaki/diffusers in…
isamu-isozaki Apr 29, 2024
cc9a3e7
Merge branch 'main' into vqgan
isamu-isozaki Apr 29, 2024
d481d1f
Merge branch 'main' into vqgan
sayakpaul Apr 30, 2024
1149dfb
get testing suite ready.
sayakpaul Apr 30, 2024
b219a1a
Merge branch 'main' into vqgan
sayakpaul Apr 30, 2024
d705ed4
remove return loss
isamu-isozaki Apr 30, 2024
75d36b6
remove return_loss
isamu-isozaki Apr 30, 2024
6e3ef01
Remove diffs
isamu-isozaki Apr 30, 2024
adbed45
Remove diffs
isamu-isozaki Apr 30, 2024
d19c78e
Merge branch 'main' into vqgan
DN6 May 7, 2024
9ebae82
Merge branch 'main' into vqgan
isamu-isozaki May 15, 2024
9f46121
fix ruff format
isamu-isozaki May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install peft
python -m uv pip install peft timm
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ jobs:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/

- name: Failure short reports
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/push_tests_fast.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install peft
python -m uv pip install peft timm
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
Expand Down
127 changes: 127 additions & 0 deletions examples/vqgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
## Training an VQGAN VAE
VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).


Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

Then cd in the example folder and run
```bash
pip install -r requirements.txt
```


And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

### Training on CIFAR10

The command to train a VQGAN model on cifar10 dataset:

```bash
accelerate launch train_vqgan.py \
--dataset_name=cifar10 \
--image_column=img \
--validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
--resolution=128 \
--train_batch_size=2 \
--gradient_accumulation_steps=8 \
--report_to=wandb
```

An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below.

# Modifying the architecture

To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below
```
{
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
256,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"AttnDownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 16384,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"vq_embed_dim": 4
}
```
To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below
```
{
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
256,
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 16384,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"vq_embed_dim": 4
}
```
For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.

## Extra training tips/ideas
During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)
Secondly, training should seem to be done when both the discriminator and the generator loss converges.
Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss.
48 changes: 48 additions & 0 deletions examples/vqgan/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Ported from Paella
"""

import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin


# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
class Discriminator(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = hidden_channels // (2 ** max((d - i), 0))
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d(
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
)
self.logits = nn.Sigmoid()

def forward(self, x, cond=None):
x = self.encoder(x)
if cond is not None:
cond = cond.view(
cond.size(0),
cond.size(1),
1,
1,
).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x
8 changes: 8 additions & 0 deletions examples/vqgan/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
datasets
timm
numpy
tqdm
tensorboard
Loading