# Core 

> Minimal pipeline for Diffusion Guidance experiments.

In [None]:
#| default_exp core

In [None]:
#| export
# imports for diffusion models
from abc import ABC
import importlib
from PIL import Image
import torch
from tqdm.auto    import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers    import AutoencoderKL, UNet2DConditionModel
from diffusers    import LMSDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler
import torch
from torch import nn
try:
    from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
    from k_diffusion.sampling import get_sigmas_karras
    import k_diffusion.sampling as k_sampling
except:
    print(f'WARNING: Could not import k_diffusion')
from min_diffusion.kdiff import *

In [None]:
#| export

    
class MinimalDiffusion:
    """Loads a Stable Diffusion pipeline.
    
    The goal is to have more control of the image generation loop. 
    This class loads the following individual pieces:
        - Tokenizer
        - Text encoder
        - VAE
        - U-Net
        - Sampler
        
    The `self.generate` function uses these pieces to run a Diffusion image generation loop.
    
    This class can be subclasses and any of its methods overriden to gain even more control over the Diffusion pipeline. 
    """
    def __init__(self, model_name, device, dtype, revision, generator=None, use_k_diffusion='',
                 better_vae='', unet_attn_slice=True, scheduler_kls='euler'):
        self.model_name = model_name
        self.device = device
        self.dtype = dtype
        self.revision = revision
        self.generator = None 
        self.use_k_diffusion = use_k_diffusion
        self.better_vae = better_vae
        self.unet_attn_slice = unet_attn_slice
        self.scheduler_kls = scheduler_kls
        
        
    def load(self):
        """Loads and returns the individual pieces in a Diffusion pipeline.        
        """
        # load the pieces
        self.load_text_pieces()
        self.load_vae()
        self.load_unet()
        self.load_scheduler()
        # put them on the device
        self.to_device()
    
    def load_text_pieces(self):
        """Creates the tokenizer and text encoder.
        """
        tokenizer = CLIPTokenizer.from_pretrained(
            self.model_name,
            subfolder="tokenizer",
            torch_dtype=self.dtype)
        text_encoder = CLIPTextModel.from_pretrained(
            self.model_name,
            subfolder="text_encoder",
            torch_dtype=self.dtype)
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
    
    
    def load_vae(self):
        """Loads the Variational Auto-Encoder.
        
        Optionally loads an improved `better_vae` from the stability.ai team.
            It can be either the `ema` or `mse` VAE.
        """
        # optionally use a VAE from stability that was trained for longer 
        if self.better_vae:
            assert self.better_vae in ('ema', 'mse')
            print(f'Using the improved VAE "{better_vae}" from stabiliy.ai')
            vae = AutoencoderKL.from_pretrained(
                f"stabilityai/sd-vae-ft-{self.better_vae}",
                revision=self.revision,
                torch_dtype=self.dtype)
        else:
            vae = AutoencoderKL.from_pretrained(self.model_name, subfolder='vae',
                                                torch_dtype=self.dtype)
        self.vae = vae

        
    def load_unet(self):
        """Loads the U-Net.
        
        Optionally uses attention slicing to fit on smaller GPU cards.
        """
        unet = UNet2DConditionModel.from_pretrained(
            self.model_name,
            subfolder="unet",
            #revision=self.revision,
            torch_dtype=self.dtype)
        # optionally enable unet attention slicing
        if self.unet_attn_slice:
            print('Enabling default unet attention slicing.')
            if isinstance(unet.config.attention_head_dim, int):
                # half the attention head size is usually a good trade-off between
                # speed and memory
                slice_size = unet.config.attention_head_dim // 2
            else:
                # if `attention_head_dim` is a list, take the smallest head size
                slice_size = min(unet.config.attention_head_dim)
            unet.set_attention_slice(slice_size)
        self.unet = unet
        
                
    def load_scheduler(self):
        """Loads the scheduler.
        """
        if self.scheduler_kls == 'euler':
            sched_kls = EulerDiscreteScheduler
        elif self.scheduler_kls == 'euler_ancestral':
            sched_kls = EulerAncestralDiscreteScheduler
        elif self.scheduler_kls == 'dpm_multi':
            sched_kls = DPMSolverMultistepScheduler
        else:
            self.sched_kls = LMSDiscreteScheduler
        print(f'Using scheduler: {sched_kls}')
        if self.model_name.split('/')[-1] == 'stabilityai/stable-diffusion-2':
            sched_kwargs = {'prediction_type': 'v-prediction'}
        else:
            sched_kwargs = {}
        scheduler = sched_kls.from_pretrained(self.model_name, 
                                              subfolder="scheduler",
                                              **sched_kwargs)
        self.scheduler = scheduler 


    def generate(
        self,
        prompt,
        guide_tfm=None,
        width=512,
        height=512,
        steps=50,
        use_karras_sigmas=False,
        **kwargs
    ):
        """Main image generation loop.
        """
        # if no guidance transform was given, use the default update
        if guide_tfm is None:
            print('NOTE: Using the default, static Classifier-free Guidance.')
            G = 7.5
            def guide_tfm(uncond, cond, idx): return uncod + G * (cond - uncond)
        self.guide_tfm = guide_tfm
        
        # prepare the text embeddings
        text = self.encode_text(prompt)
        neg_prompt = kwargs.get('negative_prompt', '')
        if neg_prompt:
            print(f'Using negative prompt: {neg_prompt}')
        uncond = self.encode_text(neg_prompt)
        
        # start from the shared, initial latents
        if getattr(self, 'init_latents', None) is None:
            self.init_latents = self.get_initial_latents(height, width)
        latents = self.init_latents.clone().to(self.unet.device)
        
        # set the number of timesteps
        # TODO: get alphas_cumprod from the k_diffusion schedules
        self.scheduler.set_timesteps(steps, device=self.unet.device)
        
        if self.use_k_diffusion:
            print(f'NOTE: Generating with k-diffusion Samplers')
            
            # load the sampler class
            SamplerCls = SAMPLER_LOOKUP[self.use_k_diffusion]
            model = ModelWrapper(self.unet, self.scheduler.alphas_cumprod)
            sampler = SamplerCls(model, self.model_name)
            
            # set the positive and neutral (or negative) conditionings 
            positive_conditioning = text
            neutral_conditioning = uncond
            
            # move wrapped sigmas and log-sigmas to device
            sampler.cv_denoiser.sigmas = sampler.cv_denoiser.sigmas.to(latents.device)
            sampler.cv_denoiser.log_sigmas = sampler.cv_denoiser.log_sigmas.to(latents.device)
            
            # sample with k_diffusion
            latents = sampler.sample(
                num_steps=steps,
                initial_latent=latents,
                positive_conditioning=positive_conditioning,
                neutral_conditioning=neutral_conditioning,
                t_start=None,#t_enc,
                mask=None,#mask,
                orig_latent=None,#init_latent,
                shape=latents.shape,
                batch_size=1,
                guide_tfm=guide_tfm,
                use_karras_sigmas=use_karras_sigmas,
            )

        
        else:
            # prepare the conditional and unconditional inputs
            text_emb = torch.cat([uncond, text]).type(self.unet.dtype)
            # scale the latents
            latents = latents * self.scheduler.init_noise_sigma
            # run the diffusion process
            for i,ts in enumerate(tqdm(self.scheduler.timesteps)):
                latents = self.diffuse_step(latents, text_emb, ts, i)

        # decode the final latents and return the generated image
        image = self.image_from_latents(latents)
        return image    


    def diffuse_step(self, latents, text_emb, ts, idx):
        """Runs a single diffusion step.
        """
        inp = self.scheduler.scale_model_input(torch.cat([latents] * 2), ts)
        with torch.no_grad(): 
            tf = ts
            if torch.has_mps:
                tf = ts.type(torch.float32)
            preds = self.unet(inp, tf, encoder_hidden_states=text_emb)
            u, t  = preds.sample.chunk(2)
        
        # run classifier-free guidance
        pred = self.guide_tfm(u, t, idx)
        
        # update and return the latents
        latents = self.scheduler.step(pred, ts, latents).prev_sample
        return latents
    

    def encode_text(self, prompts, maxlen=None):
        """Extracts text embeddings from the given `prompts`.
        """
        maxlen = maxlen or self.tokenizer.model_max_length
        inp = self.tokenizer(prompts, padding="max_length", max_length=maxlen, 
                             truncation=True, return_tensors="pt")
        inp_ids = inp.input_ids.to(self.device)
        return self.text_encoder(inp_ids)[0]

    
    def to_device(self, device=None):
        """Places to pipeline pieces on the given device
        
        Note: assumes we keep Scheduler and Tokenizer on the cpu.
        """
        device = device or self.device
        for m in (self.text_encoder, self.vae, self.unet):
            m.to(device)
    
    
    def set_initial_latents(self, latents):
        """Sets the given `latents` as the initial noise latents.
        """
        self.init_latents = latents
        
        
    def get_initial_latents(self, height, width):
        """Returns an initial set of latents.
        """
        return torch.randn((1, self.unet.in_channels, height//8, width//8),
                           dtype=self.dtype, generator=self.generator)
    
    
    def image_from_latents(self, latents):
        """Scales diffusion `latents` and turns them into a PIL Image.
        """
        # scale and decode the latents
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            data = self.vae.decode(latents.type(self.vae.dtype)).sample[0]
        # Create PIL image
        data = (data / 2 + 0.5).clamp(0, 1)
        data = data.cpu().permute(1, 2, 0).float().numpy()
        data = (data * 255).round().astype("uint8")
        image = Image.fromarray(data)
        return image

In [None]:
# #| export

# def get_device() -> str:
#     """Return the best torch backend available"""
#     if torch.cuda.is_available():
#         return "cuda"

#     if torch.backends.mps.is_available():
#         return "mps:0"

#     return "cpu"

# def ensure_4_dim(t: torch.Tensor):
#     if len(t.shape) == 3:
#         t = t.unsqueeze(dim=0)
#     return t

# class SamplerName:
#     PLMS = "plms"
#     DDIM = "ddim"
#     K_DPM_FAST = "k_dpm_fast"
#     K_DPM_ADAPTIVE = "k_dpm_adaptive"
#     K_LMS = "k_lms"
#     K_DPM_2 = "k_dpm_2"
#     K_DPM_2_ANCESTRAL = "k_dpm_2_a"
#     K_DPMPP_2M = "k_dpmpp_2m"
#     K_DPMPP_2S_ANCESTRAL = "k_dpmpp_2s_a"
#     K_EULER = "k_euler"
#     K_EULER_ANCESTRAL = "k_euler_a"
#     K_HEUN = "k_heun"
#     K_DPMPP_SDE = 'k_dpmpp_sde'


# class ImageSampler(ABC):
#     short_name: str
#     name: str
#     default_steps: int
#     default_size: int

#     def __init__(self, model):
#         self.model = model
#         self.device = get_device()
        
        
# def get_noise_prediction(
#     denoise_func,
#     noisy_latent,
#     time_encoding,
#     neutral_conditioning,
#     positive_conditioning,
#     signal_amplification=7.5,
# ):
#     noisy_latent = ensure_4_dim(noisy_latent)

#     noisy_latent_in = torch.cat([noisy_latent] * 2)
#     time_encoding_in = torch.cat([time_encoding] * 2)
#     if isinstance(positive_conditioning, dict):
#         assert isinstance(neutral_conditioning, dict)
#         conditioning_in = {}
#         for k in positive_conditioning:
#             if isinstance(positive_conditioning[k], list):
#                 conditioning_in[k] = [
#                     torch.cat([neutral_conditioning[k][i], positive_conditioning[k][i]])
#                     for i in range(len(positive_conditioning[k]))
#                 ]
#             else:
#                 conditioning_in[k] = torch.cat(
#                     [neutral_conditioning[k], positive_conditioning[k]]
#                 )
#     else:
#         conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])

#     # the k-diffusion samplers actually return the denoised predicted latents but things seem
#     # to work anyway
#     noise_pred_neutral, noise_pred_positive = denoise_func(
#         noisy_latent_in, time_encoding_in, conditioning_in
#     ).chunk(2)

#     amplified_noise_pred = signal_amplification * (
#         noise_pred_positive - noise_pred_neutral
#     )
#     noise_pred = noise_pred_neutral + amplified_noise_pred

#     return noise_pred


# def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
#     """
#     Apply a mask to the noisy_latent.
#     ts is a decreasing value between 1000 and 1
#     """
#     assert orig_latent is not None
#     log_latent(orig_latent, "orig_latent")
#     noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise)

#     # this helps prevent the weird disjointed images that can happen with masking
#     hint_strength = 1
#     # if we're in the first 10% of the steps then don't fully noise the parts
#     # of the image we're not changing so that the algorithm can learn from the context
#     if ts > 1000:
#         hinted_orig_latent = (
#             noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
#         )
#         log_latent(hinted_orig_latent, f"hinted_orig_latent {ts}")
#     else:
#         hinted_orig_latent = noised_orig_latent

#     hinted_orig_latent_masked = hinted_orig_latent * mask
#     log_latent(hinted_orig_latent_masked, f"hinted_orig_latent_masked {ts}")
#     noisy_latent_masked = (1.0 - mask) * noisy_latent
#     log_latent(noisy_latent_masked, f"noisy_latent_masked {ts}")
#     noisy_latent = hinted_orig_latent_masked + noisy_latent_masked
#     log_latent(noisy_latent, f"mask-blended noisy_latent {ts}")
#     return noisy_latent


# #| export

# class StandardCompVisDenoiser(CompVisDenoiser):
#     def apply_model(self, *args, **kwargs):
#         return self.inner_model.apply_model(*args, **kwargs)


# class StandardCompVisVDenoiser(CompVisVDenoiser):
#     def apply_model(self, *args, **kwargs):
#         return self.inner_model.apply_model(*args, **kwargs)

# #| export

# def sample_dpm_adaptive(
#     model, x, sigmas, extra_args=None, disable=False, callback=None
# ):
#     sigma_min = sigmas[-2]
#     sigma_max = sigmas[0]
#     return k_sampling.sample_dpm_adaptive(
#         model=model,
#         x=x,
#         sigma_min=sigma_min,
#         sigma_max=sigma_max,
#         extra_args=extra_args,
#         disable=disable,
#         callback=callback,
#     )


# def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=None):
#     sigma_min = sigmas[-2]
#     sigma_max = sigmas[0]
#     return k_sampling.sample_dpm_fast(
#         model=model,
#         x=x,
#         sigma_min=sigma_min,
#         sigma_max=sigma_max,
#         n=len(sigmas),
#         extra_args=extra_args,
#         disable=disable,
#         callback=callback,
#     )


# class KDiffusionSampler(ImageSampler, ABC):
#     sampler_func: callable

#     def __init__(self, model):
#         super().__init__(model)
#         denoiseer_cls = (
#             StandardCompVisVDenoiser
#             # if model.parameterization == "v"
#             # else StandardCompVisDenoiser
#         )
#         self.cv_denoiser = denoiseer_cls(model)

#     def sample(
#         self,
#         num_steps,
#         shape,
#         neutral_conditioning,
#         positive_conditioning,
#         guidance_scale,
#         batch_size=1,
#         mask=None,
#         orig_latent=None,
#         initial_latent=None,
#         t_start=None,
#     ):
#         # if positive_conditioning.shape[0] != batch_size:
#         #     raise ValueError(
#         #         f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
#         #     )

#         if initial_latent is None:
#             initial_latent = torch.randn(shape, device="cpu").to(self.device)

#         #log_latent(initial_latent, "initial_latent")
#         if t_start is not None:
#             t_start = num_steps - t_start + 1

#         sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]

#         # if our number of steps is zero, just return the initial latent
#         if sigmas.nelement() == 0:
#             if orig_latent is not None:
#                 return orig_latent
#             return initial_latent

#         x = initial_latent * sigmas[0]
#         #log_latent(x, "initial_sigma_noised_tensor")
#         model_wrap_cfg = CFGDenoiser(self.cv_denoiser)

#         mask_noise = None
#         if mask is not None:
#             mask_noise = torch.randn_like(initial_latent, device="cpu").to(
#                 initial_latent.device
#             )

#         samples = self.sampler_func(
#             model=model_wrap_cfg,
#             x=x,
#             sigmas=sigmas,
#             extra_args={
#                 "cond": positive_conditioning,
#                 "uncond": neutral_conditioning,
#                 "cond_scale": guidance_scale,
#                 "mask": mask,
#                 "mask_noise": mask_noise,
#                 "orig_latent": orig_latent,
#             },
#             disable=False,
#             #callback=callback,
#         )

#         return samples


# class DPMFastSampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPM_FAST
#     name = "Diffusion probabilistic models - fast"
#     default_steps = 15
#     sampler_func = staticmethod(sample_dpm_fast)


# class DPMAdaptiveSampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPM_ADAPTIVE
#     name = "Diffusion probabilistic models - adaptive"
#     default_steps = 40
#     sampler_func = staticmethod(sample_dpm_adaptive)


# class DPM2Sampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPM_2
#     name = "Diffusion probabilistic models - 2"
#     default_steps = 40
#     sampler_func = staticmethod(k_sampling.sample_dpm_2)


# class DPM2AncestralSampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPM_2_ANCESTRAL
#     name = "Diffusion probabilistic models - 2 ancestral"
#     default_steps = 40
#     sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral)


# class DPMPP2MSampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPMPP_2M
#     name = "Diffusion probabilistic models - 2m"
#     default_steps = 15
#     sampler_func = staticmethod(k_sampling.sample_dpmpp_2m)
    
    
# class DPMPPSDESampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPMPP_SDE
#     name = "Diffusion probabilistic models - 2m"
#     default_steps = 30
#     sampler_func = staticmethod(k_sampling.sample_dpmpp_sde)


# class DPMPP2SAncestralSampler(KDiffusionSampler):
#     short_name = SamplerName.K_DPMPP_2S_ANCESTRAL
#     name = "Ancestral sampling with DPM-Solver++(2S) second-order steps."
#     default_steps = 15
#     sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral)


# class EulerSampler(KDiffusionSampler):
#     short_name = SamplerName.K_EULER
#     name = "Algorithm 2 (Euler steps) from Karras et al. (2022)"
#     default_steps = 40
#     sampler_func = staticmethod(k_sampling.sample_euler)


# class EulerAncestralSampler(KDiffusionSampler):
#     short_name = SamplerName.K_EULER_ANCESTRAL
#     name = "Euler ancestral"
#     default_steps = 40
#     sampler_func = staticmethod(k_sampling.sample_euler_ancestral)


# class HeunSampler(KDiffusionSampler):
#     short_name = SamplerName.K_HEUN
#     name = "Algorithm 2 (Heun steps) from Karras et al. (2022)."
#     default_steps = 40
#     sampler_func = staticmethod(k_sampling.sample_heun)


# class LMSSampler(KDiffusionSampler):
#     short_name = SamplerName.K_LMS
#     name = "LMS"
#     default_steps = 40
#     sampler_func = staticmethod(k_sampling.sample_lms)


# class CFGDenoiser(nn.Module):
#     """
#     Conditional forward guidance wrapper
#     """

#     def __init__(self, model):
#         super().__init__()
#         self.inner_model = model
#         self.device = get_device()

#     def forward(
#         self,
#         x,
#         sigma,
#         uncond,
#         cond,
#         cond_scale,
#         mask=None,
#         mask_noise=None,
#         orig_latent=None,
#     ):
#         def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
#             return self.inner_model(
#                 noisy_latent_in, time_encoding_in, cond=conditioning_in
#             )

#         if mask is not None:
#             assert orig_latent is not None
#             t = self.inner_model.sigma_to_t(sigma, quantize=True).to(self.device)
#             big_sigma = max(sigma, 1)
#             x = mask_blend(
#                 noisy_latent=x,
#                 orig_latent=orig_latent * big_sigma,
#                 mask=mask,
#                 mask_noise=mask_noise * big_sigma,
#                 ts=t,
#                 model=self.inner_model.inner_model,
#             )

#         noise_pred = get_noise_prediction(
#             denoise_func=_wrapper,
#             noisy_latent=x,
#             time_encoding=sigma,
#             neutral_conditioning=uncond,
#             positive_conditioning=cond,
#             signal_amplification=cond_scale,
#         )

#         return noise_pred
    
    
# SAMPLERS = [
#     # PLMSSampler,
#     # DDIMSampler,
#     DPMFastSampler,
#     DPMAdaptiveSampler,
#     LMSSampler,
#     DPM2Sampler,
#     DPM2AncestralSampler,
#     DPMPP2MSampler,
#     DPMPPSDESampler,
#     DPMPP2SAncestralSampler,
#     EulerSampler,
#     EulerAncestralSampler,
#     HeunSampler,
# ]

# SAMPLER_LOOKUP = {sampler.short_name: sampler for sampler in SAMPLERS}

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()