# Kandinsky 2.2

**Kandinsky 2.2** is a multilingual text-to-image model capable of producing more photorealistic images. The model includes an image prior model for creating image embeddings from text prompts, and a decoder model that generates images based on the prior model's embeddings.

In the Diffusers there are two separate scripts for Kandinsky 2.2, one for training the prior model and one for training the decoder model.

We will explore the [`train_text_to_image_prior.py`](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py) script to train the prior model and the [`train_text_to_image_decoder.py`](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py) script to train the decoder.

As always, make sure to install the `diffusers` library from source:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

We need to navigate to the following folder and install the corresponding dependencies:
```bash
cd examples/kandinsky2_2/text_to_image
pip install -r requirements.txt
```
HuggingFace Accelerate is also helpful to train on multiple GPUs with mixed-precision.
```bash
pip install accelerate
```

Now we can initialize a HuggingFace Accelerate environment
```bash
accelerate config
```
To set up a default Acclerate environment without choosing any configurations:
```bash
accelerate config default
```
Or if our environment does not support an interactive shell like a notebook, we can use:
```python
from accelerate.utils import write_basic_config
write_basic_config()
```

## Script parameters

The training script provides many parameters to customize the training run. We can find all of the parameters and their descriptions in the `parse_args()` function.

To speed up training with mixed precision using the `fp16` format, add the `--mixed_precision` parameter to the training command:
```bash
accelerate launch train_text_to_image.py --mixed_precision="fp16"
```
Most of the parameters are identical to the parameters in the **Text-to-image** training guide.

### Min-SNR weighting

The `Min-SNR` weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types.

We can add the `--snr_gamma` and set it to the recommended value of 5.0:
```bash
accelerate launch train_text_to_image.py --snr_gamma=5.0
```

## Training script

The training script is similar to other Text-to-image training, but it has been modified to support training the prior and decoder models.

##### prior model

The `main()` function contains the code for preparing the dataset and training the model.

For the Kandinsky 2.2 prior model, the training script also loads a `CLIPImageProcessor` - in addition to a scheduler and a tokenizer - for preprocessing images and a `CLIPVisionModelWithProjection` model for encoding the images:
```python
    # Load scheduler, image_processor, tokenizer and models.
    noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
    image_processor = CLIPImageProcessor.from_pretrained(
        args.pretrained_prior_model_name_or_path, subfolder="image_processor"
    )
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")

    def deepspeed_zero_init_disabled_context_manager():
        """
        returns either a context list that includes one that will disable zero.Init or an empty context list
        """
        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
        if deepspeed_plugin is None:
            return []

        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
        ).eval()
        text_encoder = CLIPTextModelWithProjection.from_pretrained(
            args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
        ).eval()
```
Kandinsky uses a `PriorTransformer` to generate the image embeddings, so we want to set up the optimizer to learn the prior model's parameters:
```python
    prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")

    # Freeze text_encoder and image_encoder
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)

    # Set prior to trainable.
    prior.train()
    .......
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW
    optimizer = optimizer_cls(
        prior.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
```
Next, the input captions are tokenized, and images are preprocessed by the `CLIPImageProcessor`:
```python
    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        text_input_ids = inputs.input_ids
        text_mask = inputs.attention_mask.bool()
        return text_input_ids, text_mask

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
        examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
        return examples
```
Finally, the training loop converts the input images into latents, adds noise to the image embeddings, and makes a prediction.

##### decoder model

Same here, the `main()` function contains the code for preparing the dataset and training the model.

Unlike the prior model, the decoder initializes a `VQModel` to decode the latents into images and it uses a `UNet2DConditionalModel`:
```python
    def deepspeed_zero_init_disabled_context_manager():
        """
        returns either a context list that includes one that will disable zero.Init or an empty context list
        """
        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
        if deepspeed_plugin is None:
            return []

        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
        vae = VQModel.from_pretrained(
            args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
        ).eval()
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
        ).eval()
    unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
```

Next, the script includes several image transforms and a `preprocessing` function for applying the transforms to the images and returning the pixel values:
```python
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    image_column = args.image_column
    if image_column not in column_names:
        raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")

    def center_crop(image):
        width, height = image.size
        new_size = min(width, height)
        left = (width - new_size) / 2
        top = (height - new_size) / 2
        right = (width + new_size) / 2
        bottom = (height + new_size) / 2
        return image.crop((left, top, right, bottom))

    def train_transforms(img):
        img = center_crop(img)
        img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
        img = np.array(img).astype(np.float32) / 127.5 - 1
        img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
        return img

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
        return examples
```
Lastly, the training loop handles converting the images to latents, adding noise, and predicting the noise residual.

## Launch the script

Same as before, we will train our custom model on the Naruto BLIP captions dataset to generate our own Naruto characters.

##### prior model

```bash
export DATASET_NAME="lambdalabs/naruto-blip-captions"

accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --validation_prompts="A robot naruto, 4k photo" \
  --report_to="wandb" \
  --push_to_hub \
  --output_dir="kandi2-prior-naruto-model"
```

Once training is finished,

In [None]:
from diffusers import AutoPipelineForText2Image, DiffusionPipeline
import torch

prior_pipeline = DiffusionPipeline.from_pretrained(
    'path/to/our/custom/model',
    torch_dtype=torch.float16,
)
prior_components = {
    'prior_' + k: v for k,v in prior_pipeline.components.items()
}

pipeline = AutoPipelineForText2Image.from_pretrained(
    'kandinsky-community/kandinsky-2-2-decoder',
    **prior_components,
    torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()

In [None]:
prompt = 'a robot naruto, 4k photo'

image = pipeline(prompt).images[0]
image

##### decoder model

```bash
export DATASET_NAME="lambdalabs/naruto-blip-captions"

accelerate launch --mixed_precision="fp16"  train_text_to_image_decoder.py \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --validation_prompts="A robot naruto, 4k photo" \
  --report_to="wandb" \
  --push_to_hub \
  --output_dir="kandi2-decoder-naruto-model"
```

Once training is finished,

In [None]:
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
    'path/to/our/custom/model',
    torch_type=torch.float16,
)
pipe.enable_model_cpu_offload()

For the decoder model, we can also perform inference from a saved checkpoint which can be useful for viewing intermediate results by loading the checkpoint into the UNet:

In [None]:
from diffusers import AutoPipelineForText2Image, UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained(
    'path/to/our/custom/model' + '/checkpoint-<N>/unet'
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    'kandinsky-community/kandinsky-2-2-decoder',
    unet=unet,
    torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()

In [None]:
image = pipeline(prompt="A robot naruto, 4k photo").images[0]