In [2]:
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import nn
from enum import StrEnum
from transformers.modeling_outputs import BaseModelOutputWithPooling
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from lightning import LightningModule
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel,  DDIMScheduler, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTokenizerFast, CLIPImageProcessor
from typing import Any, Self
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from pydantic import BaseModel, ConfigDict, model_validator
import torch
import torch.nn.functional as F



class SDTrainBatch(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    input_ids: torch.Tensor
    images: torch.Tensor

    @model_validator(mode="after")
    def check_size(self)->Self:
        assert self.input_ids.size(0) == self.images.size(0)
        assert self.images.size(1) == 3
        return self


class PredictionType(StrEnum):
    EPSILON = "epsilon"
    SAMPLE = "sample"
    V_PREDICTION = "v_prediction"


class StableDiffusionModule(LightningModule):
    """
    Attributes:
        unet:
        vae:
        text_encoder:
        tokenizer:
         scheduler:

    """
    def __init__(self,
                pipeline: StableDiffusionPipeline | None = None,

                 *,
                 model_id: str | None = None,
                 vae: AutoencoderKL | None = None,
                 unet: UNet2DConditionModel | None = None,
                 text_encoder: CLIPTextModel | None = None,
                 tokenizer: CLIPTokenizer|CLIPTokenizerFast | None = None,
                 scheduler: DDIMScheduler|DDPMScheduler | None = None, # TODO: Make enum for typehint, hf so stupid
                 feature_extractor: CLIPImageProcessor|None=None,
                 **addition_kwargs
                 ):
        super().__init__()
        kwargs = dict(
            unet = unet,
            vae=vae,
            text_encoder = text_encoder,
            tokenizer=tokenizer,
            scheduler = scheduler,
            model_id=model_id,
            feature_extractor=feature_extractor,
        )
        # Delete all None values.
        del_k = []
        for k,v in kwargs.items():
            if v is None:
                del_k.append(k)
        for k in del_k:
            del kwargs[k]

        if pipeline is not None:
            assert isinstance(pipeline, StableDiffusionPipeline)
            if kwargs!={}:
                raise ValueError(f"All arguments must be None when `pipeline` is given. Got `{kwargs}.`")
        else:
            if kwargs.pop("model_id",None) is not None:
                pipeline = StableDiffusionPipeline.from_pretrained(model_id,**kwargs)
            else:
                # Does not require feature extractor.
                if "feature_extractor" not in kwargs:
                    kwargs["feature_extractor"] = None
                pipeline = StableDiffusionPipeline(**kwargs, requires_safety_checker=False, safety_checker=None, image_encoder= None)

        self.unet: UNet2DConditionModel = pipeline.unet
        self.vae: AutoencoderKL = pipeline.vae
        self.text_encoder: CLIPTextModel = pipeline.text_encoder
        self.tokenizer:CLIPTokenizer|CLIPTokenizerFast = pipeline.tokenizer
        self.scheduler:DDIMScheduler|DDPMScheduler = pipeline.scheduler
        self.pipeline = pipeline

        self.addition_kwargs = addition_kwargs

    @property
    def latents_values_scaling(self) -> float:
        """Default to 0.18215 and latents values should be multiplied after `vae.encode` and `divided` before `vae.decode`."""
        return self.vae.config.scaling_factor

    @property
    def latents_spatial_reduced_ratio(self) -> int:
        """The reduced ratio"""
        return 2 ** (len(self.vae.config.block_out_channels) - 1)

    def vae_encode(self, images: torch.Tensor)->torch.Tensor:
        latent_dist:DiagonalGaussianDistribution = self.vae.encode(images).latent_dist
        latents = latent_dist.sample()*self.latents_values_scaling
        return latents

    def vae_decode(self, latents: torch.Tensor)->torch.Tensor:
        return self.vae.decode(latents/self.latents_values_scaling).sample

    def loss_fn(self, prediction: torch.Tensor, target: torch.Tensor)->torch.Tensor:
        reduction = self.addition_kwargs.get("reduction","mean")
        return F.mse_loss(prediction, target,reduction=reduction)

    def training_step(self, batch: SDTrainBatch, batch_index:int  , **kwargs: Any) -> STEP_OUTPUT:

        latents = self.vae_encode(images:=batch.images)
        noises = torch.rand_like(latents)
        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps,(latents.size(0),))
        noisy_latents = self.scheduler.add_noise(latents, noises, timesteps)

        text_encoder_output: BaseModelOutputWithPooling = self.text_encoder(input_ids:=batch.input_ids)
        text_embeddings = text_encoder_output.last_hidden_state

        print(images.shape, noises.shape, timesteps.shape)
        unet_output:UNet2DConditionOutput = self.unet(noisy_latents, timesteps, text_embeddings)
        model_pred = unet_output.sample

        if (pred_type := self.scheduler.config.prediction_type) == PredictionType.EPSILON:
            target = noises
        elif pred_type == PredictionType.SAMPLE:
            target = images
        elif pred_type == PredictionType.V_PREDICTION:
            target = self.scheduler.get_velocity(latents, noises, timesteps)
        else:
            raise RuntimeError(f"The `prediction_type`=`{pred_type}` is not supported.")

        loss = self.loss_fn(model_pred, target)

        # self.log("train_loss", loss)

        return loss

    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        pass

    def configure_optimizers(self) -> OptimizerLRScheduler:
        return torch.optim.Adam(self.parameters())


In [2]:
from torchvision.transforms.v2 import Resize
from PIL import Image
out = Resize(256)(Image.new("RGB", (700, 700), color="red"))


In [5]:
a = lambda : 5


In [8]:
a(5)

TypeError: <lambda>() takes 0 positional arguments but 1 was given

In [3]:
from diffusers import DDPMScheduler
from lightning.pytorch.utilities.model_summary import ModelSummary

model_id = "stabilityai/stable-diffusion-2-1"

scheduler_subfolder = "scheduler"
module = StableDiffusionModule(model_id=model_id)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [5]:
ModelSummary(module, max_depth=-1)

     | Name                                                                | Type                    | Params | Mode
--------------------------------------------------------------------------------------------------------------------------
0    | unet                                                                | UNet2DConditionModel    | 865 M  | eval
1    | unet.conv_in                                                        | Conv2d                  | 11.8 K | eval
2    | unet.time_proj                                                      | Timesteps               | 0      | eval
3    | unet.time_embedding                                                 | TimestepEmbedding       | 2.1 M  | eval
4    | unet.time_embedding.linear_1                                        | Linear                  | 410 K  | eval
5    | unet.time_embedding.act                                             | SiLU                    | 0      | eval
6    | unet.time_embedding.linear_2                       

In [52]:
batch = SDTrainBatch(
    input_ids = torch.randint(0,100,[2,12]),
    images = torch.randn([2,3,128,128])
)
loss = module.training_step(batch,1)

torch.Size([2, 3, 128, 128]) torch.Size([2, 4, 16, 16]) torch.Size([2])


/home/hieu/Workspace/projects/petorch/.venv/lib/python3.13/site-packages/lightning/pytorch/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [57]:
loss.shape

torch.Size([])