In [None]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
import random
from torchvision import transforms
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import AutoModel, AutoTokenizer, set_seed, Siglip2TextModel
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm
from torch.utils.data import DataLoader
from copy import deepcopy
from functools import partial

from utils import (
    linear_multistep_coeff,
    pil_add_text, 
    latent_to_PIL, 
    make_grid, 
    encode_prompt, 
    dcae_scalingf, 
    free_memory, 
    generate,
    generate_lms,
)

seed = 42
set_seed(seed)

In [None]:
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_pretrained(
    "./cp-e76"
).to(device).to(dtype)

In [None]:
text_encoder = AutoModel.from_pretrained("HuggingFaceTB/SmolLM2-360M", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-360M", torch_dtype=dtype)
tokenizer.pad_token = tokenizer.eos_token
dcae = AutoencoderDC.from_pretrained("Efficient-Large-Model/Sana_600M_1024px_diffusers", subfolder="vae", torch_dtype=dtype).to(device)

In [None]:
print(f"text_encoder parameters: {sum(p.numel() for p in text_encoder.parameters()) / 1e6:.2f}M")

In [None]:
pipeline_components = dict(
    transformer=transformer,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    dcae=dcae,
)
inference_config = dict(
    # guidance_scale=4,
    latent_dim=[1, 32, 8, 12],
    latent_seed = 92216724,
    num_steps=20
)

# prompt = "a photo of a red train in the mountains"
# prompt = "a white cat with purple ears"
prompt = "A polar bear sitting in a red car on the beach"
# prompt = "a beautiful field of flowers"
# prompt = "a man and a woman on the beach "
# prompt = "a dog in the swimming pool"
# prompt = "a black car in the middle of a beautiful endless field of white flowers"

make_grid([
    generate(prompt, guidance_scale=2, **pipeline_components, **inference_config),
    generate(prompt, guidance_scale=4, **pipeline_components, **inference_config),
    generate(prompt, guidance_scale=5, **pipeline_components, **inference_config),
    generate(prompt, guidance_scale=6, **pipeline_components, **inference_config),
    generate(prompt, guidance_scale=7, **pipeline_components, **inference_config),
    generate(prompt, guidance_scale=10, **pipeline_components, **inference_config),
])

In [None]:
pipeline_components = dict(
    transformer=transformer,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    dcae=dcae,
)
inference_config = dict(
    guidance_scale=5,
    latent_dim=[1, 32, 15, 20],
    latent_seed = 92216724,
    num_steps=20
)

prompt = "a beautiful mountain landscape showing a bunch of people"

img = generate(prompt, **pipeline_components, **inference_config)
img

In [None]:
img

In [None]:
pipeline_components = dict(
    transformer=transformer,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    dcae=dcae,
)
inference_config = dict(
    guidance_scale=5,
    latent_dim=[1, 32, 8, 8],
    # latent_seed = 9221672424,
    num_steps=20
)

prompts = [
    "a dog",
    "a dog with blue eyes",
    "a dog on the beach",
    "a dog in the swimming pool",
    "a shark in the ocean",
    "a blue airplane taking off at the airport",
    "a bird in the swimming pool",
    "a beautiful snowy mountain landscape",
    "a woman",
    "a woman and her dog on the beach",
    "a woman eating a cheeseburger",
    "an astronaut riding a rainbow unicorn",
]

num_imgs_per_label = len(prompts)

In [None]:
images = {p:[] for p in prompts}
x0s = {p: 
       [ {step:[] for step in range(inference_config["num_steps"])} 
        for _ in range(num_imgs_per_label) ]
    for p in prompts
}

for prompt in tqdm(prompts):
    imgs_xps_prompt = [
        generate(prompt, return_xps=True, **pipeline_components, **inference_config)
        for _ in range(num_imgs_per_label)
    ]
    images[prompt] = [img_xps[0] for img_xps in imgs_xps_prompt] # imgs_xps_prompt[0]=image, imgs_xps_prompt[1]=list of x0s 
    for img_no in range(num_imgs_per_label):
        for step in range(inference_config["num_steps"]):
            x0s[prompt][img_no][step] = imgs_xps_prompt[img_no][1][step]


In [None]:
# gallery of all denoised images
gallery = make_grid([
    pil_add_text( make_grid(images[p]), p, font_size=40, position=(0,0))
    for p in prompts
], len(prompts), 1)
gallery.save("output.png")

gallery

In [None]:
# list of galleries of x predictions, one item per step
gallery_x0s = [
    make_grid([
        pil_add_text(
            make_grid([ x0s[prompt][img_no][step] for img_no in range(num_imgs_per_label)])
        , prompt)
        for prompt in prompts
    ], len(prompts), 1)
    for step in range(inference_config["num_steps"])
]

# original size
gif_anim = deepcopy(gallery_x0s)
gif_anim += [gif_anim[-1]]*15

gif_anim[-1].save("output.gif", save_all=True, append_images=gif_anim[1:-1], duration=100, loop=0)


In [None]:
# !ffmpeg -i output.gif -movflags faststart -pix_fmt yuv420p -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" output.mp4

In [None]:
# half size
gif_anim = deepcopy(gallery_x0s)
gif_anim = [i.resize((int(i.width//2.5), int(i.height//2.5))) for i in gif_anim]
gif_anim += [gif_anim[-1]]*15

gif_anim[-1].save("output_half.gif", save_all=True, append_images=gif_anim[1:-1], duration=100, loop=0)


In [None]:
# !ffmpeg -i output_half.gif -movflags faststart -pix_fmt yuv420p -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" output_half.mp4

In [None]:
prompt_selection = [
    prompts[0],
    prompts[2],
    prompts[4],
    prompts[6],
    prompts[7],
    prompts[8],
    prompts[9],
    prompts[10],
]
prompt_selection

In [None]:
# x0 prediction gallery for selected prompts
num_inf_steps = 20
num_imgs_per_label = len(prompt_selection)

gallery_x0s = [
    make_grid([
        pil_add_text(
            make_grid([ x0s[prompt][img_no][step] for img_no in range(num_imgs_per_label)])
        , prompt, font_size=40, stroke_width=2, position=(2,0))
        for prompt in prompt_selection
    ], len(prompt_selection), 1)
    for step in range(num_inf_steps)
]

# half size
gif_anim = deepcopy(gallery_x0s)
gif_anim = [i.resize((int(i.width//2), int(i.height//2))) for i in gif_anim]
gif_anim += [gif_anim[-1]]*15

gif_anim[-1].save("output_selected.gif", save_all=True, append_images=gif_anim[1:-1], duration=100, loop=0)

In [None]:
# !ffmpeg -i output_selected.gif -movflags faststart -pix_fmt yuv420p -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" output_selected.mp4