## **Components (Migrate from Stable Diffusion)**
Quick migration mental model (old → new)

1. U-Net → Transformer (MMDiT-X)

2. 1× CLIP → 2× CLIP + 1× T5-XXL (FROZEN ALL, T5 can be obmit when inference)
T5 for language understanding, CLIP for embedding."Multi-encoder fusion providing both detailed and global context, plus advanced linguistic understanding with T5; supports optional omission to manage VRAM."

3. Same VAE role (encode/decode latents)

4. DDIM/DPM schedulers → FlowMatch-Euler/Heun (+ shift)

5. CFG only → CFG + optional SLG


| Module                            | SD2                                                   | SD3                                                                    | Notes                                         |
| --------------------------------- | ----------------------------------------------------- | ---------------------------------------------------------------------- | --------------------------------------------- |
| **Text Encoders**                 | OpenCLIP-ViT-H/14                                     | CLIP-ViT/L, OpenCLIP-ViT/G (**new**), T5-XXL (**new**)                 | Multi-encoder fusion is the biggest change.   |
| **Tokenizer**                     | CLIP tokenizer                                        | CLIP tokenizer, T5 tokenizer (**new**)                                 | To handle multiple encoder types.             |
| **UNet / Denoiser**               | UNet2DConditionModel (ResNet + CrossAttention blocks) | **MM-DiT (Multimodal Diffusion Transformer)** (**new**)                | Transformer backbone replaces CNN-heavy UNet. |
| **Noise Scheduler**               | DDIM, PNDM, Euler, etc.                               | Flow Matching Scheduler (**new**, but also supports DDIM etc.)         | Flow matching improves training stability.    |
| **Variational Autoencoder (VAE)** | VAE-based latent compression (256→64)                 | Same (slightly improved VAE)                                           | Still compresses images into latent space.    |
| **Conditioning Mechanism**        | Cross-attention with CLIP embeddings                  | Cross-attention + pooled embeddings + T5 embeddings (**new features**) | Richer conditioning signals.                  |
| **Safety Checker**                | Optional NSFW detector                                | Similar, integrated                                                    | No big change.                                |
| **Optimizer**                     | AdamW, EMA                                            | Same, but scaled for DiT training                                      | Training is heavier due to larger encoders.   |


# Pseudocode for Stable Diffusion 3 Inference

def sd3_inference(prompt, num_steps=50, guidance_scale=7.5):
    # INPUT: text prompt (string)
    # OUTPUT: generated image (RGB)

    # [NEW in SD3] Multiple text encoders
    clip_l_emb = CLIP_ViT_L.encode(prompt)          # encoder 1
    clip_g_emb = OpenCLIP_ViT_G.encode(prompt)      # encoder 2
    t5_emb      = T5_XXL.encode(prompt)             # encoder 3
    conditioning = fuse_embeddings([clip_l_emb, clip_g_emb, t5_emb])

    # Latent init (same as SD2)
    latents = sample_gaussian_noise(shape=(latent_channels, H//8, W//8))

    # Iterative denoising
    for t in scheduler.timesteps(num_steps):
        # [NEW] MM-DiT backbone with flow matching scheduler
        velocity = MM_DiT(latents, conditioning, t)  # predict velocity/flow
        latents = scheduler.step(velocity, t, latents, guidance_scale)

    # Decode to image (same as SD2)
    image = VAE.decode(latents)
    return image


# Pseudocode for Stable Diffusion 3 Training
for batch in dataloader:
    images, captions = batch
    # INPUT: paired (image, text)
    # OUTPUT: trained weights for MM-DiT

    # Latent encoding (same as SD2)
    latents = VAE.encode(images)

    # Sample timestep + noise (same as SD2)
    noise = sample_gaussian_noise_like(latents)
    t = sample_random_timestep()
    noisy_latents = scheduler.add_noise(latents, noise, t)

    # [NEW in SD3] Multi-encoder conditioning
    clip_l_emb = CLIP_ViT_L.encode(captions)
    clip_g_emb = OpenCLIP_ViT_G.encode(captions)
    t5_emb     = T5_XXL.encode(captions)
    conditioning = fuse_embeddings([clip_l_emb, clip_g_emb, t5_emb])

    # [NEW] MM-DiT predicts velocity instead of noise
    velocity_pred = MM_DiT(noisy_latents, conditioning, t)

    # [NEW] Flow-matching loss
    loss = mse_loss(velocity_pred, true_velocity(latents, noise, t))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()



In [5]:
""" SD2
Text Prompt ──> OpenCLIP-ViT-H/14 ──> Text Embeddings
                                         │
                                         ▼
   Latent Noise ──> UNet (ResNet+Attention) ──> ε-prediction (noise)
                                         │
                              Scheduler (DDPM/DDIM/PNDM)
                                         │
                                         ▼
                                 Latent Image
                                         │
                                         ▼
                                VAE Decoder → RGB Image

SD3
Text Prompt ──> [CLIP-ViT/L]  ─┐
              [OpenCLIP-ViT/G] ├──> Fused Text Embeddings ──┐
              [T5-XXL]        ─┘                            │  [NEW]
                                                           ▼
   Latent Noise ──> MM-DiT Transformer ──> Velocity (dx/dt prediction) [NEW]
                                                             │
                                           Scheduler (Flow Matching) [NEW]
                                                             │
                                                             ▼
                                                     Latent Image
                                                             │
                                                             ▼
                                                    VAE Decoder → RGB Image

"""
None

### **1. Transformer**

In [1]:
from diffusers import SD3Transformer2DModel
import torch
dtype = torch.bfloat16
model_id = "stabilityai/stable-diffusion-3.5-medium"

transformer = SD3Transformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype = dtype,
)
transformer

SD3Transformer2DModel(
  (pos_embed): PatchEmbed(
    (proj): Conv2d(16, 1536, kernel_size=(2, 2), stride=(2, 2))
  )
  (time_text_embed): CombinedTimestepTextProjEmbeddings(
    (time_proj): Timesteps()
    (timestep_embedder): TimestepEmbedding(
      (linear_1): Linear(in_features=256, out_features=1536, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=1536, out_features=1536, bias=True)
    )
    (text_embedder): PixArtAlphaTextProjection(
      (linear_1): Linear(in_features=2048, out_features=1536, bias=True)
      (act_1): SiLU()
      (linear_2): Linear(in_features=1536, out_features=1536, bias=True)
    )
  )
  (context_embedder): Linear(in_features=4096, out_features=1536, bias=True)
  (transformer_blocks): ModuleList(
    (0-12): 13 x JointTransformerBlock(
      (norm1): SD35AdaLayerNormZeroX(
        (silu): SiLU()
        (linear): Linear(in_features=1536, out_features=13824, bias=True)
        (norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=Fa

In [2]:
from petorch import AdapterAPI
from petorch.prebuilt.configs import LoraConfig
from petorch.utilities.func import get_module_num_parameters, freeze_module
config = LoraConfig(adapter_name='default', rank=8, alpha=16)
def freeze_adapter_and_print_info(module:torch.nn.Module):
    frozen_params = freeze_module(module)
    AdapterAPI.add_adapter(module,config, activate=False)
    train_params, non_train_params = get_module_num_parameters(module)

    assert non_train_params == frozen_params, f"{non_train_params} != {frozen_params}"
    print(f"Train params:{train_params:,}, Non train params: {non_train_params:,}")
    print(train_params / non_train_params)
freeze_adapter_and_print_info(transformer)

[32m[1m[2025-08-29 09:03:35.841][0m[32m[0m [1m[PID 33309 | INFO    ][0m [36m[3mpetorch.utilities.logger:setup_logger:110 :~ [0m[36m[0m [1mLog level is set to `INFO`.[0m
Train params:16,505,856, Non train params: 2,243,171,520
0.007358267458745196


In [None]:
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline

"A red bus driving on a bridge"
| Encoder | Token embeddings (shape)                             | Meaning                           | Pooled embedding (shape) |
| ------- | ---------------------------------------------------- | --------------------------------- | ------------------------ |
| CLIP-L  | `(77, 768)`                                          | Each of 77 tokens → 768-d vectors | `(768,)`                 |
| CLIP-G  | `(77, 1024)`                                         | Same 77 tokens → 1024-d vectors   | `(1024,)`                |
| T5-XXL  | `(32, 4096)` (assuming 32 tokens from SentencePiece) | Contextualized embeddings         | —                        |

After projection:

CLIP-L → (77, 4096)

CLIP-G → (77, 4096)

T5-XXL → (32, 4096)

These sequences can be concatenated → (186, 4096) (77+77+32 tokens).
That’s the fused text embedding sequence given to MM-DiT.

| Term                                          | What it means                                                      | Example output                                                              |
| --------------------------------------------- | ------------------------------------------------------------------ | --------------------------------------------------------------------------- |
| **Token embedding (lookup layer)**            | Maps each token ID → vector (no context)                           | `(seq_len, hidden_dim)` but same word in different contexts has same vector |
| **Contextualized embedding (encoder output)** | Each token vector **incorporates context from surrounding tokens** | `(seq_len, hidden_dim)` but now “red” in “red bus” ≠ “red apple”            |
| **Sequence embedding / sentence embedding**   | Compresses **whole sequence** into **single vector**               | `(hidden_dim,)` or `(1, hidden_dim)` — usually via pooling or \[CLS] token  |


https://huggingface.co/blog/sd3