In [1]:
import torch
from PIL import Image
from diffusers import AutoencoderKL, UNet2DModel, DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, DistillationPipeline
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
import math
import requests
from torchvision.transforms import (
    CenterCrop,
    Compose,
    InterpolationMode,
    Normalize,
    RandomHorizontalFlip,
    Resize,
    ToTensor,
    ToPILImage
)
from torch.utils.data import Dataset
from accelerate import Accelerator
import utils
from tqdm import tqdm
import torch.nn.functional as F




In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f9a051d2010>

In [3]:
training_config = utils.DiffusionTrainingArgs()

In [4]:
# Load an image of my dog for this example

image_url = "https://i.imgur.com/IJcs4Aa.jpeg"
image = Image.open(requests.get(image_url, stream=True).raw)

In [5]:
# Define the transforms to apply to the image for training
augmentations = utils.get_train_transforms(training_config)

In [6]:
class SingleImageDataset(Dataset):
    def __init__(self, image, batch_size):
        self.image = image
        self.batch_size = batch_size

    def __len__(self):
        return self.batch_size

    def __getitem__(self, idx):
        return self.image


In [7]:
train_image = augmentations(image.convert("RGB"))
train_dataset = SingleImageDataset(train_image, training_config.batch_size)

In [8]:
teacher = UNet2DModel.from_pretrained("bglick13/minnie-diffusion")
distiller = DistillationPipeline()

Downloading:   0%|          | 0.00/455M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [9]:
N = 1000
generator = torch.manual_seed(0)


In [None]:
teacher = UNet2DModel.from_pretrained("bglick13/minnie-diffusion")
N = 1000
distilled_images = []
for distill_step in range(2):
    print(f"Distill step {distill_step} from {N} -> {N // 2}")
    teacher, distilled_ema, distill_accelrator = distiller(teacher, N, train_dataset, epochs=300, batch_size=training_config.batch_size)
    N = N // 2
    new_scheduler = DDPMScheduler(num_train_timesteps=N, beta_schedule="squaredcos_cap_v2")
    pipeline = DDPMPipeline(
        unet=distill_accelrator.unwrap_model(distilled_ema.averaged_model if training_config.use_ema else teacher),
        scheduler=new_scheduler,
    )

    # run pipeline in inference (sample random noise and denoise)
    images = pipeline(generator=generator, batch_size=training_config.batch_size, output_type="numpy").images

    # denormalize the images and save to tensorboard
    images_processed = (images * 255).round().astype("uint8")
    distilled_images.append(images_processed[0])


In [None]:
# Display train image for reference
train_image_display = train_image * 0.5 + 0.5
train_image_display = ToPILImage()(train_image_display)
display(train_image_display)

for i, image in enumerate(distilled_images):
    print(f"Distilled image {i}")
    display(Image.fromarray(image))
    Image.fromarray(image).save(f"distilled_{i}.png")

In [None]:
display(Image.fromarray(images_processed[0]))
display(Image.fromarray(images_processed[1]))