# Load Model

In [None]:
# Params
import os
config_path = '/home/lolicon/data/dataset/Illya/model.yaml'
root_dir = '/home/lolicon/data/dataset/Illya/'
main_ckpt_path = '../stable-diffusion-webui/models/Stable-diffusion/CounterfeitV30_v30.safetensors'
lora_ckpt_path = os.path.join(root_dir, 'lora_ckpt', 'epoch=1999.ckpt')
# main_ckpt_path = os.path.join(root_dir, 'lora_ckpt', 'epoch=1999.ckpt')

# weight for control net 
ctrl_pose_path = '../stable-diffusion-webui/models/ControlNet/control_sd15_openpose.pth'

num_epochs = 2
num_gpus = 1
batch_size = 2
logger_freq = 20000
learning_rate = 1e-5
sd_locked = True
only_mid_control = False

In [None]:
import config

import cv2
import einops
import numpy as np
import torch
import random

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
# from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

# apply_canny = CannyDetector()
# config_path = '/home/lolicon/data/dataset/lycoris/model.yaml'
model = create_model(config_path).cpu()
torch_device = 'cuda'
# send to device
model = model.to(torch_device)


In [None]:

from utils.model_loader import load_state_dict

# weight for unet, encoder, decoder, text embedding 
# main_ckpt_path = '../stable-diffusion-webui/models/Stable-diffusion/pastelMixStylizedAnime_pastelMixFull.safetensors'

sd_ckpt = load_state_dict(main_ckpt_path)
sd_ctrl = load_state_dict(ctrl_pose_path)
# lora_model = load_state_dict(lora_ckpt_path)
model.load_multi_state_dict(sd_ckpt, pose_model=sd_ctrl)


# Create Dataset

In [None]:
from data import CustomDataset


dataset = CustomDataset(root_dir)


# Create Lora Network

In [None]:
from modules.Lora import LoRANetwork, LoRAModule

lora_unet = load_state_dict(lora_ckpt_path)
lora_network = LoRANetwork(unet=model.control_model.pose_model, lora_dim=8, alpha=4.0, weights_sd=lora_unet)
lora_network.apply_to(apply_unet=True, apply_text_encoder=False)
# lora_network = load_state_dict(lora_ckpt_path)
# loar_network = Lo
# print(model.control_model.pose_model)




# Ensemble Model

In [None]:

from pytorch_lightning.utilities.distributed import rank_zero_only
import pytorch_lightning as pl


class Ensemble(pl.LightningModule):
    def __init__(self, model, lora, learning_rate=1e-5, sd_locked=True, only_mid_control=False, *args, **kwargs):
        super().__init__(*args, **kwargs)

        lora.apply_to(apply_unet=True, apply_text_encoder=False)

        self.model = model
        self.lora = lora

        self.model.log = self.log

        self.lora.train()
        self.num_timesteps = 40

        self.learning_rate = learning_rate
        self.model.sd_locked = sd_locked
        self.model.only_mid_control = only_mid_control

    def prepare_grad_etc(self):
        self.lora.requires_grad_(True)
        self.model.control_model.requires_grad_(True)
        self.model.model.requires_grad_(True)

    def state_dict(self):
        return self.lora.state_dict()

    def prepare_optimizer_params(self):
        return self.lora.prepare_optimizer_params(self.learning_rate, self.learning_rate)

    def configure_optimizers(self):
        lr = self.learning_rate
        params = self.prepare_optimizer_params()
        opt = torch.optim.AdamW(params, lr=lr)
        return opt

    def training_step(self, batch, batch_idx):
        return self.model.training_step(batch, batch_idx)

    def on_train_batch_start(self, batch, batch_idx):
        self.model.on_train_batch_start(batch, batch_idx)

    def on_train_batch_end(self, *args, **kwargs):
        return self.model.on_train_batch_end(*args, **kwargs)


# Sampling

In [None]:
import os
import cv2
import time
import einops
prompt = ["masterpiece, best quality, seduction, cute face, 1girl, sexy, 8K, high resolution, weapon"]
# prompt_negative = ["deformation, ugly, bad quality, distortion"]

# lora_network.
# # total_model = Ensemble(model, lora_network, learning_rate, sd_locked, only_mid_control)
# lora_network.load_state_dict(lora_unet)
# total_model.lora.load_state_dict(lora_unet)
# prompt = ["1girl, aurora, blonde_hair, city_lights, cloud, cloudy_sky, diffraction_spikes, feather_hair_ornament, feathers, gradient_sky, hair_ornament, holding, horizon, illyasviel_von_einzbern, lens_flare, light_rays, long_hair, looking_at_viewer, ocean, orange_sky, outdoors, prisma_illya, red_eyes, shooting_star, sitting, sky, solo, sparkle, star_\\(sky\\), starry_sky, sun, sunlight, sunrise, sunset, twilight"]
prompt_negative = ["easynegative, lowres, text, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad anatomy, bad hands, error, missing fingers, extra digits, fewer digits, bad feet, bad colours, missing arms, text, water print, logo"]



# using pose root_dir/hint/1.png
hint_path = os.path.join(root_dir, 'hint', '23.png')
source = cv2.imread(hint_path)
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
# require source shape to be the same as (height, width)

height = 512
width = 512
random_seed = time.time()

source = cv2.resize(source, (height, width), interpolation=cv2.INTER_LINEAR)
source = torch.tensor((source.astype(np.float32)/ 255.0)).to(torch_device).unsqueeze(0)
source = einops.rearrange(source, 'b h w c -> b c h w').clone()

print(source.shape)

# import random
# random_seed = random.randint(0, 2147483647)

num_inference_steps = 40
cfg_scale = 7.5
batch_size = 1

In [None]:
from utils.prompt_parser import get_learned_conditioning, get_multicond_learned_conditioning
import utils.prompt_parser
from modules.sampler import VanillaStableDiffusionSampler
from ldm.models.diffusion.ddim import DDIMSampler
from PIL import Image
%load_ext autoreload
%autoreload

# model.to(accelerator.device)
# LoraNet.to(torch_device)
# model.to(torch_device)

# prompt = ["A cool digital illustration of a steampunk computer laboratory with clockwork machines, 4k, detailed, trending in artstation, fantasy vivid colors"]
# prompt_negative = [""]
# orig_sampler = VanillaStableDiffusionSampler(DDIMSampler, )
# lora_network.apply_to
orig_sampler = DDIMSampler(model)


# Positive Text Embedding
positive_text_embeddings = model.get_learned_conditioning(prompt)
print(f'the postive text embedding: {positive_text_embeddings}')

# Negative Text Embedding
negative_text_embeddings  = model.get_learned_conditioning(prompt_negative)
print(f'the negative text embedding: {negative_text_embeddings}')

# Prep latents
torch.manual_seed(random_seed)
latents = torch.randn((batch_size, 4, height // 8, width // 8), device='cpu')
latents = latents.to(torch_device)

add_prompt = 'illya'

conditioning = {'c_crossattn': [positive_text_embeddings]}
conditioning['pose_1'] = [source]
unconditional_conditioning = {'c_crossattn': [negative_text_embeddings]}
unconditional_conditioning['pose_1'] = [source]

conditioning['pose_1_text'] = [model.get_learned_conditioning([f"{add_prompt}"])]
unconditional_conditioning['pose_1_text'] = [model.get_learned_conditioning([f"{add_prompt}"])]

                            #  unconditional_conditioning=negative_text_embeddings,
                            #  conditioning=positive_text_embeddings,
sample_ddim, intermediates = orig_sampler.sample(S=num_inference_steps, 
                             batch_size=1,
                             shape=(4, height // 8, width // 8),
                             conditioning=conditioning,
                             x_T=latents,
                             unconditional_conditioning=unconditional_conditioning,
                             unconditional_guidance_scale=cfg_scale)



# Show image

In [None]:
from PIL import Image
image = model.decode_first_stage(sample_ddim)

# Display
image = (image / 2.0 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0]

In [None]:
image = model.decode_first_stage(sample_ddim)

# Display
image = (image / 2.0 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0]


In [None]:
image = model.decode_first_stage(sample_ddim)

# Display
image = (image / 2.0 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0]


In [None]:

image = model.decode_first_stage(Lora_dataset[4][0].to(accelerator.device).unsqueeze(0))

# Display
image = (image / 2.0 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
pil_images[0]

In [None]:
LoraNet.load_module()
# LoraNet.unload_module()

In [None]:
import pytorch_lightning as pl
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
# Set the gradient

class trainModel(pl.LightningModule):
    def __init__(self, LDMmodel, LoraNet):
        super().__init__()
        self.model = LDMmodel
        self.LoraNet = LoraNet
    
    def set_gradient(self):
        self.model.first_stage_model = self.model.first_stage_model.to('cpu')
        self.model.requires_grad_(requires_grad=True)
        self.LoraNet.requires_grad_(requires_grad=True)
    
    def p_losses(self, x_start, t, cond, noise=None):
        
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.model.q_sample(x_start=x_start, t=t, noise=noise)
        model_out = self.model.model(x_noisy, t=t, c_crossattn=[cond])

        loss_dict = {}
        if self.model.parameterization == "eps":
            target = noise
        elif self.model.parameterization == "x0":
            target = x_start
        else:
            raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")

        loss = self.model.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])

        log_prefix = 'train' if self.model.training else 'val'

        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
        loss_simple = loss.mean() * self.model.l_simple_weight

        loss_vlb = (self.model.lvlb_weights[t] * loss).mean()
        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})

        loss = loss_simple + self.model.original_elbo_weight * loss_vlb

        loss_dict.update({f'{log_prefix}/loss': loss})

        return loss, loss_dict

    def forward(self, image, cond):
        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
        cond = self.model.get_learned_conditioning(cond)
        t = torch.randint(0, self.model.num_timesteps, (image.shape[0],), device=self.model.device).long()
        return self.p_losses(image, t, cond)


myModel = trainModel(model, LoraNet)

myModel.set_gradient()

