In [None]:
import torch
import os
import time
import pandas as pd
from tqdm.auto import tqdm
import argparse

import sys
sys.path.append('../.')
from utils.load_util import load_sdxl_models, load_pipe



distillation_type='dmd' # what type of distillation model do you want to use ("dmd", "lcm", "turbo", "lightning")
device = 'cuda:0'
weights_dtype = torch.bfloat16

pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler = load_sdxl_models(distillation_type=distillation_type, 
                                                                                        weights_dtype=weights_dtype, 
                                                                                        device=device)

In [None]:
base_guidance_scale= 5
distilled_guidance_scale = 0

run_base_till_timestep = None # set to none if you want it to be automatically decided
run_distilled_from_timestep = 1


# how many total timesteps to set for schedulers
base_num_inference_steps = 4 
distilled_num_inference_steps = 4

# for paper consistent results use this
base_scheduler = distilled_scheduler

# set the timesteps for the model
base_scheduler.set_timesteps(base_num_inference_steps)
distilled_scheduler.set_timesteps(distilled_num_inference_steps)

# automatically figure out what is the natural point to turn off the base model
if run_base_till_timestep is None:
    # check the timestep from which you need to run the model
    distilled_timestep = distilled_scheduler.timesteps[run_distilled_from_timestep]

    # check the closest timestep in basemodel
    base_timesteps = abs(base_scheduler.timesteps - distilled_timestep)
    run_base_till_timestep = base_timesteps.argmin()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# User prompt
prompt = "cartoon character"  # Replace with your desired prompt
num_images = 15  # 5x5 grid


nrows = 3
ncols = 5

# Initialize variables
total_time = 0
all_images = []

pipe.set_progress_bar_config(disable=True)
# Generate images
for i in tqdm(range(num_images)):
    # Generate random seed
    seed = np.random.randint(0, 2**32 - 1)
    generator = torch.manual_seed(seed)
    
    # First use base model
    pipe.unet = base_unet
    pipe.scheduler = base_scheduler
    
    start_time = time.perf_counter()
    base_latents = pipe(prompt=prompt, from_timestep=0, till_timestep=run_base_till_timestep, 
                         guidance_scale=base_guidance_scale, num_inference_steps=base_num_inference_steps, 
                         output_type='latent')
    
    # Switch to distilled model
    pipe.unet = distilled_unet
    pipe.scheduler = distilled_scheduler
    
    
    pil_image = pipe(prompt=prompt, start_latents=base_latents, guidance_scale=distilled_guidance_scale,
                      from_timestep=run_distilled_from_timestep, till_timestep=None, 
                      num_inference_steps=distilled_num_inference_steps)[0]
    end_time = time.perf_counter()
    
    runtime = end_time - start_time
    total_time += runtime
    
    # Convert PIL image to numpy array and append to list
    all_images.append(pil_image)

# Create 5x5 grid with no whitespace
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3), dpi=200)
plt.subplots_adjust(wspace=0, hspace=0)  # Remove spacing between subplots
fig.patch.set_visible(False)  # Hide the figure's background
for i, ax in enumerate(axes.flat):
    ax.imshow(all_images[i])
    ax.axis('off')  # Remove axes
    ax.set_xticklabels([])  # Remove tick labels
    ax.set_yticklabels([])
    ax.tick_params(left=False, bottom=False)  # Remove ticks

# Remove any remaining borders/margins
plt.tight_layout(pad=0, h_pad=0, w_pad=0)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())

print(f"Total Runtime: {total_time:.4f} seconds")
print(f"Single Runtime: {runtime:.4f} seconds")
# plt.savefig('grid.png', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
for im in all_images:
    display(im)

# Individual Pipe Inference

In [None]:
import torch
import os
import time
import pandas as pd
from tqdm.auto import tqdm
import argparse

import sys
sys.path.append('.')
from utils.load_util import load_sdxl_models, load_pipe



distillation_type= 'dmd' # set to None for base model
device = 'cuda:0'
weights_dtype = torch.bfloat16

pipe = load_pipe(distillation_type=distillation_type, 
                  weights_dtype=weights_dtype, 
                    device=device)

In [None]:
guidance_scale = 0
num_inference_steps = 4

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# User prompt
prompt = "image of a dog"  # Replace with your desired prompt
num_images = 15  # 5x5 grid


nrows = 5
ncols = 3

# Initialize variables
total_time = 0
all_images = []

pipe.set_progress_bar_config(disable=True)
# Generate images
for i in tqdm(range(num_images)):
    # Generate random seed
    seed = np.random.randint(0, 2**32 - 1)
    generator = torch.manual_seed(seed)
    
    # First use base model    
    start_time = time.perf_counter()
    pil_image = pipe(prompt=prompt, guidance_scale=guidance_scale,
                      num_inference_steps=num_inference_steps)[0]
    end_time = time.perf_counter()
    
    runtime = end_time - start_time
    total_time += runtime
    
    # Convert PIL image to numpy array and append to list
    all_images.append(pil_image)


# Create 5x5 grid with no whitespace
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3), dpi=200)
plt.subplots_adjust(wspace=0, hspace=0)  # Remove spacing between subplots
fig.patch.set_visible(False)  # Hide the figure's background
for i, ax in enumerate(axes.flat):
    ax.imshow(all_images[i])
    ax.axis('off')  # Remove axes
    ax.set_xticklabels([])  # Remove tick labels
    ax.set_yticklabels([])
    ax.tick_params(left=False, bottom=False)  # Remove ticks

# Remove any remaining borders/margins
plt.tight_layout(pad=0, h_pad=0, w_pad=0)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())

print(f"Total Runtime: {total_time:.4f} seconds")
print(f"Single Runtime: {runtime:.4f} seconds")
plt.savefig('grid.png', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
for im in all_images:
    display(im)
    