# Guidance Transformation class to normalize and schedule Classifier-free Guidance.

In [None]:
#|default_exp transforms

The following classes implement:  

1. Baseline Classifier-free Guidance  
2. Scaling the prediction by the ratio of norms  
3. Scaling the `t - u` update by its norms  
4. Scaling both 2. + 3.

In [None]:
#|export

'''Code for blog post:
    https://enzokro.dev/blog/posts/2022-11-15-guidance-expts-1
'''
import math
import torch


In [None]:
#| export

class GuidanceTfm:
    "Baseline Classifier-free Guidance for Difussion."
    name = "CFGuidance"
    def __init__(self, schedules, *args, **kwargs):
        self.schedules = schedules
        
    def encode(self, u, t, idx=None):
        "Applies guidance on `u` and `t` with optional pre/post processing."
        self.pre_proc(u, t, idx)
        self.guide(u, t, idx)
        self.post_proc(u, t, idx)
        return self.pred
    
    def guide(self, u, t, idx=None):
        "Mixes latents `u` and `t` based on guidance schedule for `g`."
        self.pred = u + (self.scheduler('g', idx) * (t - u))

    def pre_proc (self, u, t, idx=None): pass
    def post_proc(self, u, t, idx=None): pass
    
    def scheduler(self, name, idx):
        "Gets the scheduled value for parameter `name` at timestep `idx`."
        return self.schedules.get(name)[idx]
    
    def __call__(self, *args, **kwargs):
        return self.encode(*args, **kwargs)
    
    
class BaseNormGuidance(GuidanceTfm):
    "Scales the noise prediction by its overall norm."
    name = "BaseNormGuidance"
    def post_proc(self, u, t, idx=None):
        self.pred = self.pred * (torch.linalg.norm(u) / torch.linalg.norm(self.pred))
        
        
class TNormGuidance(GuidanceTfm):
    "Scales the latent mix of `t - u`"
    name = "TNormGuidance"
    def guide(self, u, t, idx=None):
        self.pred = u + (self.scheduler('g', idx) * (t - u)) / torch.linalg.norm(t - u) * torch.linalg.norm(u)
        
        
class FullNormGuidance(TNormGuidance, BaseNormGuidance):
    "Applies both Base and T-Norm on the noise prediction."
    name = "FullNormGuidance"
    pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()