In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../src')
from diffusion_prior.pipeline import DiffusionPrior
from diffusion_prior.model import DiffusionPriorUNet
from diffusion_prior.dataset import EmbeddingDataset, EmbeddingDataLoader

import torch

In [None]:
# define prior model (U-Net) with condition dim = combined embed dim = CLIP-ViT-H-14 dim = 1024
prior_model = DiffusionPriorUNet(cond_dim=1024)

# define prior pipeline with train() and generate() methods
pipe = DiffusionPrior(prior_model)

In [None]:
# set random combined and image embeddings
combined_embeddings = torch.randn(64, 1024)
image_embeddings = torch.randn(64, 1024)

In [None]:
# initialize custom dataset with pairs (combined embed, image embed)
dataset = EmbeddingDataset(combined_embeddings, image_embeddings)

In [None]:
# initialize custom dataloader
dataloader = EmbeddingDataLoader(dataset, batch_size=16)

In [None]:
# train prior diffusion model over num_epochs epochs with learning_rate lr
pipe.train(dataloader, num_epochs=10, learning_rate=1e-4)

In [None]:
# test generate() method
# P.S. it requires [B, 1024] dimension
combined_embeds = combined_embeddings[0].unsqueeze(0)
image_embeds = image_embeddings[0].unsqueeze(0)

In [None]:
# generate image embedding with prior model from combined embedding
image_embeds_generated = pipe.generate(
    combined_embeds=combined_embeds, 
    num_inference_steps=50, 
    guidance_scale=5.0
)