In [None]:
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16).to("cuda")

In [None]:
import tqdm

In [None]:
prompt = "A high quality, creative, portrait of a dog, against a plain background."
negative_prompt = ""

for i in tqdm.tqdm(range(0, 100, 2)):
    # prior.enable_model_cpu_offload()
    prior_output = prior(
        prompt=prompt,
        height=1024,
        width=1024,
        negative_prompt=negative_prompt,
        guidance_scale=4.0,
        num_images_per_prompt=2,
        num_inference_steps=20
    )

    # decoder.enable_model_cpu_offload()
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings.to(torch.float16),
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=0.0,
        output_type="pil",
        num_inference_steps=10
    )
    
    # Save the image
    for idx, image in enumerate(decoder_output.images):
        image.save(f"projects/custom/dataset/dog/{i+idx}.png")

In [None]:
animals = ["cat", "giraffe", "tiger", "bear", "penguin", "panda"]
for animal in animals:
    prompt = f"A high quality, creative, portrait of a {animal}, against a plain background."
    negative_prompt = ""

    for i in tqdm.tqdm(range(0, 100, 2)):
        # prior.enable_model_cpu_offload()
        prior_output = prior(
            prompt=prompt,
            height=1024,
            width=1024,
            negative_prompt=negative_prompt,
            guidance_scale=4.0,
            num_images_per_prompt=2,
            num_inference_steps=20
        )

        # decoder.enable_model_cpu_offload()
        decoder_output = decoder(
            image_embeddings=prior_output.image_embeddings.to(torch.float16),
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=0.0,
            output_type="pil",
            num_inference_steps=10
        )
        
        # Save the image
        for idx, image in enumerate(decoder_output.images):
            image.save(f"projects/custom/dataset/{animal}/{i+idx}.png")

In [1]:
import diffusers

In [6]:
model = diffusers.VQModel(
    in_channels=3,
    out_channels=3,
    down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"),
    up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"),
    block_out_channels=(32, 64),
    layers_per_block=2,
    latent_channels=4,
    sample_size=256,
    num_vq_embeddings=256,
    vq_embed_dim=64,
    norm_num_groups=16,
)
print(f"Model has {model.num_parameters():,} parameters.")

Model has 858,411 parameters.


In [8]:
import torch
import torch.nn as nn

In [14]:
model.forward(torch.randn(1, 3, 256, 256), True)

DecoderOutput(sample=tensor([[[[ 1.0439e+00, -4.9478e-01, -5.3723e-01,  ..., -4.6625e-01,
           -8.5494e-01, -7.6728e-01],
          [ 5.8595e-01,  1.0339e+00,  6.6085e-03,  ...,  8.0054e-01,
            1.3276e+00,  8.7352e-01],
          [ 1.6277e+00,  2.4285e+00,  7.0050e-01,  ...,  6.0781e-01,
            6.3180e-01,  1.3622e-01],
          ...,
          [ 7.7525e-01,  8.2628e-01,  1.6002e+00,  ...,  4.8426e-01,
           -4.3053e-01,  3.0097e-01],
          [ 1.0023e+00,  1.2221e+00,  1.7539e+00,  ...,  1.7192e-01,
            1.9404e-01, -7.6295e-03],
          [ 6.2281e-01,  6.0951e-01,  5.6773e-01,  ..., -1.1793e-02,
           -4.7666e-02, -1.5319e-01]],

         [[-7.6515e-01, -7.5619e-01, -1.1668e+00,  ...,  8.6761e-01,
            2.1760e-01,  3.0377e-01],
          [-1.0732e+00, -1.0866e+00, -4.8508e-01,  ...,  2.7568e-01,
            2.5070e-01,  1.9509e-01],
          [-2.5499e+00, -8.2628e-01, -7.6318e-01,  ..., -1.3304e-01,
            2.0589e-01,  2.7565e-01],