# Optimisation Examples

This notebook demonstrates simple use of gradient descent to optimise the parameters corresponding to a point, line segment or curve using a MSE loss between a target image and the image rendered by the parameters. The aim of this is to visually illustrate how the different geometric constructs get optimised (and verify the rastererise is working correctly!).

The following code block just defines a function which performs plain gradient descent with a fixed learning rate (it does optionally anneal the sigma value of the object(s) being drawn over time however). In addition to computing the optimised parameters, an animation object is created which shows how the optimisation process progresses over time. 

We also use this block to define the coordinate grid used for mapping pixel space to world space.

In [1]:
try: 
    from celluloid import Camera
except:
    !pip install celluloid
    from celluloid import Camera

try:
    from dsketch.raster.disttrans import point_edt2, line_edt2, curve_edt2_bruteforce
    from dsketch.raster.raster import exp
    from dsketch.raster.composite import softor
except:
    !pip install git+https://github.com/jonhare/DifferentiableSketching.git
    from dsketch.raster.disttrans import point_edt2, line_edt2, curve_edt2_bruteforce
    from dsketch.raster.raster import exp
    from dsketch.raster.composite import softor
    
import torch
from torch.nn.functional import mse_loss
import matplotlib.pyplot as plt
from IPython.display import HTML


def optimise(target, params, render_fn, lr=0.5, steps=200, sigma_init=0.1, sigma_min=1e-2, sigma_factor=2, sigma_step=50):
    fig = plt.figure(figsize=(8, 4))
    camera = Camera(fig)
    sp1 = fig.add_subplot(1, 2, 1)
    sp1.set_title("Target")
    sp2 = fig.add_subplot(1, 2, 2)
    sp2.set_title("Estimate")
    
    params.requires_grad = True
    sigma = sigma_init
    
    
    for i in range(steps):
        params.grad = None
        est = render_fn(params, sigma)
        loss = mse_loss(est, target)
        loss.backward()
        params.data = params - params.grad * lr
        
        if i % 10 == 0:
            sp1.imshow(target)
            sp2.imshow(est.detach())
            camera.snap()

        if i % sigma_step == 0:
            if sigma>sigma_min:
                sigma = sigma / sigma_factor
                sigma = max(sigma, sigma_min)
    
    
    # create the animation
    anim = camera.animate()
    plt.close()

    return params, anim

sz = 100
r = torch.linspace(-1, 1, sz)
c = torch.linspace(-1, 1, sz)
grid = torch.meshgrid(r, c)
grid = torch.stack(grid, dim=2)

## Optimising a point

We first experiment with optimising the coordinates of a point:

In [2]:
def render_point(params, sigma):
    return exp(point_edt2(params.unsqueeze(0), grid), sigma).squeeze(0)

target_params = torch.tensor([0.0,0.0])
target_image = render_point(target_params, 1e-2)

params = torch.tensor([0.3,0.3])
params, anim = optimise(target_image, params, render_point)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

Target params:    tensor([0., 0.])
Optimised params: tensor([8.9389e-10, 9.1189e-10], requires_grad=True)


One problem of such an optimisation is that it is very sensitive to the initial conditions; a perfectly valid approach in terms of reducing the loss is to move the point being optimised away from the target and outside of the image bounds. This obviously doesn't give a globally optimal solution (the loss wont be zero), but does represent a local optimum with a relatively small loss.

## Multiple points



In [3]:
def render_points(params, sigma):
    return softor(exp(point_edt2(params, grid), sigma).unsqueeze(0)).squeeze(0)

target_params = torch.tensor([[-0.5, -0.5], [0.5,0.5]])
target_image = render_points(target_params, 1e-2)

params = torch.tensor([[0.3,0.3], [-0.3,0.0]])
params, anim = optimise(target_image, params, render_points, lr=0.5)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

Target params:    tensor([[-0.5000, -0.5000],
        [ 0.5000,  0.5000]])
Optimised params: tensor([[ 0.5000,  0.5000],
        [-0.5000, -0.5000]], requires_grad=True)


## Line segments

We now look at straight line segments:

In [4]:
def render_line(params, sigma):
    return exp(line_edt2(params.unsqueeze(0), grid), sigma).squeeze(0)

target_params = torch.tensor([[0.5, 0.5], [-0.5, -0.5]])
target_image = render_line(target_params, 1e-2)

params = torch.tensor([[0.3,-0.3],[0.3,0.3]])
params, anim = optimise(target_image, params, render_line, lr=0.5)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

Target params:    tensor([[ 0.5000,  0.5000],
        [-0.5000, -0.5000]])
Optimised params: tensor([[-0.5000, -0.5000],
        [ 0.5000,  0.5000]], requires_grad=True)


Once again, there is a certain sensitivity to initial conditions and the gradients can force the line to collapse to a point and move off-image towards a local optimum.

## Multiple line segments


In [5]:
def render_lines(params, sigma):
    return softor(exp(line_edt2(params, grid), sigma).unsqueeze(0)).squeeze(0)

target_params = torch.tensor([[[0, 0], [0.5, 0.5]], [[0.25, 0.5], [0.25, -0.5]]])
target_image = render_lines(target_params, 1e-2)

params = torch.tensor([[[0.3,-0.5],[0.8,0.5]], [[0.4,-0.3],[-0.2,-0.3]]])
params, anim = optimise(target_image, params, render_lines, lr=0.1, steps=1000)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

Target params:    tensor([[[ 0.0000,  0.0000],
         [ 0.5000,  0.5000]],

        [[ 0.2500,  0.5000],
         [ 0.2500, -0.5000]]])
Optimised params: tensor([[[ 4.3438e-04, -4.1726e-04],
         [ 4.9785e-01,  5.0175e-01]],

        [[ 2.5047e-01, -5.0000e-01],
         [ 2.4882e-01,  4.5418e-01]]], requires_grad=True)


## Optimising a Quadratic Bezier Curve

In [6]:
from dsketch.raster.disttrans import quadratic_bezier

def render_curve_bez2(params, sigma):
    dt2 = curve_edt2_bruteforce(params.unsqueeze(0), grid, iters=3, slices=10, cfcn=quadratic_bezier)
    return exp(dt2, sigma).squeeze(0)

target_params = torch.tensor([[0.5, 0.5], [-1, 1], [-0.5, -0.5]])
target_image = render_curve_bez2(target_params, 1e-2)

params = torch.tensor([[0.5, -0.5], [-0.4, -0.4], [0.3, 0.5]])
params, anim = optimise(target_image, params, render_curve_bez2, lr=0.5, steps=500)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

Target params:    tensor([[ 0.5000,  0.5000],
        [-1.0000,  1.0000],
        [-0.5000, -0.5000]])
Optimised params: tensor([[ 0.4999,  0.4998],
        [-1.0000,  0.9999],
        [-0.5000, -0.5000]], requires_grad=True)


## Optimising a Cubic Bezier Curve

In [None]:
from dsketch.raster.disttrans import cubic_bezier

def render_curve_bez3(params, sigma):
    dt2 = curve_edt2_bruteforce(params.unsqueeze(0), grid, iters=3, slices=10, cfcn=cubic_bezier)
    return exp(dt2, sigma).squeeze(0)

target_params = torch.tensor([[0.5, 0.5], [-1, 1] , [1,-1], [-0.5, -0.5]])
target_image = render_curve_bez3(target_params, 1e-2)

params = torch.tensor([[0.5, -0.5], [-0.4, -0.4] ,[0.4,1], [0.3, 0.5]])
params, anim = optimise(target_image, params, render_curve_bez3, lr=0.5, steps=500)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

Target params:    tensor([[ 0.5000,  0.5000],
        [-1.0000,  1.0000],
        [ 1.0000, -1.0000],
        [-0.5000, -0.5000]])
Optimised params: tensor([[-0.4817, -0.5066],
        [ 0.9900, -0.9816],
        [-1.0099,  1.0091],
        [ 0.5178,  0.4948]], requires_grad=True)


## Optimising a Catmull-Rom Spline

In [None]:
from dsketch.raster.disttrans import centripetal_catmull_rom_spline

def render_curve_crs(params, sigma):
    dt2 = curve_edt2_bruteforce(params.unsqueeze(0), grid, iters=2, slices=10, cfcn=centripetal_catmull_rom_spline)
    return exp(dt2, sigma).squeeze(0)


target_params = torch.tensor([[10, 1], [0.4, -0.4] , [-0.4, 0.8], [-8, -1]])
target_image = render_curve_crs(target_params, 1e-2)

# print(centripetal_catmull_rom_spline(target_params.unsqueeze(0),torch.tensor([1])))

# plt.imshow(target_image)

params = torch.tensor([[0.5, -0.5], [-0.4, -0.4] ,[0.4,1], [0.3, -0.5]])
params, anim = optimise(target_image, params, render_curve_crs, lr=0.5, steps=500)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())

In [None]:
from dsketch.raster.disttrans import curve_edt2_polyline

def render_curve_crs_pl(params, sigma):
    dt2 = curve_edt2_polyline(params.unsqueeze(0), grid, segments=20, cfcn=centripetal_catmull_rom_spline)
    return exp(dt2, sigma).squeeze(0)


params = torch.tensor([[0.5, -0.5], [-0.4, -0.4] ,[0.4,1], [0.3, -0.5]])
params, anim = optimise(target_image, params, render_curve_crs_pl, lr=0.5, steps=500)

print(f"Target params:    {target_params}")
print(f"Optimised params: {params}")

HTML(anim.to_jshtml())