This notebook show how to load a pretrained model from the diffusers library and encapsulate it in a module such that it can be used in a pytorch lightning Trainer pipeline.

In [1]:
from diffusers import AutoencoderKL
from src.models.lit_vae import LitVAE

In [2]:
# load the VAE model from the Huggingface hub
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-medium", subfolder="vae")
vae.eval()

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

In [3]:
# create a pytorch lightning module from the VAE model
lit_vae = LitVAE(vae)
lit_vae.eval()

LitVAE(
  (model): AutoencoderKL(
    (encoder): Encoder(
      (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (down_blocks): ModuleList(
        (0): DownEncoderBlock2D(
          (resnets): ModuleList(
            (0-1): 2 x ResnetBlock2D(
              (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
              (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (nonlinearity): SiLU()
            )
          )
          (downsamplers): ModuleList(
            (0): Downsample2D(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
            )
          )
        )
        (1): DownEncoderBlock2D(
          (resnets): ModuleList(
            (0): ResnetBlock2D(
             

In [4]:
# Create data module
from src.dataloader.ffhq import FFHQWeightedTensorDataset
from src.dataloader.weighting import DataWeighter

# Datamodule
img_dir="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/images1024x1024"
pt_dir="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/pt_images"
train_attr_path="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/ffhq_smile_scores.json"
val_attr_path="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/ffhq_smile_scores.json"
combined_attr_path="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/ffhq_smile_scores.json"
max_property_value=5
min_property_value=0
mode="all"
batch_size=128
num_workers=2 # 4

# Weighter
weight_type="uniform"
rank_weight_k=1e-3
weight_quantile=None
dbas_noise=None
rwr_alpha=None

from argparse import Namespace

args = Namespace(
    img_dir=img_dir,
    pt_dir=pt_dir,
    train_attr_path=train_attr_path,
    val_attr_path=val_attr_path,
    combined_attr_path=combined_attr_path,
    max_property_value=max_property_value,
    min_property_value=min_property_value,
    mode=mode,
    batch_size=2,
    num_workers=num_workers,
    weight_type=weight_type,
    rank_weight_k=rank_weight_k,
    weight_quantile=weight_quantile,
    dbas_noise=dbas_noise,
    rwr_alpha=rwr_alpha
)

datamodule = FFHQWeightedTensorDataset(args, DataWeighter(args))
datamodule.setup() # assignment into train/validation split is made and weights are set

In [5]:
# get one batch of data
batch = next(iter(datamodule.train_dataloader()))
batch.shape

torch.Size([2, 3, 256, 256])

In [17]:
# apply forward pass of the VAE model
recon_batch, latent_dist = lit_vae(batch)

In [19]:
latent_dist.var

tensor([[[[2.1981e-08, 3.6542e-09, 5.0910e-09,  ..., 2.9410e-09,
           3.0296e-09, 1.9005e-08],
          [3.7177e-09, 3.9722e-09, 5.6782e-09,  ..., 2.8220e-09,
           3.1870e-09, 3.8599e-09],
          [3.7848e-09, 3.9655e-09, 4.5738e-09,  ..., 4.1777e-09,
           4.6818e-09, 5.8745e-09],
          ...,
          [1.2254e-09, 2.3659e-09, 2.8435e-09,  ..., 2.1406e-09,
           2.0443e-09, 2.1301e-09],
          [1.2691e-09, 2.1038e-09, 3.6276e-09,  ..., 2.7891e-09,
           1.8665e-09, 2.5764e-09],
          [4.5417e-08, 1.0386e-08, 1.0909e-08,  ..., 8.4666e-09,
           9.3299e-09, 5.9755e-08]],

         [[1.7706e-08, 2.5898e-09, 3.1363e-09,  ..., 2.3293e-09,
           1.9652e-09, 1.4713e-08],
          [2.5446e-09, 2.9112e-09, 3.0711e-09,  ..., 1.9907e-09,
           2.0756e-09, 2.5620e-09],
          [3.0528e-09, 2.3885e-09, 2.5124e-09,  ..., 2.9873e-09,
           3.1041e-09, 3.8116e-09],
          ...,
          [9.7413e-10, 1.8493e-09, 2.0367e-09,  ..., 1.6170