In [7]:
%reload_ext autoreload
%matplotlib inline

import inspect
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
from PIL import Image
from diffusers import PNDMScheduler, UNet2DConditionModel, AutoencoderKL
from huggingface_hub import hf_hub_download
from torch import nn
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor

In [8]:
def load_traced_unet(
        checkpoint: str,
        subfolder: str,
        filename: str,
        dtype: torch.dtype,
        device: str = "cuda",
        local_files_only=False,
        cache_dir: Optional[str] = None,
) -> Optional[nn.Module]:
    unet_file = hf_hub_download(
        checkpoint,
        subfolder=subfolder,
        filename=filename,
        local_files_only=local_files_only,
        cache_dir=cache_dir,
    )
    unet_traced = torch.jit.load(unet_file)

    class TracedUNet(nn.Module):
        @dataclass
        class UNet2DConditionOutput:
            sample: torch.Tensor

        def __init__(self):
            super().__init__()
            self.in_channels = device
            self.device = device
            self.dtype = dtype

        def forward(self, latent_model_input, t, encoder_hidden_states):
            sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
            return self.UNet2DConditionOutput(sample=sample)

    return TracedUNet()

In [9]:
class FrozenCLIP(nn.Module):
    def __init__(self, MODEL="riffusion/riffusion-model-v1"):
        super(FrozenCLIP, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.text_encoder = CLIPTextModel.from_pretrained(MODEL, subfolder="text_encoder")
        self.text_encoder = self.text_encoder.to(self.device)

        self.tokenizer = CLIPTokenizer.from_pretrained(MODEL, subfolder="tokenizer")
        self.feature_extractor = CLIPImageProcessor.from_pretrained(MODEL, subfolder="feature_extractor")

    @torch.no_grad()
    def embed_text(self, prompt):
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids

        prompt_embeds = self.text_encoder(text_input_ids.to(self.device))
        prompt_embeds = prompt_embeds[0]
        prompt_embeds_dtype = self.text_encoder.dtype
        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=self.device)

        return prompt_embeds

In [10]:
class SoundStyleTransferModel(nn.Module):
    def __init__(self, MODEL="riffusion/riffusion-model-v1"):
        super(SoundStyleTransferModel, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.frozen_clip = FrozenCLIP()

        self.vae = AutoencoderKL.from_pretrained(MODEL, subfolder="vae")
        self.vae = self.vae.to(self.device)

        self.unet = UNet2DConditionModel.from_pretrained(MODEL, subfolder="unet")
        self.unet = self.unet.to(self.device)

        self.scheduler = PNDMScheduler.from_config(MODEL, subfolder="scheduler")
        self.scheduler.prk_timesteps = np.array([])

        traced_unet = load_traced_unet(
            MODEL,
            subfolder="unet_traced",
            filename="unet_traced.pt",
            dtype=torch.float32,
            device=self.device
        )

        if traced_unet is not None:
            print("Loaded Traced UNet")
            self.unet = traced_unet

    @torch.no_grad()
    def encode_images(self, images):
        return self.vae.encode(images).latent_dist.sample() * 0.18215

    @torch.no_grad()
    def decode_latents(self, latents):
        return self.vae.decode(latents / 0.18215).sample

    def forward(self, latents, text_embeddings, timesteps):
        result = self.unet(
            latents,
            timesteps,
            encoder_hidden_states=text_embeddings
        ).sample

        return result

    def get_text_embeddings(self, alpha, text_prompt_start, text_prompt_end):
        embed_start = self.frozen_clip.embed_text(text_prompt_start)
        embed_end = self.frozen_clip.embed_text(text_prompt_end)
        text_embed = embed_start * (1.0 - alpha) + embed_end * alpha
        return text_embed, embed_start.dtype

    def original_timestep(self, alpha, denoising_a, denoising_b, inference_steps):
        strength = (1 - alpha) * denoising_a + alpha * denoising_b

        offset = self.scheduler.config.get("steps_offset", 0)
        init_timestep = int(inference_steps * strength) + offset
        init_timestep = min(init_timestep, inference_steps)

        timesteps = self.scheduler.timesteps[-init_timestep]
        timesteps = torch.tensor([timesteps], device=self.device)

        return timesteps, init_timestep, offset

    def partial_diffusion(self, latents, alpha, timesteps, dtype):
        noise_a = torch.randn(latents.shape, device=self.device, dtype=dtype)
        noise_b = torch.randn(latents.shape, device=self.device, dtype=dtype)
        noise = self.slerp(alpha, noise_a, noise_b)
        latents = self.scheduler.add_noise(latents, noise, timesteps)
        return latents

    def get_extra_kwargs(self, eta):
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta
        return extra_step_kwargs

    @torch.inference_mode()
    def diffuse(
            self,
            init_image,
            text_prompt_start,
            text_prompt_end,
            inference_steps=50,
            denoising_a=0.75,
            denoising_b=0.75,
            guidance_a=7.0,
            guidance_b=7.0,
            alpha=0.75,
            eta=0.00
    ):
        self.unet.eval()
        self.scheduler.set_timesteps(inference_steps)

        # Guidance for later
        guidance_scale = guidance_a * (1.0 - alpha) + guidance_b * alpha

        text_embed, latents_dtype = self.get_text_embeddings(alpha, text_prompt_start, text_prompt_end)

        image_torch = self.preprocess_image(init_image).to(device=self.device, dtype=latents_dtype)
        init_latents = self.encode_images(image_torch)

        # Partial diffusion
        timesteps, init_timestep, offset = self.original_timestep(alpha, denoising_a, denoising_b, inference_steps)
        init_latents = self.partial_diffusion(init_latents, alpha, timesteps, dtype=latents_dtype)

        extra_step_kwargs = self.get_extra_kwargs(eta)
        t_start = max(inference_steps - init_timestep + offset, 0)
        timesteps = self.scheduler.timesteps[t_start:].to(self.device)

        latents = init_latents
        for t in tqdm(timesteps, total=len(timesteps)):
            with torch.amp.autocast("cuda"):
                pred_noise = self.forward(latents, text_embed, t)
                latents = self.scheduler.step(pred_noise, t, latents, **extra_step_kwargs).prev_sample

        decoded_image = self.decode_latents(latents)
        image = (decoded_image / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).squeeze().numpy()
        image = self.numpy_to_pil(image)[0]

        return image

    @staticmethod
    def preprocess_image(image: Image.Image) -> torch.Tensor:
        w, h = image.size
        w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
        image = image.resize((w, h), resample=Image.LANCZOS)

        image_np = np.array(image).astype(np.float32) / 255.0
        image_np = image_np[None].transpose(0, 3, 1, 2)

        image_torch = torch.from_numpy(image_np)

        return 2.0 * image_torch - 1.0

    @staticmethod
    def slerp(t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995) -> torch.Tensor:
        if not isinstance(v0, np.ndarray):
            inputs_are_torch = True
            input_device = v0.device
            v0 = v0.cpu().numpy()
            v1 = v1.cpu().numpy()

        dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
        if np.abs(dot) > dot_threshold:
            v2 = (1 - t) * v0 + t * v1
        else:
            theta_0 = np.arccos(dot)
            sin_theta_0 = np.sin(theta_0)
            theta_t = theta_0 * t
            sin_theta_t = np.sin(theta_t)
            s0 = np.sin(theta_0 - theta_t) / sin_theta_0
            s1 = sin_theta_t / sin_theta_0
            v2 = s0 * v0 + s1 * v1

        if inputs_are_torch:
            v2 = torch.from_numpy(v2).to(input_device)

        return v2

    @staticmethod
    def numpy_to_pil(images):
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        if images.shape[-1] == 1:
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            pil_images = [Image.fromarray(image) for image in images]

        return pil_images

In [11]:
model = SoundStyleTransferModel()

An error occurred while trying to fetch riffusion/riffusion-model-v1: riffusion/riffusion-model-v1 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch riffusion/riffusion-model-v1: riffusion/riffusion-model-v1 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


Loaded Traced UNet


In [12]:
image = Image.open("in.png")
prompt_start = "Church bells on sunday"
prompt_end = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8K"

image = model.diffuse(image, prompt_start, prompt_end)

image.save("out.png")

100%|██████████| 38/38 [00:07<00:00,  5.28it/s]
