In [None]:
!pip install datasets

In [None]:
# 훈련에 사용할 기반 모델 불러오기
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
import torch

pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1"
weight_dtype = torch.float32

tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")

In [None]:
# unet의 입력 크기가 바뀌어야 할 경우 이 코드를 이용하여 unet의 conv_in을 재설정 해준다
# 첫 Conv만 바꾸기
old_conv = unet.conv_in
new_conv = torch.nn.Conv2d(8, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

# 3. 기존 weight 일부 복사
with torch.no_grad():
    new_conv.weight[:, :4, :, :] = old_conv.weight  # 기존 latent 부분
    new_conv.bias.copy_(old_conv.bias)

unet.conv_in = new_conv  # 교체

In [None]:
# unet의 in_channels를 바뀐 conv_in에 맞춰 다시 설정한다.
unet.config.in_channels = 8

In [None]:
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
import torch
from peft import LoraConfig, get_peft_model, PeftModel


config = LoraConfig(
    r=64,
    lora_alpha=256,
    target_modules=["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],
    lora_dropout=0.0,
    bias="none",
)
lora_unet = get_peft_model(unet, config)

In [None]:
for name, param in unet.named_parameters():
    if name.startswith("conv_in"):
        param.requires_grad = True

In [None]:
for name, param in lora_unet.named_parameters():
    if "lora_" in name:
        param.requires_grad = True

In [None]:
lora_unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)

In [None]:
trainable_params = [p for p in lora_unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
    trainable_params,
    lr=0.00005,
    weight_decay=0.0,
)

optimizer.zero_grad()

In [None]:
from datasets import load_dataset
import os

dataset = load_dataset("jhc90/webtoon_text_conversion_data", cache_dir="/content/drive/MyDrive/stable_diffusion/webtoon_text_conversion_data")

In [None]:
from datasets import DatasetDict
from torchvision import transforms


image_transforms = transforms.Compose(
        [
            transforms.Lambda(lambda img: img.convert("RGB")),
            transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

def transform_batch(batch):
    batch["input_image"] = [image_transforms(img) for img in batch["input_image"]]
    batch["output_image"] = [image_transforms(img) for img in batch["output_image"]]

    return batch

dataset["train"].set_transform(transform_batch)
dataset["valid"].set_transform(transform_batch)
dataset["test"].set_transform(transform_batch)

In [None]:
print(len(dataset["train"]))
print(len(dataset["valid"]))
print(len(dataset["test"]))

train_dataloader = torch.utils.data.DataLoader(dataset["train"], batch_size=16, shuffle=True, num_workers=8)
valid_dataloader = torch.utils.data.DataLoader(dataset["valid"], batch_size=16, shuffle=False, num_workers=8)
test_dataloader = torch.utils.data.DataLoader(dataset["test"], batch_size=16, shuffle=False, num_workers=8)

In [None]:
from tqdm import tqdm
from diffusers.optimization import get_cosine_schedule_with_warmup


epochs = 5

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=1000,
    num_training_steps=(len(train_dataloader) * epochs),
)

for epoch in range(epochs):
    lora_unet.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        prompt = []
        for input_text in batch["input_text"]:
            prompt.append(input_text)
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        prompt_embeds = text_encoder(text_inputs.input_ids.to(device))
        prompt_embeds = prompt_embeds.last_hidden_state

        input_latents = vae.encode(batch["input_image"].to(device, dtype=weight_dtype)).latent_dist.sample()
        input_latents = input_latents * vae.config.scaling_factor

        output_latents = vae.encode(batch["output_image"].to(device, dtype=weight_dtype)).latent_dist.sample()
        output_latents = output_latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(output_latents)
        bsz = output_latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=output_latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(output_latents, noise, timesteps)

        # concatenate the noised latents with the mask and the masked latents
        latent_model_input = torch.cat([noisy_latents, input_latents], dim=1)

        # Predict the noise residual
        noise_pred = lora_unet(latent_model_input, timesteps, prompt_embeds).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(output_latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

        if step % 100 == 0:
            print(f"Epoch {epoch}, Step {step}, Learning rate: {optimizer.param_groups[0]['lr']}, Loss: {loss.item()}")

        if step % 3800 == 0 and step != 0:
            train_epoch, train_step = epoch, step
            lora_unet.eval()
            with torch.no_grad():
                total_loss = 0
                for step, batch in enumerate(tqdm(valid_dataloader)):
                    prompt = []
                    for input_text in batch["input_text"]:
                        prompt.append(input_text)
                    text_inputs = tokenizer(
                        prompt,
                        padding="max_length",
                        max_length=tokenizer.model_max_length,
                        truncation=True,
                        return_tensors="pt",
                    )

                    prompt_embeds = text_encoder(text_inputs.input_ids.to(device))
                    prompt_embeds = prompt_embeds.last_hidden_state

                    input_latents = vae.encode(batch["input_image"].to(device, dtype=weight_dtype)).latent_dist.sample()
                    input_latents = input_latents * vae.config.scaling_factor

                    output_latents = vae.encode(batch["output_image"].to(device, dtype=weight_dtype)).latent_dist.sample()
                    output_latents = output_latents * vae.config.scaling_factor

                    # Sample noise that we'll add to the latents
                    noise = torch.randn_like(output_latents)
                    bsz = output_latents.shape[0]
                    # Sample a random timestep for each image
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=output_latents.device)
                    timesteps = timesteps.long()

                    # Add noise to the latents according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    noisy_latents = noise_scheduler.add_noise(output_latents, noise, timesteps)

                    # concatenate the noised latents with the mask and the masked latents
                    latent_model_input = torch.cat([noisy_latents, input_latents], dim=1)

                    # Predict the noise residual
                    noise_pred = lora_unet(latent_model_input, timesteps, prompt_embeds).sample

                    # Get the target for loss depending on the prediction type
                    if noise_scheduler.config.prediction_type == "epsilon":
                        target = noise
                    elif noise_scheduler.config.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(output_latents, noise, timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                    loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
                    total_loss += loss.item()

                print(f"Validation Loss: {total_loss / (step + 1)}")
                lora_unet.save_pretrained(f"/content/drive/MyDrive/stable_diffusion/lora_unet_text_conversion/lora_unet_adapter_epoch{train_epoch}_step{train_step}_loss{total_loss / (step + 1)}")
                torch.save(optimizer.state_dict(), f"/content/drive/MyDrive/stable_diffusion/lora_unet_text_conversion/optimizer_epoch{train_epoch}_step{train_step}_loss{total_loss / (step + 1)}")
                torch.save(lr_scheduler.state_dict(), f"/content/drive/MyDrive/stable_diffusion/lora_unet_text_conversion/lr_scheduler_epoch{train_epoch}_step{train_step}_loss{total_loss / (step + 1)}")

            lora_unet.train()

merged_unet = lora_unet.merge_and_unload()
merged_unet.save_pretrained(f"/content/drive/MyDrive/stable_diffusion/lora_unet_text_conversion/lora_unet_last_last_last")