In [None]:
import os
import matplotlib.pyplot as plt
from PIL import Image

# Пути к папкам
base_dir = '/media/alexander/DATA/dima_vinichenko/data/SDXL'
folders = ['DDIM', 'DPM', 'DPM_AYS', 'DDPM_original']

# Настройки сетки
num_images = 22
num_cols = len(folders)
num_rows = num_images

# Размер фигуры
figsize_per_image = 2  # Размер изображения в дюймах
fig, axes = plt.subplots(num_rows, num_cols, figsize=(figsize_per_image*num_cols, figsize_per_image*num_rows))

# Загружаем и отображаем изображения
for row in range(num_rows):
    image_idx = row + 1
    for col, folder in enumerate(folders):
        image_path = os.path.join(base_dir, folder, f"{image_idx}.png")
        if os.path.exists(image_path):
            img = Image.open(image_path)
            axes[row, col].imshow(img)
        axes[row, col].axis('off')

# Добавляем заголовки к столбцам
for col, folder in enumerate(folders):
    axes[0, col].set_title(folder, fontsize=10)

plt.tight_layout()
plt.show()


In [1]:
import torch
import numpy as np

from diffusers import StableDiffusionXLPipeline
from diffusers.utils import make_image_grid

from IPython.display import display

from diffusers import DPMSolverMultistepScheduler as DefaultDPMSolver

# Add support for setting custom timesteps
class DPMSolverMultistepScheduler(DefaultDPMSolver):
    def set_timesteps(
        self, num_inference_steps=None, device=None,
        timesteps=None
    ):
        if timesteps is None:
            super().set_timesteps(num_inference_steps, device)
            return

        all_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        self.sigmas = torch.from_numpy(all_sigmas[timesteps])
        self.timesteps = torch.tensor(timesteps[:-1]).to(device=device, dtype=torch.int64) # Ignore the last 0

        self.num_inference_steps = len(timesteps)

        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0

        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
        self._begin_index = None
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication


from diffusers import DPMSolverMultistepInverseScheduler as DefaultDPMSolverInverse

class DPMSolverMultistepInverseScheduler(DefaultDPMSolverInverse):
    
    def set_timesteps(
        self, num_inference_steps=None, device=None,
        timesteps=None
    ):
        if timesteps is None:
            super().set_timesteps(num_inference_steps, device)
            return

        all_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
        self.sigmas = torch.from_numpy(all_sigmas[timesteps])
        self.timesteps = torch.tensor(timesteps[:-1]).to(device=device, dtype=torch.int64) # Ignore the last 0

        self.num_inference_steps = len(timesteps)

        self.model_outputs = [
            None,
        ] * self.config.solver_order
        self.lower_order_nums = 0

        # add an index counter for schedulers that allow duplicated timesteps
        self._step_index = None
        self._begin_index = None
        self.sigmas = self.sigmas.to("cpu")  # to avoid too much CPU/GPU communication


In [2]:
from PIL import Image
from diffusers import DDIMInverseScheduler
from torchvision import transforms as tvt
from diffusers.utils import make_image_grid
import torch

from diffusers import DDIMScheduler

@torch.no_grad()
def ddim_inversion(input_image, num_steps):
    
    # dtype = torch.float16
    
    pipe = StableDiffusionXLPipeline.from_single_file(
        "/media/alexander/DATA/ai1_models/models_SDXL/sd_xl_base_1.0.safetensors",
        torch_dtype=torch.float16, 
        # variant="fp16", 
        use_safetensors=True,
        add_watermarker=False
    ).to(device)

    # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
    pipe.scheduler = inverse_scheduler
    
    pipe.to(device)
    
    dtype = torch.float32
    seed=10
    generator = torch.Generator(device).manual_seed(seed)
    
    input_image = tvt.ToTensor()(input_image)[None, ...]
    input_image = input_image.to(device=device, dtype=dtype)
    
    vae = pipe.vae
    vae.to(device, dtype=dtype)
        
    latents_mean = latents_std = None
    if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
        latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
    if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
        latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
    
    init_latents = vae.encode(input_image * 2 - 1).latent_dist.sample(generator)
    # init_latents = vae.encode(input_image).latent_dist.sample(generator)
    
    if latents_mean is not None and latents_std is not None:
        latents_mean = latents_mean.to(device=device, dtype=dtype)
        latents_std = latents_std.to(device=device, dtype=dtype)
        init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
    else:

        init_latents = init_latents * vae.config.scaling_factor 

    latents = init_latents
    latents = latents.type(torch.float16)
    print('latents', latents.shape)
    # inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
    # pipe.scheduler = inverse_scheduler

    inv_latents = pipe(prompt="", negative_prompt="", guidance_scale=1.,
                          width=input_image.shape[-1], height=input_image.shape[-2],
                          output_type='latent', return_dict=False,
                          num_inference_steps=num_steps, latents=latents)

        
    return inv_latents
    


In [3]:
from PIL import Image
# from diffusers import DDIMInverseScheduler
from torchvision import transforms as tvt


@torch.no_grad()
def dpm_inversion(input_image, num_steps, timesteps):
    
    # dtype = torch.float16
    
    pipe = StableDiffusionXLPipeline.from_single_file(
        "/media/alexander/DATA/ai1_models/models_SDXL/sd_xl_base_1.0.safetensors",
        torch_dtype=torch.float16, 
        # variant="fp16", 
        use_safetensors=True,
        add_watermarker=False
    ).to(device)

    # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config)
    pipe.scheduler = inverse_scheduler
    
    pipe.to(device)
    
    dtype = torch.float32
    seed=10
    generator = torch.Generator(device).manual_seed(seed)
    
    input_image = tvt.ToTensor()(input_image)[None, ...]
    input_image = input_image.to(device=device, dtype=dtype)
    
    vae = pipe.vae
    vae.to(device, dtype=dtype)
        
    latents_mean = latents_std = None
    if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
        latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
    if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
        latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
    
    init_latents = vae.encode(input_image * 2 - 1).latent_dist.sample(generator)
    # init_latents = vae.encode(input_image).latent_dist.sample(generator)
    
    if latents_mean is not None and latents_std is not None:
        latents_mean = latents_mean.to(device=device, dtype=dtype)
        latents_std = latents_std.to(device=device, dtype=dtype)
        init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
    else:

        init_latents = init_latents * vae.config.scaling_factor 

    latents = init_latents
    latents = latents.type(torch.float16)
    print('latents', latents.shape)
    # inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
    # pipe.scheduler = inverse_scheduler

    inv_latents = pipe(prompt="", negative_prompt="", guidance_scale=1.,
                          width=input_image.shape[-1], height=input_image.shape[-2],
                          output_type='latent', return_dict=False,
                          num_inference_steps=num_steps, latents=latents,
                          timesteps=timesteps
                      )

        
    return inv_latents


In [4]:
sampling_schedule = [999, 845, 730, 587, 443, 310, 193, 116, 53, 13, 0]
prompt_list = []
edit_prompt_list = []

генерирую изображение с DDPM

In [765]:
orig_prompts = [
    'Photo of a golden retriever in the park',
    'Photo of a horse running on the beach',
    'Realistic photo of a cat on the sofa',
    'Photo of a tiger in the jungle',
    'Realistic photo of a rabbit in the garden',
    'Photo of a dog sleeping on a blanket',
    'Photo of a giraffe in the savannah',
    'Photo of a fox in the snowy field',
    'Photo of a zebra drinking from a lake',
    'Photo of a woman sitting in a cafe',
    'Realistic photo of a basketball on the court',
    'Realistic photo of a guitar on a wooden floor',
    'Photo of a man surfing on the ocean',
    'Realistic photo of a boy holding an ice cream',
    'Photo of a plane flying over the mountains',
    'Photo of a cat sleeping on a bed',
    'Photo of a lion resting on a rock',
    'Photo of a leopard in the jungle',
    'Realistic photo of a white dove in the sky',
    'Photo of a camel in the desert',
    'Photo of a cyclist riding on a mountain trail',
    'Photo of a sailboat on the ocean at sunset',
    'Realistic photo of a child playing with a kite',
    'Realistic photo of a woman holding a bouquet',
    'Photo of a black car on a city street',
    'Photo of a pizza on a wooden table',
    'Photo of a black dog on the snow',
    'Realistic photo of a firefighter in uniform',
    'Photo of a red rose in a garden',
    'Photo of a deer standing in a forest clearing',
    'Photo of a skier on a snowy mountain',
    'Photo of a man fishing by a river',
    'Photo of a yellow taxi in New York City',
    'Photo of a violin on a concert stage',
    'Photo of a shark swimming in the ocean',
    'Photo of a ballerina dancing in a studio',
    'Photo of a penguin on ice',
    'Photo of a raccoon in a tree',
    'Photo of a monkey sitting on a rock',
    'Realistic photo of goat on a cliff'
]

edit_prompts = [
    'Photo of a beagle in the park',
    'Photo of a camel running on the beach',
    'Realistic photo of a dog on the sofa',
    'Photo of a leopard in the jungle',
    'Realistic photo of a squirrel in the garden',
    'Photo of a cat sleeping on a blanket',
    'Photo of an elephant in the savannah',
    'Photo of a deer in the snowy field',
    'Photo of a gazelle drinking from a lake',
    'Photo of a man sitting in a cafe',
    'Realistic photo of a football on the court',
    'Realistic photo of a violin on a wooden floor',
    'Photo of a woman surfing on the ocean',
    'Realistic photo of a girl holding an ice cream',
    'Photo of a helicopter flying over the mountains',
    'Photo of a dog sleeping on a bed',
    'Photo of a cheetah resting on a rock',
    'Photo of a panther in the jungle',
    'Realistic photo of a black dove in the sky',
    'Photo of a llama in the desert',
    'Photo of a hiker riding on a mountain trail',
    'Photo of a yacht on the ocean at sunset',
    'Realistic photo of a boy playing with a kite',
    'Realistic photo of a man holding a bouquet',
    'Photo of a white car on a city street',
    'Photo of a burger on a wooden table',
    'Photo of a golden retriever on the snow',
    'Realistic photo of a police officer in uniform',
    'Photo of a sunflower in a garden',
    'Photo of a moose standing in a forest clearing',
    'Photo of a snowboarder on a snowy mountain',
    'Photo of a woman fishing by a river',
    'Photo of a blue taxi in New York City',
    'Photo of a cello on a concert stage',
    'Photo of a whale swimming in the ocean',
    'Photo of a gymnast dancing in a studio',
    'Photo of a seal on ice',
    'Photo of an owl in a tree',
    'Photo of a baboon sitting on a rock',
    'Realistic photo of ibex on a cliff'
]

In [766]:
from src.sdxl_pipeline import StableDiffusionXLPipeline
from src.utils import plot_image_grid, print_images

import torch
from diffusers import DDPMScheduler


device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")

pipe = StableDiffusionXLPipeline.from_single_file(
    "/media/alexander/DATA/ai1_models/models_SDXL/sd_xl_base_1.0.safetensors",
    torch_dtype=torch.float16, 
    # variant="fp16", 
    use_safetensors=True,
    add_watermarker=False,
).to(device)

pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Some weights of the model checkpoint were not used when initializing CLIPTextModel: 
 ['text_model.embeddings.position_ids']


In [767]:
i = 22

Generation

In [768]:
tmp = 39

prompt = orig_prompts[tmp]
print(prompt)

prompt_ = edit_prompts[tmp]
print(prompt_)

Realistic photo of goat on a cliff
Realistic photo of ibex on a cliff


In [783]:
num_steps = 50

# seed = None
# latents, seed = pipe.get_latents(seed=seed, device=device)
image = pipe(prompt=prompt, negative_prompt="", guidance_scale=7.5,
             num_inference_steps=num_steps).images[0]


image


In [784]:
path = '/media/alexander/DATA/dima_vinichenko'
image_path = path + '/data/SDXL/DDPM_original'

image.save(f"{image_path}/{i}.png")
im = Image.open(image_path + f"/{i}.png")
im

DPM++ AYS

In [785]:
num_steps = 10


inv_latents = dpm_inversion(image, num_steps=num_steps, timesteps = sampling_schedule[::-1])

# pipe = StableDiffusionXLPipeline.from_single_file(
#     "/media/alexander/DATA/ai1_models/models_SDXL/sd_xl_base_1.0.safetensors",
#     torch_dtype=torch.float16, 
#     # variant="fp16", 
#     use_safetensors=True,
#     add_watermarker=False
# ).to(device)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

image_ = pipe(prompt=prompt_, negative_prompt="", guidance_scale=7.5,
             num_inference_steps=num_steps, latents=inv_latents[0], timesteps=sampling_schedule).images[0]


display(make_image_grid([image, image_], rows=1, cols=2))



In [786]:
path = '/media/alexander/DATA/dima_vinichenko'
image_path = path + '/data/SDXL/DPM_AYS'

image_.save(f"{image_path}/{i}.png")
im = Image.open(image_path + f"/{i}.png")
im


DPM++

In [787]:
num_steps = 10

inv_latents = dpm_inversion(image, num_steps=num_steps, timesteps = None)

# pipe = StableDiffusionXLPipeline.from_single_file(
#     "/media/alexander/DATA/ai1_models/models_SDXL/sd_xl_base_1.0.safetensors",
#     torch_dtype=torch.float16, 
#     # variant="fp16", 
#     use_safetensors=True,
#     add_watermarker=False
# ).to(device)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# prompt_ = 'Realistic photography of a cat on the snow'

image_ = pipe(prompt=prompt_, negative_prompt="", guidance_scale=7.5,
             num_inference_steps=num_steps, latents=inv_latents[0], timesteps=None).images[0]

display(make_image_grid([image, image_], rows=1, cols=2))

In [None]:
path = '/media/alexander/DATA/dima_vinichenko'
image_path = path + '/data/SDXL/DPM'



image_.save(f"{image_path}/{i}.png")
im = Image.open(image_path + f"/{i}.png")
im

DDIM

In [None]:
num_steps = 10

inv_latents = ddim_inversion(image, num_steps=num_steps)

# pipe = StableDiffusionXLPipeline.from_single_file(
#     "/media/alexander/DATA/ai1_models/models_SDXL/sd_xl_base_1.0.safetensors",
#     torch_dtype=torch.float16, 
#     # variant="fp16", 
#     use_safetensors=True,
#     add_watermarker=False
# ).to(device)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

# prompt_ = 'Realistic photography of a cat on the snow'

image_ = pipe(prompt=prompt_, negative_prompt="", guidance_scale=7.5,
             num_inference_steps=num_steps, latents=inv_latents[0]).images[0]

display(make_image_grid([image, image_], rows=1, cols=2))

In [None]:
path = '/media/alexander/DATA/dima_vinichenko'
image_path = path + '/data/SDXL/DDIM'

image_.save(f"{image_path}/{i}.png")
im = Image.open(image_path + f'/{i}.png')
im

In [780]:
import gc

del pipe
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [781]:
prompt_list.append(prompt)
prompt_list

['Beautiful DSLR Photograph of a penguin on the beach, golden hour',
 'Realistic photography of a dog on the snow',
 'realistic photo of a bear on the rock',
 'photo of a puppy on the grass',
 'photo of a yellow car in the city',
 'Realistic photography of a cat in the forest',
 'realistic photo of a wolf on the rock',
 'realistic photo of a wolf on the beach',
 'realistic photo of a wolf in the snow forest',
 'Photo of a sailing boat on the lake',
 'Photo of a cat on the windowsill',
 'Photo of a white swan swimming on the river',
 'Realistic photo of a cat on the sofa',
 'Photo of a tiger in the jungle',
 'Realistic photo of a rabbit in the garden',
 'Photo of a plane flying over the mountains',
 'Photo of a lion resting on a rock',
 'Photo of a black car on a city street',
 'Photo of a deer standing in a forest clearing',
 'Photo of a penguin on ice',
 'Photo of a raccoon in a tree',
 'Realistic photo of goat on a cliff']

In [782]:
edit_prompt_list.append(prompt_)
edit_prompt_list


['Beautiful DSLR Photograph of a dog on the beach, golden hour',
 'Realistic photography of a cat on the snow',
 'realistic photo of a wolf on the rock',
 'photo of a kitty on the grass',
 'photo of a red car in the city',
 'Realistic photography of a dog in the forest',
 'realistic photo of a bear on the rock',
 'realistic photo of a bear on the beach',
 'realistic photo of a dog in the snow forest',
 'Photo of a motorboat on the lake',
 'Photo of a dog on the windowsill',
 'Photo of a black swan swimming on the river',
 'Realistic photo of a dog on the sofa',
 'Photo of a leopard in the jungle',
 'Realistic photo of a squirrel in the garden',
 'Photo of a helicopter flying over the mountains',
 'Photo of a cheetah resting on a rock',
 'Photo of a white car on a city street',
 'Photo of a moose standing in a forest clearing',
 'Photo of a seal on ice',
 'Photo of an owl in a tree',
 'Realistic photo of ibex on a cliff']