This notebook show how to load a pretrained model from the diffusers library and encapsulate it in a module such that it can be used in a pytorch lightning Trainer pipeline.

In [None]:
from diffusers import AutoencoderKL
from src.models.lit_vae import LitVAE

In [None]:
# load the VAE model from the Huggingface hub
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-medium", subfolder="vae")
vae.eval()

In [None]:
# create a pytorch lightning module from the VAE model
lit_vae = LitVAE(vae)
lit_vae.eval()

In [None]:
# Create data module
from src.dataloader.ffhq import FFHQWeightedTensorDataset
from src.dataloader.weighting import DataWeighter

# Datamodule
img_dir="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/images1024x1024"
pt_dir="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/pt_images"
train_attr_path="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/ffhq_smile_scores.json"
val_attr_path="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/ffhq_smile_scores.json"
combined_attr_path="/pfs/work7/workspace/scratch/ma_mgraevin-optdif/data/ffhq/ffhq_smile_scores.json"
max_property_value=5
min_property_value=0
mode="all"
batch_size=128
num_workers=2 # 4

# Weighter
weight_type="uniform"
rank_weight_k=1e-3
weight_quantile=None
dbas_noise=None
rwr_alpha=None

from argparse import Namespace

args = Namespace(
    img_dir=img_dir,
    pt_dir=pt_dir,
    train_attr_path=train_attr_path,
    val_attr_path=val_attr_path,
    combined_attr_path=combined_attr_path,
    max_property_value=max_property_value,
    min_property_value=min_property_value,
    mode=mode,
    batch_size=2,
    num_workers=num_workers,
    weight_type=weight_type,
    rank_weight_k=rank_weight_k,
    weight_quantile=weight_quantile,
    dbas_noise=dbas_noise,
    rwr_alpha=rwr_alpha
)

datamodule = FFHQWeightedTensorDataset(args, DataWeighter(args))
datamodule.setup() # assignment into train/validation split is made and weights are set

In [None]:
# get one batch of data
batch = next(iter(datamodule.train_dataloader()))
batch.shape

In [None]:
# apply forward pass of the VAE model
recon_batch, latent_dist = lit_vae(batch)

In [None]:
latent_dist.var