# InstructPix2Pix

**InstructPix2Pix** is a Stable Diffusion model trained to edit images from human-provided instructions. This model is conditioned on the text prompt (or editing instruction) and the input image.

We will explore the [`train_instruct_pix2pix.py`](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py)

As always, make sure to install the `diffusers` library from source:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

We need to navigate to the following folder and install the corresponding dependencies:
```bash
cd examples/instruct_pix2pix
pip install -r requirements.txt
```
HuggingFace Accelerate is also helpful to train on multiple GPUs with mixed-precision.
```bash
pip install accelerate
```

Now we can initialize a HuggingFace Accelerate environment
```bash
accelerate config
```
To set up a default Acclerate environment without choosing any configurations:
```bash
accelerate config default
```
Or if our environment does not support an interactive shell like a notebook, we can use:
```python
from accelerate.utils import write_basic_config
write_basic_config()
```

## Script parameters

The training script provides many parameters to customize the training run. We can find all of the parameters and their descriptions in the `parse_args()` function.

To speed up training with mixed precision using the `fp16` format, add the `--mixed_precision` parameter to the training command:
```bash
accelerate launch train_text_to_image.py --mixed_precision="fp16"
```
Most of the parameters are identical to the parameters in the **Text-to-image** training guide. The paramters more aligned with InstructPix2Pix:
* `--original_image_column`: the original image before the edits are made
* `--edited_image_column`: the image after the edits are made
* `--edit_prompt_column`: the instructions to edit the image
* `--conditioning_dropout_prob`: the dropout probability for the edited image and edit prompts during training which enables classifier-free guidance (CFG) for one or both conditioning inputs

## Training script

The general training script is provided in the **Text-to-image** training guide. This guide is specifically about the InstructPix2Pix.

The dataset preprocessing code and training loop are found in the `main()` function.

The script begins by modifying the number of input channels in the first convolutional layer of the UNet to account for InstructPix2Pix's additional conditioning image:
```python
    # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
    # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
    # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
    # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
    # initialized to zero.
    logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
    in_channels = 8
    out_channels = unet.conv_in.out_channels
    unet.register_to_config(in_channels=in_channels)

    with torch.no_grad():
        new_conv_in = nn.Conv2d(
            in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
        )
        new_conv_in.weight.zero_()
        new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
        unet.conv_in = new_conv_in
```

These UNet parameters are updated by the optimizer:
```python
    # Initialize the optimizer
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        unet.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
```
Next, the edited images and edit instructions are preprocessed and tokenized. The same image transformations are applied to the original and edited images.
```python
    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(captions):
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
        ]
    )

    def preprocess_images(examples):
        original_images = np.concatenate(
            [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
        )
        edited_images = np.concatenate(
            [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
        )
        # We need to ensure that the original and the edited images undergo the same
        # augmentation transforms.
        images = np.concatenate([original_images, edited_images])
        images = torch.tensor(images)
        images = 2 * (images / 255) - 1
        return train_transforms(images)

    def preprocess_train(examples):
        # Preprocess images.
        preprocessed_images = preprocess_images(examples)
        # Since the original and edited images were concatenated before
        # applying the transformations, we need to separate them and reshape
        # them accordingly.
        original_images, edited_images = preprocessed_images.chunk(2)
        original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
        edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)

        # Collate the preprocessed images into the `examples`.
        examples["original_pixel_values"] = original_images
        examples["edited_pixel_values"] = edited_images

        # Preprocess the captions.
        captions = list(examples[edit_prompt_column])
        examples["input_ids"] = tokenize_captions(captions)
        return examples
```
Finally, in the training loop, it starts by encoding edited images into latent space, applies dropout to the original image and edit instruction embeddings to support CFG. This is what enables the model to modulate the influence of the edit instruction and original image on the edited image.

## Launch the script

We will use the `fusing/instructpix2pix-1000-samples` dataset to train our custom InstructPix2Pix model.

Run the script to start training:
```bash
accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_ID \
    --enable_xformers_memory_efficient_attention \
    --resolution=256 \
    --random_flip \
    --train_batch_size=4 \
    --gradient_accumulation_steps=4 \
    --gradient_checkpointing \
    --max_train_steps=15000 \
    --checkpointing_steps=5000 \
    --checkpoints_total_limit=1 \
    --learning_rate=5e-05 \
    --max_grad_norm=1 \
    --lr_warmup_steps=0 \
    --conditioning_dropout_prob=0.05 \
    --mixed_precision=fp16 \
    --seed=42 \
    --push_to_hub
```

After training is completed, we can use our new InstructPix2Pix for inference:

In [None]:
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
from diffusers.utils import load_image

pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "path/to/our/custom/model",
    torch_dtype=torch.float16
).to("cuda")

In [None]:
image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png")
prompt = "add some ducks to the lake"
num_inference_steps = 20
image_guidance_scale = 1.5
guidance_scale = 10
generator = torch.Generator("cuda").manual_seed(111)

edited_image = pipeline(
   prompt,
   image=image,
   num_inference_steps=num_inference_steps,
   image_guidance_scale=image_guidance_scale,
   guidance_scale=guidance_scale,
   generator=generator,
).images[0]
edited_image

## Stable Diffusion XL

SDXL adds a second text-encoder to its architecture. We need to use the [`train_instruct_pix2pix_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py) script to train a SDXL model to follow image editing instructions.