# An analysis of `clamp_with_grad()` 

This notebook examines the `clamp_with_grad()` function that's sometimes used in the CLIP-guided art community.

I wanted to build some intuition about how the backward pass is defined and why the function works.



## Setup

Import the usual libraries and tweak the matplotlib setup

In [None]:
import torch
import matplotlib
import matplotlib.pyplot as plt
from math import pi

matplotlib.rcParams['figure.figsize'] = (8 * 3, 8)

## Define a helper function to access the gradient

In [None]:
def get_gradient(f, xs):
  xs = xs.detach().clone()
  xs.requires_grad = True

  ys = f(xs)
  ys.backward(torch.ones_like(ys))  
  return xs.grad.detach()


x = torch.arange(-2, 2, 0.01)

plt.plot(x, torch.sin(x * 5), label="sin(x)")
plt.plot(x, get_gradient(torch.sin, x * 5), label="dsin(x)/dx")
plt.legend()

## Define `clamp_with_grad()`

In [None]:

class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return (
            grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
            None,
            None,
        )

clamp_with_grad = ClampWithGrad.apply

class ClampWithGrad2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        clamped = input.clamp(ctx.min, ctx.max)
        delta = input - clamped
        v = torch.where(delta == 0, grad_in, torch.sign(delta))
        return (v, None, None)

clamp_with_grad2 = ClampWithGrad2.apply



In [None]:
def plot(fn, ax, title):
  x = torch.arange(-2, 2, 0.01)
  clamp = lambda x: fn(x, -1, 1)
  
  #loss = lambda x: (clamp(x)) ** 2
  loss = lambda x: torch.sin(clamp(x) * pi) + 1

  ax.plot(x, clamp(x), label="clamp", alpha=0.5)
  ax.plot(x, loss(x), label="loss", alpha=0.5)
  
  #ax.plot(x, get_gradient(lambda x: 2 * x, clamp(x)), label="dloss/dclamp" )
  #ax.plot(x, get_gradient(lambda x: torch.sin(x * pi), clamp(x)), label="dloss/dclamp", alpha=0.5)
  
  #ax.plot(x, get_gradient(clamp, x), label="dclamp/dx", alpha=0.5)
  ax.plot(x, get_gradient(loss, x), label="dloss/dx", alpha=0.5)
  ax.legend()
  ax.title.set_text(title)

fig, ax = plt.subplots(ncols=3)
plot(torch.clamp, ax[0], "clamp")
plot(clamp_with_grad, ax[1], "clamp_with_grad")
plot(clamp_with_grad2, ax[2], "clamp_with_grad_proposed")
