In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import proplot as pplt

import torch
import torch.nn as nn

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

current_path = !pwd
parent_path = os.path.dirname(current_path[0])
if parent_path not in sys.path: sys.path.append(parent_path)

import utils.model_handling as model_utils
import utils.dataset_generation as iso_data
import utils.histogram_analysis as hist_funcs
import utils.principal_curvature as curve_utils
import utils.plotting as plot_funcs

In [None]:
class QuadraticFunction(torch.nn.Module):
    def __init__(self, diag=None, hess=None):
        super().__init__()
        
        if hess is not None:
            self.hessian = torch.tensor(hess).to(DEVICE)
        else:
            if diag is None:
                diag = [1.0, 2.0]

            #self.hessian = torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]).to(DEVICE)
            self.hessian = torch.diag(torch.tensor(diag)).to(DEVICE)
        
        self.hessian = torch.nn.Parameter(self.hessian, requires_grad=False)
    
    def forward(self, x):
        return torch.dot(x, torch.matmul(self.hessian, x))

f = QuadraticFunction([1.0, 2.0]).to(DEVICE)

In [None]:
xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

In [None]:
fig, axs = pplt.subplots(nrows=1, ncols=1)
axs.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())

In [None]:
points = [
    [0.5, 1.0],
    [1.0, 2.0],
    [-0.5, -1.0], 
    [0.5, -1.0], 
    [0.0, -0.5],
    [0.0, 0.5],
    [0.0, 1.0],
    [0.0, 2.0],
]

In [None]:
def value_grad_hess(f, point):
    value = f(point)
    grad = torch.autograd.functional.jacobian(f, point)
    hess = torch.autograd.functional.hessian(f, point)
    return value, grad, hess

# 2d function with 1d isosurface

first let's test a symmetric quadratic. Observe that while the graph usually has a principal curvature direction that coincides with the decision boundary, it gives us wrong curvature for this direction

In [None]:
f = QuadraticFunction([1.0, 1.0]).to(DEVICE)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

Now let's go for something more assymmetric. Observe that the principal directions for the graph are not longer orthogonal in the parameter space! Also, the graph often doesn't even have a principal direction that aligns with the isosurface.

In [None]:
#f = QuadraticFunction([1.0, 2.0]).to(DEVICE)
f = QuadraticFunction(hess=[[1.0, 0.5], [0.5, 2.0]]).to(DEVICE)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

Let's also try a linear function

In [None]:
f = lambda x: torch.sum(x)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

In [None]:
points = [
    [0.5, 1.0],
    [1.0, 2.0],
    [-1.0, 2.0],
    [0.0, 0.0],
    #[0.0, -0.5],
    #[0.0, 0.5],
    #[0.0, 1.0],
    #[0.0, 2.0],
]

f = lambda x: torch.prod(x)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    if point[0] != 0:
        iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
        plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

In [None]:
f = QuadraticFunction([2.0, 2.0]).to(DEVICE)
point = torch.tensor([0.0, 1.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f, point)
    
# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

## 3d, i.e. function 3d->1d resulting in 2d isosurface

First let's check that we get correct curvatures for a sphere

In [None]:
f_3d = QuadraticFunction([1.0, 1.0, 1.0]).to(DEVICE)

point = torch.tensor([0, 0, 2.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**2)

print("ratio", torch.prod(iso_curvatures)*point[-1]**2)

print()

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

Now with projection to subspace

In [None]:
projection_subspace_of_interest.shape

In [None]:
f_3d = QuadraticFunction([1.0, 1.0, 1.0]).to(DEVICE)
#projection_subspace_of_interest = None
projection_subspace_of_interest = torch.tensor([
    #[2, 0, 2.0],
    [2, 1, 0.0],
]).to(DEVICE)

point = torch.tensor([0, 2, 2.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**2)

print("ratio", torch.prod(iso_curvatures)*point[-1]**2)

print()

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

In [None]:
f_3d = QuadraticFunction([1.0, 1.0, 1.0]).to(DEVICE)
projection_subspace_of_interest = torch.tensor([
    [2, 2, 1.0],
    [2, 0, 0.0],
]).to(DEVICE)

point = torch.tensor([0, 2, 2.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**2)

print("ratio", torch.prod(iso_curvatures)*point[-1]**2)

print()

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

And something more assymmetric

In [None]:
f_3d = QuadraticFunction([1.0, 2.0, 3.0]).to(DEVICE)

point = torch.tensor([-3.0, 2.0, -1.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()


print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

and again with subspace

In [None]:
f_3d = QuadraticFunction([1.0, 2.0, 3.0]).to(DEVICE)

projection_subspace_of_interest = torch.tensor([
    [2, 2, 1.0],
    #[2, 0, 0.0],
]).to(DEVICE)


point = torch.tensor([-3.0, 2.0, -1.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest,
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()


print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)