In [None]:
import os
import torch
import torchvision
import math
import numpy as np
from PIL import Image
from diffusers import AutoencoderKL

In [None]:
pipe = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", torch_type=torch.float16)
pipe.to("cuda:1", torch.float16)

In [None]:
def add_padding(torch_image, patch_size):
    tile_count_x = math.ceil(torch_image.shape[2] / patch_size)
    tile_count_y = math.ceil(torch_image.shape[1] / patch_size)
    new_width = tile_count_x * patch_size
    new_height = tile_count_y * patch_size
    pad_size_x = new_width - torch_image.shape[2]
    pad_size_y = new_height - torch_image.shape[1]
    torch_image = torch.nn.functional.pad(torch_image, (0, pad_size_x, 0, pad_size_y), mode="reflect", value=0)
    return torch_image

In [None]:
def remove_pads(torch_image, output_image):
    input_width = torch_image.shape[2]
    input_height = torch_image.shape[1]
    output = output_image[:, :input_height, :input_width]
    return output

In [None]:
def create_pathes(torch_image, patch_size):
    print(torch_image.shape)
    image_width = torch_image.shape[2]
    image_height = torch_image.shape[1]
    patch_size = patch_size
    stride = patch_size
    pathes = []
    for x in range(0, image_width, stride):
        for y in range(0, image_height, stride):
            pathes.append(torch_image[:, y:y+patch_size, x:x+patch_size])
    return pathes

In [None]:
def merge_pathes(pathes, padded_image_width, padded_image_height, patch_size):
    stride = patch_size
    output_image = torch.zeros((3, padded_image_height, padded_image_width))
    for i, x in enumerate(range(0, padded_image_width, stride)):
        for j, y in enumerate(range(0, padded_image_height, stride)):
            output_image[:, y:y+patch_size, x:x+patch_size] = pathes[i*padded_image_height//stride+j]
    return output_image

In [None]:
image = Image.open("../1_media/input_images/cat.jpg").convert("RGB")
transformed_image = torchvision.transforms.ToTensor()(image)
padded_image = add_padding(transformed_image, 1024)
patches = create_pathes(padded_image, 1024)

In [None]:
latents = []
decodeds = []
for patch in patches:
    patch = patch.unsqueeze(0) * 2 - 1
    patch = patch.to("cuda:1", torch.float16)
    latent = pipe.encode(patch).latent_dist.sample()  * 0.18215
    latents.append(latent.detach().cpu())
    latent_ = 1 / 0.18215 * latent
    decoded = pipe.decode(latent_)
    decoded = (decoded.sample / 2 + 0.5).clamp(0,1)
    decoded = decoded.detach().cpu().squeeze(0)
    decodeds.append(decoded)
    del patch
    del latent
    del decoded
    torch.cuda.empty_cache()

In [None]:
tensor_output = merge_pathes(decodeds, padded_image.shape[2], padded_image.shape[1], 1024)
output = remove_pads(transformed_image, tensor_output)
output = torchvision.transforms.ToPILImage()(output)
output.show()

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
from diffusers import UNet2DConditionModel, DPMSolverMultistepScheduler, LMSDiscreteScheduler

scheduler = DPMSolverMultistepScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler", torch_type=torch.float16)
scheduler.set_timesteps(51)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", torch_type=torch.float16)
unet.to("cuda:1", torch.float16)

In [None]:
import matplotlib.pyplot as plt
for latent in latents:
    noise = torch.randn_like(latent)
    fig, axs = plt.subplots(3, 2, figsize=(16, 32))

    for c, sampling_step in enumerate(range(0, 51 , 10)):
        print(sampling_step)
        encoded_and_noised = scheduler.add_noise(latent, noise, timesteps=torch.tensor([scheduler.timesteps[sampling_step]]))

        encoded_and_noised_ = 1 / 0.18215 * encoded_and_noised       
        decoded = pipe.decode(encoded_and_noised_.to("cuda:1", torch.float16))
        decoded = (decoded.sample / 2 + 0.5).clamp(0,1).squeeze(0)
        decoded = decoded.detach().cpu().permute(1,2,0).numpy()
        decoded_image = Image.fromarray((decoded * 255).astype(np.uint8))

        axs[c//2][c%2].imshow(decoded_image)
        axs[c//2][c%2].set_title(f"Step - {sampling_step}")
        del encoded_and_noised
        del encoded_and_noised_
        torch.cuda.empty_cache()

In [None]:
noised_latents = []
for latent in latents:
    fig, axs = plt.subplots(1, 1, figsize=(16, 8))
    noise = torch.randn_like(latent)
    noised_latent = scheduler.add_noise(latent, noise, timesteps=torch.tensor([scheduler.timesteps[35]]))

    noised_latent_ = 1 / 0.18215 * noised_latent
    decoded = pipe.decode(noised_latent_.to("cuda:1", torch.float16))
    decoded = (decoded.sample / 2 + 0.5).clamp(0,1).squeeze(0)
    decoded = decoded.detach().cpu().permute(1,2,0).numpy()
    decoded_image = Image.fromarray((decoded * 255).astype(np.uint8))
    decoded_image.save("0.png")
    
    noised_latents.append(noised_latent.detach().cpu())
    axs.imshow(decoded_image)
    del noise
    del noised_latent_
    del noised_latent
    del decoded
    torch.cuda.empty_cache()
    break

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer

tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer", torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder", torch_dtype=torch.float16).to("cuda:1")

In [None]:
prompt = [""]
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad(): 
    text_embeddings = text_encoder(
        text_input.input_ids.to("cuda:1")
    )[0]

latent = noised_latents[0].to("cuda:1", torch.float16)

latent_model_input = torch.cat([latent])

with torch.no_grad():
    noise_pred = unet(latent_model_input,
                     35,
                    encoder_hidden_states=text_embeddings
    )["sample"]

latent_model_input = latent - noise_pred

noise_pred = 1 / 0.18215 * latent_model_input
decoded = pipe.decode(noise_pred.to("cuda:1", torch.float16))
decoded = (decoded.sample / 2 + 0.5).clamp(0,1).squeeze(0)
decoded = decoded.detach().cpu().permute(1,2,0).numpy()
decoded_image = Image.fromarray((decoded * 255).astype(np.uint8))
decoded_image.save("0_denoised.png")
decoded_image.show()

In [None]:
def text_enc(prompts, maxlen=None):

    if maxlen is None: 
        maxlen = tokenizer.model_max_length
    inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt") 
    return text_encoder(inp.input_ids.to("cuda:1"))[0].half()

In [None]:
from tqdm import tqdm
prompt = [" "]
negative_prompt = ["lowres, extra digit, fewer digits, cropped, worst quality, low quality, text, word, icon, logo, hands, fingers, feet, face, eyes, anime, women, man, nude"]
guidance_scale = 5

batch_size = len(prompt)

text = text_enc(prompt)
uncond = text_enc(negative_prompt * batch_size, text.shape[1])
del text_encoder
emb = torch.cat([uncond, text])

scheduler.set_timesteps(51)

noised_latent = noised_latents[0]
noised_latent = noised_latent.to("cuda:1", torch.float16) * scheduler.init_noise_sigma
print("Noised Latent Shape:", noised_latent.shape)

for i, time_step in enumerate(tqdm(scheduler.timesteps)):
    #inp = scheduler.scale_model_input(torch.cat([noised_latent] * 2), time_step)
    with torch.no_grad():
        noise_pred= unet(noised_latent,
                             time_step,
                             encoder_hidden_states=text).sample

    #pred = noise_pred + guidance_scale * (t - noise_pred)
    noised_latent = scheduler.step(noise_pred, time_step, noised_latent).prev_sample

    noised_latent_ = 1 / 0.18215 * noised_latent
    decoded = pipe.decode(noised_latent_)
    decoded = (decoded.sample / 2 + 0.5).clamp(0,1).squeeze(0)
    decoded = decoded.detach().cpu().permute(1,2,0).numpy()
    decoded_image = Image.fromarray((decoded * 255).astype(np.uint8))
    if not os.path.exists(f'./steps'):
        os.mkdir(f'./steps')
    decoded_image.save(f'./steps/{i}.png')
        
    del noise_pred
    del decoded
    torch.cuda.empty_cache()

In [None]:
decoded_image.show()