# Training Kandinsky 2.2 with LoRA
This notebook can train LoRA for both the prior and decoder, as well as do basic inference.



In [None]:
# @title Install Requirements
%%bash
git clone https://github.com/ai-forever/diffusers
pip install /content/diffusers
pip install transformers
pip install accelerate
pip install bitsandbytes
pip install safetensors
apt install aria2

In [None]:
# @title Mount Drive (Optional)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# @title Locating Train Data Directory
# @markdown Define the location of your training data. This cell will also create a folder based on your input.
# @markdown Regularization images are not presently used.
import os
from IPython.utils import capture

%store -r

train_data_dir = "/content/LoRA/train_data"  # @param {type:'string'}
reg_data_dir = "/content/LoRA/reg_data"  # @param {type:'string'}

for dir in [train_data_dir, reg_data_dir]:
    if dir:
        with capture.capture_output() as cap:
            os.makedirs(dir, exist_ok=True)
            %store dir
            del cap

print(f"Your train data directory : {train_data_dir}")
if reg_data_dir:
    print(f"Your reg data directory : {reg_data_dir}")

In [None]:
# @title Extract Dataset

import os
import shutil
from pathlib import Path

# @markdown Use this section if your dataset is in a `zip` file and has been uploaded somewhere. This code cell will download your dataset and automatically extract it to the `train_data_dir` if the `unzip_to` variable is empty.
root_dir = "/content"
zipfile_url = "" #@param {type:"string"}
zipfile_name = "zipfile.zip"
unzip_to = "" #@param {type:"string"}

hf_token = "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" #@param {type:"string"}
user_header = f'"Authorization: Bearer {hf_token}"'

if unzip_to:
    os.makedirs(unzip_to, exist_ok=True)
else:
    unzip_to = train_data_dir


def download_dataset(url):
    if url.startswith("/content"):
        return url
    elif "drive.google.com" in url:
        os.chdir(root_dir)
        !gdown --fuzzy {url}
        return f"{root_dir}/{zipfile_name}"
    elif "huggingface.co" in url:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {root_dir} -o {zipfile_name} {url}
        return f"{root_dir}/{zipfile_name}"
    else:
        !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {root_dir} -o {zipfile_name} {url}
        return f"{root_dir}/{zipfile_name}"


def extract_dataset(zip_file, output_path):
    if zip_file.startswith("/content"):
        if ".tar" in zip_file:
          !tar -xvf {zip_file} -C "{output_path}"
        else:
          !unzip -j -o {zip_file} -d "{output_path}"
    else:
        if ".tar" in zip_file:
          !tar -xvf "{zip_file}" -C "{output_path}"
        else:
          !unzip -j -o "{zip_file}" -d "{output_path}"


def remove_files(train_dir, files_to_move):
    for filename in os.listdir(train_dir):
        file_path = os.path.join(train_dir, filename)
        if filename in files_to_move:
            if not os.path.exists(file_path):
                shutil.move(file_path, training_dir)
            else:
                os.remove(file_path)


zip_file = download_dataset(zipfile_url)
extract_dataset(zip_file, unzip_to)
os.remove(zip_file)

files_to_move = (
    "meta_cap.json",
    "meta_cap_dd.json",
    "meta_lat.json",
    "meta_clean.json",
)

remove_files(train_data_dir, files_to_move)


In [None]:
# @title Create CSV from dataset. Assumes all images are the same file format.
extension = ".png" # @param {type:"string"}
caption_extension = ".txt" # @param {type:"string"}
csv_output_path = "/content/LoRA/dataset.csv"
import csv
from pathlib import Path
with open(csv_output_path, 'w', newline='') as csvfile:
  csvwriter = csv.writer(csvfile)
  csvwriter.writerow(['paths','caption'])
  for filename in os.listdir(train_data_dir):
    if filename.endswith(extension):
      caption = ""
      with open(os.path.join(train_data_dir,filename.replace(extension,caption_extension)), 'r') as captionfile:
        caption = captionfile.read().rstrip()
        #caption = caption.replace('"','')
        #caption = caption.replace(',','')
        print(caption)
      csvwriter.writerow([os.path.join(train_data_dir,filename),caption])

with open(csv_output_path, "r+") as fn:
  content = fn.read()
  content = content.rstrip('\n')
  fn.seek(0)
  fn.write(content)
  fn.truncate()


In [None]:
# @title Set Parameters
# @markdown The official example uses the same parameters for both prior and decoder, so I've replicated this here.<br>
# @markdown Rank is hardcoded to 4. This is intentional, you WILL get NaN errors on higher ranks.
output_to_drive = False # @param {type:'boolean'}
if output_to_drive:
  prior_output_dir='/content/drive/MyDrive/kandinsky_lora/prior'
  decoder_output_dir='/content/drive/MyDrive/kandinsky_lora/decoder'
else:
  prior_output_dir='/content/kandinsky_lora/prior'
  decoder_output_dir='/content/kandinsky_lora/decoder'
mixed_precision = "fp16" # @param ["fp16","bf16","none"]
image_resolution = 768 # @param {type:"slider", min:512, max:1024, step:128}
train_batch_size = 1 # @param {type: "integer"}
max_train_steps = 2500 # @param {type: "integer"}
checkpointing_steps = 500 # @param {type: "integer"}
gradient_accumulation_steps = 1 # @param {type: "integer"}
lr = 1e-5 # @param {type: "number"}
lr_scheduler = "cosine" # @param ["constant", "constant_with_warmup", "cosine" ]
lr_warmup_steps = 0 # @param {type: "integer"}
dataloader_num_workers = 0 # @param {type: "integer"}
snr_gamma = 5.0 # @param {type: "number"}
weight_decay = 0.0 # @param {type: "number"}
train_seed = 0 # @param {type: "integer"}

In [None]:
# @title Train Decoder Lora
command = (f'python3 /content/diffusers/examples/kandinsky2_2_train/tune_decoder_lora.py '
           #f'--train_image_folder={train_data_dir} '
           f'--train_images_paths_csv=/content/LoRA/dataset.csv '
           f'--image_resolution={image_resolution} '
           f'--train_batch_size={train_batch_size} '
           f'--gradient_accumulation_steps={gradient_accumulation_steps} '
           f'--gradient_checkpointing '
           f'--mixed_precision={mixed_precision}  '
           f'--max_train_steps={max_train_steps} '
           f'--lr={lr} '
           f'--max_grad_norm=1 '
           f'--lr_scheduler={lr_scheduler} '
           f'--lr_warmup_steps={lr_warmup_steps} '
           f'--output_dir={decoder_output_dir} '
           f'--rank=4 '
           f'--snr_gamma={snr_gamma} '
           f'--use_8bit_adam '
           f'--checkpointing_steps={checkpointing_steps} '
           f"--dataloader_num_workers={dataloader_num_workers} "
           f"--seed={train_seed} "
           f"--weight_decay={weight_decay} ")
!{command}

In [None]:
# @title Train Prior Lora
command = (f'python3 /content/diffusers/examples/kandinsky2_2_train/tune_prior_lora.py '
           #f'--train_image_folder={train_data_dir} '
           f'--train_images_paths_csv=/content/LoRA/dataset.csv '
           f'--train_batch_size={train_batch_size} '
           f'--gradient_accumulation_steps={gradient_accumulation_steps} '
           f'--gradient_checkpointing '
           f'--mixed_precision={mixed_precision}  '
           f'--max_train_steps={max_train_steps} '
           f'--lr={lr} '
           f'--max_grad_norm=1 '
           f'--lr_scheduler={lr_scheduler} '
           f'--lr_warmup_steps={lr_warmup_steps} '
           f'--output_dir={prior_output_dir} '
           f'--rank={rank} '
           f'--snr_gamma={snr_gamma} '
           f'--use_8bit_adam '
           f'--checkpointing_steps={checkpointing_steps} '
           f"--dataloader_num_workers={dataloader_num_workers} "
           f"--seed={train_seed} "
           f"--weight_decay={weight_decay} ")
!{command}

In [None]:
# @title Convert To Safetensors
from collections import defaultdict
import torch
from safetensors.torch import load_file, save_file

def shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for ptr, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing

# @markdown Input File
pt_filename = "" # @param {type: "string"}
# @markdown Output File
sf_filename = "" # @param {type: "string"}
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
    loaded = loaded["state_dict"]
shared = shared_pointers(loaded)
for shared_weights in shared:
    for name in shared_weights[1:]:
        loaded.pop(name)

# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
save_file(loaded, sf_filename, metadata={"format": "pt"})

## Inference
This has very few optimizations applied, so it's only particularly useful for basic testing.
Use Kubin for better inference.

In [None]:
# @title Inference Settings

# @markdown ###LoRA settings
use_decoder_lora = True # @param {type: "boolean"}
use_prior_lora = True # @param {type: "boolean"}
decoder_lora_path = "" # @param {type: "string"}
prior_lora_path = "" # @param {type: "string"}

# @markdown ###Generation settings
prompt = "" # @param {type: "string"}
negative_prompt = "" # @param {type: "string"}
prior_steps = 25 # @param {type: "integer"}
decoder_steps = 50 # @param {type: "integer"}
width = 512 # @param {type: "integer"}
height = 512 # @param {type: "integer"}
seed = 1 # @param {type: "integer"}
filename = "demo.png" # @param {type: "string"}

In [None]:
# @title Load Pipelines
import sys
from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
import torch
import PIL
import torch
from diffusers.utils import load_image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection
from diffusers.models import UNet2DConditionModel
import numpy as np
image_encoder = CLIPVisionModelWithProjection.from_pretrained('kandinsky-community/kandinsky-2-2-prior', subfolder='image_encoder').to(torch.float16).to('cuda')
unet = UNet2DConditionModel.from_pretrained('kandinsky-community/kandinsky-2-2-decoder', subfolder='unet').to(torch.float16).to('cuda')
prior = KandinskyV22PriorPipeline.from_pretrained('kandinsky-community/kandinsky-2-2-prior', image_encoder=image_encoder, torch_dtype=torch.float16)
prior = prior.to("cuda")
decoder = KandinskyV22Pipeline.from_pretrained('kandinsky-community/kandinsky-2-2-decoder', unet=unet, torch_dtype=torch.float16)
decoder = decoder.to("cuda")

from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnAddedKVProcessor

if use_decoder_lora:
  lora_attn_procs = {}
  d = torch.load(decoder_lora_path)
  for name in decoder.unet.attn_processors.keys():
      cross_attention_dim = None if name.endswith("attn1.processor") else decoder.unet.config.cross_attention_dim
      if name.startswith("mid_block"):
          hidden_size = decoder.unet.config.block_out_channels[-1]
      elif name.startswith("up_blocks"):
          block_id = int(name[len("up_blocks.")])
          hidden_size = list(reversed(decoder.unet.config.block_out_channels))[block_id]
      elif name.startswith("down_blocks"):
          block_id = int(name[len("down_blocks.")])
          hidden_size = decoder.unet.config.block_out_channels[block_id]
      lora_attn_procs[name] = LoRAAttnAddedKVProcessor(
              hidden_size=hidden_size,
              cross_attention_dim=cross_attention_dim,
              rank=rank,
      ).to('cuda')
  decoder.unet.set_attn_processor(lora_attn_procs)
  decoder.unet.load_state_dict(d, strict=False)

if use_prior_lora:
  lora_attn_procs = {}
  for name in prior.prior.attn_processors.keys():
      lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048).to('cuda')
  prior.prior.set_attn_processor(lora_attn_procs)
  prior.prior.load_state_dict(torch.load(prior_lora_path), strict=False)

In [None]:
# @title Generate
torch.manual_seed(seed)
img_emb = prior(prompt=prompt, num_inference_steps=prior_steps, num_images_per_prompt=1)
neg_emb = prior(prompt=negative_prompt, num_inference_steps=prior_steps, num_images_per_prompt=1)
images = decoder(image_embeds=img_emb.image_embeds, negative_image_embeds=neg_emb.image_embeds, num_inference_steps=decoder_steps, height=height, width=width)
images.images[0].save(filename)