# Wuerstchen

The **Wuerstchen** model reduces computational costs by compressing the latent space by 42x, without compromising image quality and accelerating inference.

During training, Wuerstchen uses two models (VQGAN + autoencoder) to compress the latents, and then a third model (text-conditioned latent diffusion model) is conditioned on this highly compressed space to generate an image.

We will explores the [`train_text_to_image_prior.py`](https://github.com/huggingface/diffusers/tree/wuerschten-tests/examples/wuerstchen/text_to_image) script to train our custom Wuerstchen model.

As always, make sure to install the `diffusers` library from source:
```bash
git clone https://github.com/huggingface/diffusers
git checkout wuerschten-tests
cd diffusers
pip install .
```

We need to navigate to the following folder and install the corresponding dependencies:
```bash
cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```
HuggingFace Accelerate is also helpful to train on multiple GPUs with mixed-precision.
```bash
pip install accelerate
```

Now we can initialize a HuggingFace Accelerate environment
```bash
accelerate config
```
To set up a default Acclerate environment without choosing any configurations:
```bash
accelerate config default
```
Or if our environment does not support an interactive shell like a notebook, we can use:
```python
from accelerate.utils import write_basic_config
write_basic_config()
```

## Script parameters

The training script provides many parameters to customize the training run. We can find all of the parameters and their descriptions in the `parse_args()` function.

To speed up training with mixed precision using the `fp16` format, add the `--mixed_precision` parameter to the training command:
```bash
accelerate launch train_text_to_image.py --mixed_precision="fp16"
```
Most of the parameters are identical to the parameters in the **Text-to-image** training guide.

## Training script

The training script is similar to the Text-to-image training guide, but it has been modified to support Wuerstchen.

The `main()` function starts by initializing the image encoder - an **EfficientNet** -- in addition to the usual scheduler and tokenizer:
```python
    def deepspeed_zero_init_disabled_context_manager():
        """
        returns either a context list that includes one that will disable zero.Init or an empty context list
        """
        deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
        if deepspeed_plugin is None:
            return []

        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
        pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
        state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
        image_encoder = EfficientNetEncoder()
        image_encoder.load_state_dict(state_dict["effnet_state_dict"])
        image_encoder.eval()

        text_encoder = CLIPTextModel.from_pretrained(
            args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
        ).eval()
```

We will also load the `WuerstchenPrior` model for optimization
 ```python
    # load prior model
    prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
```
Then we apply some transforms to the images and tokenize the captions
```python
    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        text_input_ids = inputs.input_ids
        text_mask = inputs.attention_mask.bool()
        return text_input_ids, text_mask

    effnet_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
            transforms.CenterCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
        examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
        return examples
```
Finally, the training loop handles compressing the images to latent space with the `EfficientNetEncoder`, adding noise to the latents, and predicting the noise residual with the `WuerstchenPrior` model.

## Launch the script

```bash
export DATASET_NAME="lambdalabs/naruto-blip-captions"

accelerate launch  train_text_to_image_prior.py \
  --mixed_precision="fp16" \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=4 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --dataloader_num_workers=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --validation_prompts="A robot naruto, 4k photo" \
  --report_to="wandb" \
  --push_to_hub \
  --output_dir="wuerstchen-prior-naruto-model"
```

Once training is complete,

In [None]:
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
    'path/to/our/cusotm/wuerstchen/model',
    torch_dtype=torch.float16
)

In [None]:
caption = 'a cute bird naruto holding a shield'
images = pipeline(
    caption,
    width=1024,
    height=1536,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    prior_guidance_scale=4.0,
    num_images_per_prompt=2,
).images[0]
image