In [1]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import safetensors
from tqdm import tqdm
from pathlib import Path
from dataclasses import dataclass
from IPython.display import clear_output
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline

# from models.clip import CLIPTextModelWithProjection, CLIPTextConfig

ROOT_DIR = Path('/home/batman/.cache/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671')

default_sample_size = 128
vae_scale_factor = 8
height = width = default_sample_size * vae_scale_factor
batch_size = 1
num_images_per_prompt = 1
max_sequence_length = 256
num_inference_steps = 28
guidance_scale = 7.0
# device = torch.device('cuda:0')
# dtype = torch.float16
device = torch.device('cpu')
dtype = torch.float32

prompt = "A man in a space suit on the Paris metro at rush hour, waiting to get off at his stop."

clear_output()

In [10]:
# pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=dtype).to(device)
clear_output()

In [10]:
from transformers import CLIPTokenizer
from transformers import CLIPTextModelWithProjection

tokenizer = CLIPTokenizer.from_pretrained(ROOT_DIR / 'tokenizer')
text_encoder_config = json.loads((ROOT_DIR / 'text_encoder' / 'config.json').read_text())
text_encoder_config = CLIPTextConfig(**text_encoder_config)
text_encoder = CLIPTextModelWithProjection.from_pretrained(ROOT_DIR / 'text_encoder', torch_dtype=dtype).to(device)

tokenizer_2 = CLIPTokenizer.from_pretrained(ROOT_DIR / 'tokenizer_2')
text_encoder_config_2 = json.loads((ROOT_DIR / 'text_encoder_2' / 'config.json').read_text())
text_encoder_config_2 = CLIPTextConfig(**text_encoder_config_2)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(ROOT_DIR / 'text_encoder_2', torch_dtype=dtype).to(device)

def get_clip_prompt_embeds(prompt, second=False):
    _tokenizer = tokenizer_2 if second else tokenizer
    _text_encoder = text_encoder_2 if second else text_encoder

    text_input_ids = _tokenizer(
        [prompt],
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    ).input_ids

    prompt_embeds = _text_encoder(text_input_ids.to(device), output_hidden_states=True)
    pooled_prompt_embeds = prompt_embeds.text_embeds
    prompt_embeds = prompt_embeds.hidden_states[-2]
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
    pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
    return prompt_embeds, pooled_prompt_embeds

clear_output()

In [12]:
from transformers import T5Tokenizer
from transformers import T5EncoderModel

tokenizer_3 = T5Tokenizer.from_pretrained(ROOT_DIR / 'tokenizer_3')
text_encoder_3 = T5EncoderModel.from_pretrained(ROOT_DIR / 'text_encoder_3', torch_dtype=dtype).to(device)

def get_t5_prompt_embeds(prompt):
    text_inputs = tokenizer_3(
        [prompt],
        padding="max_length",
        max_length=256,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )

    prompt_embeds = text_encoder_3(text_inputs.input_ids.to(device))[0]
    prompt_embeds = prompt_embeds.to(dtype=text_encoder_3.dtype, device=device)
    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds

clear_output()

In [5]:
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ROOT_DIR / 'scheduler')
scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = scheduler.timesteps
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)

In [6]:
from diffusers.models.transformers import SD3Transformer2DModel
transformer = SD3Transformer2DModel.from_pretrained(ROOT_DIR / 'transformer', torch_dtype=dtype).to(device)
num_channels_latents = transformer.config.in_channels
clear_output()

In [7]:
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
vae = AutoencoderKL.from_pretrained(ROOT_DIR / 'vae', torch_dtype=dtype).to(device)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
clear_output()

In [8]:
pipe = StableDiffusion3Pipeline(
    transformer = transformer,
    scheduler = scheduler,
    vae = vae,
    text_encoder = text_encoder,
    tokenizer = tokenizer,
    text_encoder_2 = text_encoder_2,
    tokenizer_2 = tokenizer_2,
    text_encoder_3 = text_encoder_3,
    tokenizer_3 = tokenizer_3
)

In [14]:
def get_prompt_embeds(prompt, negative_prompt=""):
    prompt_embed, pooled_prompt_embed = get_clip_prompt_embeds(prompt)
    prompt_2_embed, pooled_prompt_2_embed = get_clip_prompt_embeds(prompt, second=True)
    clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)

    t5_prompt_embed = get_t5_prompt_embeds(prompt)
    clip_prompt_embeds = F.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]))
    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
    pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)

    negative_prompt_embed, negative_pooled_prompt_embed = get_clip_prompt_embeds(negative_prompt)
    negative_prompt_2_embed, negative_pooled_prompt_2_embed = get_clip_prompt_embeds(negative_prompt, second=True)
    negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)

    t5_negative_prompt_embed = get_t5_prompt_embeds(negative_prompt)
    negative_clip_prompt_embeds = F.pad(negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]))
    negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
    negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1)

    combined_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
    combined_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
    return combined_prompt_embeds, combined_pooled_prompt_embeds

@torch.no_grad()
def sample(prompt, num_inference_steps=28, guidance_scale = 2.0):
    prompt_embeds, pooled_prompt_embeds = get_prompt_embeds(prompt)

    shape = (batch_size, num_channels_latents, int(height) // vae_scale_factor, int(width) // vae_scale_factor)
    latents = torch.randn(shape, device=device, dtype=dtype).to(device)
    scheduler.set_timesteps(num_inference_steps, device=device)

    for t in tqdm(scheduler.timesteps):
        latent_model_input = torch.cat([latents] * 2)
        timestep = t.expand(latent_model_input.shape[0])

        noise_pred = transformer(
            hidden_states=latent_model_input,
            timestep=timestep,
            encoder_hidden_states=prompt_embeds,
            pooled_projections=pooled_prompt_embeds,
            return_dict=False,
        )[0]

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
    image = vae.decode(latents, return_dict=False)[0]
    image = image_processor.postprocess(image, output_type='pil')
    return image

image = sample(prompt, num_inference_steps=num_inference_steps, guidance_scale=2.0)
image[0]

  0%|          | 0/28 [00:00<?, ?it/s]

torch.Size([2, 16, 128, 128])
torch.Size([2])
torch.Size([2, 333, 4096])
torch.Size([2, 2048])


  0%|          | 0/28 [00:12<?, ?it/s]


KeyboardInterrupt: 