### from https://github.com/riveSunder/simple_diffusion_demo/blob/master/diffusion_demo.ipynb

In [None]:
from init_notebook import *

In [None]:
some_image = VF.to_tensor(PIL.Image.open("/home/bergi/Pictures/csv-turing.png"))
some_image = resize(some_image, 1/8)
print(some_image.shape)
VF.to_pil_image(some_image)

In [None]:
from diffusers import StableDiffusionPipeline, AutoPipelineForImage2Image

from diffusers.pipelines.pipeline_utils import numpy_to_pil
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler

In [None]:
pipe_name = "CompVis/stable-diffusion-v1-4"
if 1:
    my_dtype = torch.float32 #torch.float16
    my_device = torch.device("cpu") #torch.device("cuda")
    pipe = StableDiffusionPipeline.from_pretrained(pipe_name, torch_dtype=my_dtype, safety_checker=None).to(my_device)
else:
    my_dtype = torch.float16
    my_device = torch.device("cuda")
    pipe = StableDiffusionPipeline.from_pretrained(
        pipe_name, torch_dtype=my_dtype, safety_checker=None,
    ).to(my_device)
    pipe.enable_attention_slicing()

In [None]:
image_widget = ImageWidget()
display(image_widget)

def _callback(pipe, i, timestep, data: dict):
    # print(i, timestep, data["latents"].mean(), data["latents"].shape)
    images = pipe.vae.decode(data["latents"] / pipe.vae.config.scaling_factor).sample * .5 + .5
    image_widget.set_torch(make_grid(images).clamp(0, 1))
    return data
    
my_output = pipe(
    #"happy workers in the butter factory", 
    "red square on yellow background",
    num_inference_steps=20, num_images_per_prompt=3, guidance_scale=9.0,
    callback_on_step_end=_callback, width=64, height=64,
)


In [None]:
my_output.images[0]

In [None]:
vae = pipe.components["vae"]
print(f"params: {num_module_parameters(vae):,}")

In [None]:
VF.to_pil_image(make_grid(
    vae.decoder(torch.randn(2, vae.config.latent_channels, 8, 8) * .5)
))

In [None]:
vae.config

In [None]:
vae.encode(torch.rand(1, 3, 32, 32))

In [None]:
encoded = vae.encoder(some_image[:, :128, :128].unsqueeze(0))
print(encoded.shape)

In [None]:
with torch.inference_mode():
    #display(VF.to_pil_image(vae(some_image[None, ...]).sample[0]))
    dist = vae.encode(resize(some_image, .5)[:, :64, :64].unsqueeze(0)).latent_dist
    encoded = dist.sample()
    print(encoded.shape)

    decoded = vae.decode(encoded).sample
    display(VF.to_pil_image(decoded[0]))

    input_latents = encoded

In [None]:
unet = pipe.components["unet"]
tokenizer = pipe.components["tokenizer"]
text_encoder = pipe.components["text_encoder"]
scheduler = pipe.components["scheduler"]

print(f"params: {num_module_parameters(unet):,}")

In [None]:
tokens = tokenizer("cthulhu", padding="max_length",\
        max_length=tokenizer.model_max_length, truncation=True,\
        return_tensors="pt")

empty_tokens = tokenizer([""], padding="max_length",\
        max_length=tokenizer.model_max_length, truncation=True,\
        return_tensors="pt")

with torch.no_grad():
    text_embeddings = text_encoder(tokens.input_ids.to(my_device))[0]
    max_length = tokens.input_ids.shape[-1]
    
    notext_embeddings = text_encoder(empty_tokens.input_ids.to(my_device))[0]

text_embeddings = torch.cat([notext_embeddings, text_embeddings])

In [None]:
with torch.no_grad():
    latents_in = vae.encode(torch.rand(2, 3, 512, 512)).latent_dist.sample()
    latents_in = latents_in * scheduler.init_noise_sigma

In [None]:
#scheduler = PNDMScheduler?#(**scheduler.config)

In [None]:
#latents = input_latents.repeat(2, 1, 1, 1) 
latents = latents_in

latents_history = []
image_widget = ImageWidget()
display(image_widget)
scheduler.set_timesteps(10)
with torch.inference_mode():
    latents = scheduler.scale_model_input(latents, timestep).to(my_device)
    
    for step_idx, timestep in enumerate(tqdm(scheduler.timesteps)):
               
        predicted_latents = unet(latents, timestep, text_embeddings).sample
        #print("P", latents.shape, "->", predicted_latents.shape)
        latents = scheduler.step(predicted_latents, timestep, latents).prev_sample
        #print("X", predicted_latents.shape, "->", latents.shape)
        
        latents_history.append(latents)
        images = vae.decode(latents).sample.clamp(0, 1)
        image_widget.set_torch(make_grid(images))
        
        #latents = scheduler.step(latents, timestep, latents).prev_sample

In [None]:
scheduler.ets[3].shape

In [None]:
with torch.no_grad():
    images = vae.decode(torch.concat(latents_history)).sample.clamp(0, 1)
VF.to_pil_image(make_grid(images))