# Setup

In [1]:
from diffusers import (
    DDPMScheduler,
    UNet2DConditionModel,
    # UNet2DModel,
    AutoencoderKL,
    StableDiffusionPipeline,
    DiffusionPipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import convert_state_dict_to_diffusers
from diffusers.training_utils import cast_training_params
from diffusers.utils import make_image_grid
from transformers import CLIPTextModel, CLIPTokenizer

from datasets import load_dataset
from torchvision import transforms

from peft import LoraConfig
from peft.utils import get_peft_model_state_dict

from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration

import os
import gc
import shutil
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
config = {
    "output_dir": "./output",
    # "log_dir": "./log",
    "batch_size": 8,
    "train_epochs": 20,
    "save_ckpt_every_n_epochs": 5,
    "validate_every_n_epochs": 2,
    "lora_rank": 4,
    # "lora_rank": 2,
    # "lora_rank": 8,
    "seed": 42,
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# accelerator_project_config = ProjectConfiguration(project_dir=config["output_dir"], logging_dir=config["log_dir"])
accelerator_project_config = ProjectConfiguration(project_dir=config["output_dir"])
os.makedirs(config["output_dir"], exist_ok=True)
# os.makedirs(config["log_dir"], exist_ok=True)

accelerator = Accelerator(
    mixed_precision="fp16",  # use amp
    project_config=accelerator_project_config,
)

# Data Preparation

In [6]:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
    "kusnim1121/filtered-one-piece-with-caption",
)

# Preprocessing the datasets.
preprocess = transforms.Compose(
    [
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.RandomCrop(512),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    captions = tokenizer(
        examples["caption"],
        max_length=tokenizer.model_max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).input_ids
    return {"images": images, "captions": captions}


# Set the training transforms
train_dataset = dataset["train"].with_transform(transform)
# val_dataset = dataset["val"].with_transform(transform)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    drop_last=True,
)




# SD setup & LoRA config

In [None]:
# Load scheduler, tokenizer and models.
# SD_PATH = "/choi/model/stable-diffusion-v1-5"
SD_PATH = "runwayml/stable-diffusion-v1-5"

noise_scheduler = DDPMScheduler.from_pretrained(SD_PATH, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(SD_PATH, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(SD_PATH, subfolder="unet")
# unet = UNet2DModel.from_pretrained(SD_PATH, subfolder="unet")
text_encoder = CLIPTextModel.from_pretrained(SD_PATH, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained(SD_PATH, subfolder="tokenizer")
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Freeze the unet parameters before adding adapters
for param in unet.parameters():
    param.requires_grad_(False)

unet_lora_config = LoraConfig(
    r=config["lora_rank"],
    lora_alpha=config["lora_rank"],
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)

# Move unet, vae to device and cast to weight_dtype
weight_dtype = torch.float16
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)

lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
optimizer = torch.optim.AdamW(
    lora_layers,
    lr=1e-4,
    # lr=3e-4,
    betas=(0.9, 0.999),
    weight_decay=1e2,
    eps=1e-8,
)

lr_scheduler = get_scheduler(
    name="constant",
    optimizer=optimizer,
    # num_warmup_steps=500,
    num_training_steps=config["train_epochs"] * len(train_dataloader),
)

# Log validation function

In [None]:
def validate(pipeline, epoch):
    nrows = 3
    val_prompts = [
        "a man, in one piece style",
        "a woman, in one piece style",
        "an anime character, in one piece style",
    ] * nrows
    images = pipeline(
        # "an image of an ocean landscape",
        val_prompts,
        num_inference_steps=30,
        generator=torch.Generator().manual_seed(config["seed"]),
    ).images  # List[PIL.Image]

    # Make a grid out of the images
    # image_grid = make_image_grid(images, rows=4, cols=4)
    image_grid = make_image_grid(images, rows=nrows, cols=3)

    # Save the images
    test_dir = os.path.join(config["output_dir"], "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

In [None]:
pipeline = DiffusionPipeline.from_pretrained(
    SD_PATH,
    unet=unet,
    torch_dtype=weight_dtype,
).to(accelerator.device)
images = validate(pipeline, -1)
del pipeline
gc.collect()
torch.cuda.empty_cache()

# Train & Validate

In [None]:
# Prepare for mixed precision training
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# Train!
global_step = 0
train_loss = 0.0
for epoch in range(config["train_epochs"]):
    progress_bar = tqdm(total=len(train_dataloader))
    progress_bar.set_description(f"Epoch {epoch}")

    unet.train()
    for step, batch in enumerate(train_dataloader):
        # Convert images to latent space
        latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (config["batch_size"],), device=latents.device
        )
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        with torch.no_grad():
            encoder_hidden_states = text_encoder(batch["captions"], return_dict=False)[0].to(accelerator.device)
        # Predict the noise residual and compute loss
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

        loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

        # Backpropagate
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
        accelerator.log(logs, step=global_step)
        global_step += 1
        train_loss = 0.0

    if epoch % config["save_ckpt_every_n_epochs"] == 0 or epoch == config["train_epochs"] - 1:
        save_path = os.path.join(config["output_dir"], f"epoch-{epoch:04d}")
        if os.path.exists(save_path):
            shutil.rmtree(save_path)
        # accelerator.save_state(save_path)

        unwrapped_unet = accelerator.unwrap_model(unet)
        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))

        StableDiffusionPipeline.save_lora_weights(
            save_directory=save_path,
            unet_lora_layers=unet_lora_state_dict,
            safe_serialization=True,
        )

        print(f"Saved state to {save_path}")

    if epoch % config["validate_every_n_epochs"] == 0 or epoch == config["train_epochs"] - 1:
        # create pipeline
        pipeline = DiffusionPipeline.from_pretrained(
            # args.pretrained_model_name_or_path,
            SD_PATH,
            unet=accelerator.unwrap_model(unet),
            # torch_dtype=weight_dtype,
        ).to(accelerator.device)
        images = validate(pipeline, epoch)

        del pipeline
        gc.collect()
        torch.cuda.empty_cache()

# Push to hub

In [None]:
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub import create_repo, upload_folder

# Push to hub
repo_id = create_repo(repo_id="kusnim1121/stable-diffusion-one-piece-lora", exist_ok=True).repo_id
model_description = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {SD_PATH}. The weights were fine-tuned on the one-piece-with-caption dataset.
"""
model_card = load_or_create_model_card(
    repo_id_or_path=repo_id,
    from_training=True,
    license="mit",
    base_model=SD_PATH,
    model_description=model_description,
    inference=True,
)

tags = [
    "stable-diffusion",
    "stable-diffusion-diffusers",
    "text-to-image",
    "diffusers",
    "diffusers-training",
    "lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.push_to_hub(repo_id)

upload_folder(
    repo_id=repo_id,
    folder_path=os.path.join(config["output_dir"], "epoch-0000"),
    commit_message="End of training",
    # ignore_patterns=["step_*", "epoch_*"],
)

# Test

In [None]:
pipeline = DiffusionPipeline.from_pretrained(
    SD_PATH,
    torch_dtype=torch.float16,
).to(device)
# load attention processors
pipeline.load_lora_weights(os.path.join("./output", "epoch-0499"))
images = pipeline(
    prompt="an anime character, one piece style",
    generator=torch.Generator(device="cpu").manual_seed(config["seed"]),
    num_inference_steps=30,
).images
images[0]

In [None]:
pipeline.unload_lora_weights()
images = pipeline(  # TODO: fix
    prompt="an anime character, one piece style",
    generator=torch.Generator(device="cpu").manual_seed(config["seed"]),
    num_inference_steps=30,
).images
images[0]