In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import torch
from torch import Tensor
import torch.nn.functional as F
device = torch.device('cuda')
import matplotlib.pyplot as plt
import math
from tqdm import tqdm

from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, StableDiffusionPipeline
from typing import *
from jaxtyping import *

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [2]:
from pds import PDS, PDSConfig

pds = PDS(PDSConfig(
    sd_pretrained_model_or_path='stabilityai/stable-diffusion-2-1-base'
))

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [5]:
reference = torch.tensor(plt.imread('./base.png'))
reference = reference[..., :3].permute(2, 0, 1)[None, ...]
reference = reference.to(pds.unet.device)

In [6]:
def decode_latent(latent):
    latent = latent.detach().to(device)
    with torch.no_grad():
        rgb = pds.decode_latent(latent)
    rgb = rgb.float().cpu().permute(0, 2, 3, 1)
    rgb = rgb.permute(1, 0, 2, 3)
    rgb = rgb.flatten(start_dim=1, end_dim=2)
    return rgb

In [7]:
reference_latent = pds.encode_image(reference)
decoded = decode_latent(reference_latent)

In [None]:
im = reference_latent.clone().to(device)
im.requires_grad_(True)
im.retain_grad()

im_optimizer = torch.optim.AdamW([im], lr=0.01, betas=(0.9, 0.99), eps=1e-15)

for step in tqdm(range(2000)):
    im_optimizer.zero_grad()

    with torch.no_grad():
        pds_dict = pds(
            tgt_x0=im,
            src_x0=reference_latent.clone(),
            tgt_prompt="a DSLR photo of a dog in a winter wonderland",
            src_prompt="a DSLR photo of a dog",
            return_dict=True
        )
    grad = pds_dict['grad']

    # loss.backward()
    im.backward(gradient=grad)
    im_optimizer.step()

    if step % 20 == 0:
        decoded = decode_latent(im.detach()).cpu().numpy()
        plt.imsave('./pds_debug.png', decoded)

In [None]:
# SDS Generation

im = 0.8 * torch.randn_like(reference_latent)
im.requires_grad_(True)
im.retain_grad()

im_optimizer = torch.optim.AdamW([im], lr=0.003, betas=(0.9, 0.99), eps=1e-15)

for step in tqdm(range(4000)):
    im_optimizer.zero_grad()

    with torch.no_grad():
        pds.config.guidance_scale = 7.5
        pds_dict = pds.sds_loss(
            im=im,
            prompt="a DSLR photo of a dog in a winter wonderland",
            return_dict=True
        )
    grad = pds_dict['grad']

    # loss.backward()
    im.backward(gradient=grad)
    im_optimizer.step()

    if step % 20 == 0:
        decoded = decode_latent(im.detach()).cpu().numpy()
        plt.imsave('./sds_gen_debug.png', decoded)

In [55]:
# PDS Generation

batch_size = 1

im = 0.8 * torch.randn_like(reference_latent.repeat(batch_size, 1, 1, 1))
im.requires_grad_(True)
im.retain_grad()

im_optimizer = torch.optim.AdamW([im], lr=0.003, betas=(0.9, 0.99), eps=1e-15)

for step in tqdm(range(4000)):
    im_optimizer.zero_grad()

    with torch.no_grad():
        pds.config.guidance_scale = 7.5
        pds_dict = pds.pds_gen(
            im=im,
            prompt="a DSLR photo of a dog in a winter wonderland",
            return_dict=True
        )
    grad = pds_dict['grad']

    # loss.backward()
    im.backward(gradient=grad)
    im_optimizer.step()

    if step % 20 == 0:
        decoded = decode_latent(im.detach()).cpu().numpy()
        plt.imsave('./pds_gen_debug.png', decoded)

 42%|████▏     | 1670/4000 [05:26<07:15,  5.35it/s]

In [20]:
# PDS + SDEdit Generation

batch_size = 1

im = 0.8 * torch.randn_like(reference_latent.repeat(batch_size, 1, 1, 1))
im.requires_grad_(True)
im.retain_grad()

im_optimizer = torch.optim.AdamW([im], lr=0.01, betas=(0.9, 0.99), eps=1e-15)

for step in tqdm(range(4000)):
    im_optimizer.zero_grad()

    with torch.no_grad():
        pds.config.guidance_scale = 100
        pds_dict = pds.pds_gen_sdedit_src(
            im=im,
            prompt="a DSLR photo of a dog in a winter wonderland",
            skip_percentage = min(step / 1500, 0.8),
            num_solve_steps = 12 + min(step // 200, 20),
            return_dict=True
        )
    grad = pds_dict['grad']

    # loss.backward()
    im.backward(gradient=grad)
    im_optimizer.step()

    if step % 20 == 0:
        decoded = decode_latent(im.detach()).cpu().numpy()
        plt.imsave('./pds_gen_sdedit_debug.png', decoded)

  0%|          | 0/4000 [00:00<?, ?it/s]

 32%|███▏      | 1267/4000 [42:31<1:31:43,  2.01s/it]
