In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import proplot as pplt
from scipy.stats import ortho_group

import torch
import torch.nn as nn

#DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = 'cpu'

current_path = !pwd
parent_path = os.path.dirname(current_path[0])
if parent_path not in sys.path: sys.path.append(parent_path)

import utils.model_handling as model_utils
import utils.dataset_generation as iso_data
import utils.histogram_analysis as hist_funcs
import utils.principal_curvature as curve_utils
import utils.plotting as plot_funcs

In [None]:
class QuadraticFunction(torch.nn.Module):
    def __init__(self, diag=None, hess=None):
        super().__init__()
        
        if hess is not None:
            self.hessian = torch.tensor(hess).to(DEVICE)
        else:
            if diag is None:
                diag = [1.0, 2.0]

            #self.hessian = torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]).to(DEVICE)
            self.hessian = torch.diag(torch.tensor(diag)).to(DEVICE)
        
        self.hessian = torch.nn.Parameter(self.hessian, requires_grad=False)
    
    def forward(self, x):
        return torch.dot(x, torch.matmul(self.hessian, x))

f = QuadraticFunction([1.0, 2.0]).to(DEVICE)

In [None]:
xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

In [None]:
fig, axs = pplt.subplots(nrows=1, ncols=1)
axs.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())

In [None]:
points = [
    np.array([np.sqrt(1/2), np.sqrt(1/2)], dtype=np.float32),
    [0.5, 1.0],
    [1.0, 2.0],
    [-0.5, -1.0], 
    [0.5, -1.0], 
    [0.0, -0.5],
    [0.0, 0.5],
    [0.0, 1.0],
    [0.0, 2.0],
]

In [None]:
def value_grad_hess(f, point):
    value = f(point)
    grad = torch.autograd.functional.jacobian(f, point)
    hess = torch.autograd.functional.hessian(f, point)
    return value, grad, hess

# 2d function with 1d isosurface

first let's test a symmetric quadratic. Observe that while the graph usually has a principal curvature direction that coincides with the decision boundary, it gives us wrong curvature for this direction

In [None]:
f = QuadraticFunction([1.0, 1.0]).to(DEVICE)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

Now let's go for something more assymmetric. Observe that the principal directions for the graph are not longer orthogonal in the parameter space! Also, the graph often doesn't even have a principal direction that aligns with the isosurface.

In [None]:
#f = QuadraticFunction([1.0, 2.0]).to(DEVICE)
f = QuadraticFunction(hess=[[1.0, 0.5], [0.5, 2.0]]).to(DEVICE)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

Let's also try a linear function

In [None]:
f = lambda x: torch.sum(x)

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

In [None]:
points = [
    [0.5, 1.0],
    [1.0, 2.0],
    [-1.0, 2.0],
    [0.0, 0.0],
    #[0.0, -0.5],
    #[0.0, 0.5],
    #[0.0, 1.0],
    #[0.0, 2.0],
]

# torch.prod returns wrong hessian at (0, 0)
f = lambda x: x[0]*x[1]

xs = torch.linspace(-3, 3, 50).to(DEVICE)
ys = torch.linspace(-3, 3, 50).to(DEVICE)
XS, YS = torch.meshgrid(xs, ys)
XS_flat = XS.flatten()
YS_flat = YS.flatten()
data = torch.vstack((XS_flat, YS_flat))
zs = torch.tensor([f(data[:, i]) for i in range(len(data.T))]).reshape(XS.shape)

fig, axs = pplt.subplots(nrows=len(points), ncols=2)

def plot_curvature(ax, curvatures, directions):
    for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
        ax.arrow(np_point[0], np_point[1], direction[0], direction[1], width=0.05)
        #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
        ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')

for point_index, np_point in enumerate(points):
    print(point)
    point = torch.tensor(np_point, device=DEVICE)
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    print(pt_hess)
    
    for ax in axs[point_index, :]:
        ax.contourf(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy())
        ax.contour(XS.cpu().numpy(), YS.cpu().numpy(), zs.cpu().numpy(), levels=[value.detach().cpu().numpy()], color='black')
        ax.scatter([np_point[0]], [np_point[1]])
        ax.set_title(f'point: ({np_point[0]}, {np_point[1]})')

    
    # isoresponse
    graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)
    if point[0] != 0:
    #if True:
        iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
        plot_curvature(axs[point_index, 0], iso_curvatures, iso_directions)
    plot_curvature(axs[point_index, 1], graph_curvatures, graph_directions)

In [None]:
f = QuadraticFunction([2.0, 2.0]).to(DEVICE)
point = torch.tensor([0.0, 1.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f, point)
    
# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

## 3d, i.e. function 3d->1d resulting in 2d isosurface

First let's check that we get correct curvatures for a sphere

In [None]:
f_3d = QuadraticFunction([1.0, 1.0, 1.0]).to(DEVICE)

point = torch.tensor([0, 0, 2.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**2)

print("ratio", torch.prod(iso_curvatures)*point[-1]**2)

print()

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

Now with projection to subspace

In [None]:
f_3d = QuadraticFunction([1.0, 1.0, 1.0]).to(DEVICE)
#projection_subspace_of_interest = None
projection_subspace_of_interest = torch.tensor([
    #[2, 0, 2.0],
    [2, 1, 0.0],
]).to(DEVICE)

point = torch.tensor([0, 2, 2.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**2)

print("ratio", torch.prod(iso_curvatures)*point[-1]**2)

print()

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

In [None]:
f_3d = QuadraticFunction([1.0, 1.0, 1.0]).to(DEVICE)
projection_subspace_of_interest = torch.tensor([
    [2, 2, 1.0],
    [2, 0, 0.0],
]).to(DEVICE)

point = torch.tensor([0, 2, 2.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**2)

print("ratio", torch.prod(iso_curvatures)*point[-1]**2)

print()

print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

And something more assymmetric

In [None]:
f_3d = QuadraticFunction([1.0, 0.0, 3.0]).to(DEVICE)

point = torch.tensor([-3.0, 2.0, -1.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()


print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

In [None]:
%matplotlib widget

In [None]:
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d.proj3d import proj_transform
from mpl_toolkits.mplot3d.axes3d import Axes3D

# https://gist.github.com/WetHat/1d6cd0f7309535311a539b42cccca89c
class Arrow3D(FancyArrowPatch):

    def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._xyz = (x, y, z)
        self._dxdydz = (dx, dy, dz)

    def draw(self, renderer):
        x1, y1, z1 = self._xyz
        dx, dy, dz = self._dxdydz
        x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)

        xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)
        
        
def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
    '''Add an 3d arrow to an `Axes3D` instance.'''

    arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
    ax.add_artist(arrow)


setattr(Axes3D, 'arrow3D', _arrow3D)


def plot_implicit(fn, ax, value=0, bbox=(-2.0, 2.0)):
    ''' create a plot of an implicit function
    fn  ...implicit function (plot where fn==value)
    bbox ..the x,y,and z limits of plotted interval'''
    xmin, xmax, ymin, ymax, zmin, zmax = bbox*3
    A = np.linspace(xmin, xmax, 100) # resolution of the contour
    B = np.linspace(xmin, xmax, 15) # number of slices
    A1,A2 = np.meshgrid(A,A) # grid on which the contour is plotted

    for z in B: # plot contours in the XY plane
        X,Y = A1,A2
        Z = fn(X,Y,z) - value
        cset = ax.contour(X, Y, Z+z, [z], zdir='z', zorder=0)
        # [z] defines the only level to plot for this contour for this value of z

    for y in B: # plot contours in the XZ plane
        X,Z = A1,A2
        Y = fn(X,y,Z) - value
        cset = ax.contour(X, Y+y, Z, [y], zdir='y', zorder=0)

    for x in B: # plot contours in the YZ plane
        Y,Z = A1,A2
        X = fn(x,Y,Z) - value
        cset = ax.contour(X+x, Y, Z, [x], zdir='x', zorder=0)

    # must set plot limits because the contour will likely extend
    # way beyond the displayed level.  Otherwise matplotlib extends the plot limits
    # to encompass all values in the contour.
    ax.set_zlim3d(zmin,zmax)
    ax.set_xlim3d(xmin,xmax)
    ax.set_ylim3d(ymin,ymax)

def plot_manifold(f, point, ax):

    value = f(point)
    
    def fn(x, y, z):
        for data in x, y, z:
            if data.ndim != 0:
                shape = data.shape
        if x.ndim == 0:
            x = np.ones_like(y) * x
        if y.ndim == 0:
            y = np.ones_like(x) * y
        if z.ndim == 0:
            z = np.ones_like(x) * z

        out = [f(torch.tensor([_x, _y, _z], dtype=torch.float).to(DEVICE)).detach().cpu().numpy() for _x, _y, _z in zip(x.flatten(), y.flatten(), z.flatten())]
        return np.reshape(out, x.shape)

    np_point = point.detach().cpu().numpy()
    plot_implicit(fn, ax, value=value_orig.detach().cpu().numpy())
    ax.scatter([np_point[0]], [np_point[1]], [np_point[2]], s=50, color='red', zorder=1000)
    
    value, pt_grad, pt_hess = value_grad_hess(f, point)
    print(pt_grad)
    iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
    
    
    def plot_curvature(ax, curvatures, directions):
        for curvature, direction in zip(curvatures.detach().cpu().numpy(), directions.T.detach().cpu().numpy()):
            ax.arrow3D(np_point[0], np_point[1], np_point[2], direction[0], direction[1], direction[2],
                       #width=0.05
                       mutation_scale=10,
                       zorder=10,
                      )
            #ax.text(np_point[0] + 0.5 * direction[0], np_point[1] + 0.5 * direction[1], f'{curvature:.02f}')
            #ax.text(np_point[0] + 1.0 * direction[0], np_point[1] + 1.0 * direction[1], f'({direction[0]:.02f}, {direction[1]:.02f}): {curvature:.03f}')
    
    plot_curvature(ax, iso_curvatures, iso_directions)

In [None]:
# with manual coordinate transform

In [None]:
%matplotlib inline

In [None]:
f_3d = QuadraticFunction([3.0, 3.0, 1.0]).to(DEVICE)
point = torch.tensor([1.0, 1.0, 0.00001]).to(DEVICE)


# M = np.array([[ 0.41133789, -0.8123282 ,  0.4134296 ],
#       [-0.18372984,  0.37037412,  0.91053081],
#       [-0.89277348, -0.45049517,  0.00309998]])
# coordinate_transformation = torch.tensor(M).to(DEVICE)
# point = torch.matmul(coordinate_transformation.T.type(point.dtype), point).detach().clone()

# first in original coordinates
value_orig, pt_grad_orig, pt_hess_orig = value_grad_hess(f_3d, point)
print(value_orig)
iso_shape_operator_orig, iso_curvatures_orig, iso_directions_orig = curve_utils.local_response_curvature_isoresponse_surface(pt_grad_orig, pt_hess_orig)


# Now let's do coordinate transform
# 46
# 56
# 58
#rst = np.random.RandomState(seed=58)
rst = None
M = ortho_group.rvs(len(point), random_state=rst)
# from old coordinates to new coordinates
#M = np.array([[ 0.41133789, -0.8123282 ,  0.4134296 ],
#       [-0.18372984,  0.37037412,  0.91053081],
#       [-0.89277348, -0.45049517,  0.00309998]])
coordinate_transformation = torch.tensor(M).to(DEVICE)

def new_f(x):
    return f_3d(torch.matmul(coordinate_transformation.T.type(x.dtype), x))
new_point = torch.matmul(coordinate_transformation.type(point.dtype), point).detach().clone()

value_new, pt_grad_new, pt_hess_new = value_grad_hess(new_f, new_point)
print(value_new)
    
iso_shape_operator_new, iso_curvatures_new, iso_directions_new = curve_utils.local_response_curvature_isoresponse_surface(pt_grad_new, pt_hess_new)

print("curvature orig", iso_curvatures_orig)
print("curvature new", iso_curvatures_new)

print("gauss curvature orig", torch.prod(iso_curvatures_orig))
print("gauss curvature new", torch.prod(iso_curvatures_new))

# fig = plt.figure()
# ax1 = fig.add_subplot(121, projection='3d')
# ax2 = fig.add_subplot(122, projection='3d')
# plot_manifold(f_3d, point, ax1)
# plot_manifold(new_f, new_point, ax2)

# #ax1.view_init(elev=70., azim=270)
# #ax2.view_init(elev=-40., azim=0)

# ax1.view_init(elev=90., azim=270)
# ax2.view_init(elev=90., azim=0)

In [None]:
from scipy.linalg import null_space

In [None]:
null_space(pt_grad_orig.cpu().numpy()[np.newaxis, :])

In [None]:
null_space()

In [None]:
pt_grad_orig

In [None]:
pt_grad_new

In [None]:
pt_grad_a = torch.tensor([[ 5.2957],
        [-2.4524]], dtype=torch.float64)
pt_grad_b = torch.tensor([[1.3933]], dtype=torch.float64)
pt_hess_aa = torch.tensor([[ 5.5909, -1.1310],
        [-1.1310,  2.8733]], dtype=torch.float64)
pt_hess_ab = torch.tensor([[-0.4358],
        [-1.2047]], dtype=torch.float64)
pt_hess_bb = torch.tensor([[5.5358]], dtype=torch.float64)

grad_g = -pt_grad_a / pt_grad_b

np_pt_grad_a = pt_grad_a.cpu().numpy().astype(np.float128)
np_pt_grad_b = pt_grad_b.cpu().numpy().astype(np.float128)
np_pt_hess_aa = pt_hess_aa.cpu().numpy().astype(np.float128)
np_pt_hess_ab = pt_hess_ab.cpu().numpy().astype(np.float128)
np_pt_hess_bb = pt_hess_bb.cpu().numpy().astype(np.float128)

np_grad_g = -np_pt_grad_a / np_pt_grad_b

In [None]:
grad_g, np_grad_g

In [None]:
(-1 / pt_grad_b) * (
    torch.diag(pt_hess_ab.reshape(-1)) * (grad_g + grad_g.T)
    +
    pt_hess_bb * grad_g * grad_g.T
    +
    pt_hess_aa
)

In [None]:
(-1 / np_pt_grad_b) * (
    np.diag(np_pt_hess_ab.reshape(-1)) * (np_grad_g + np_grad_g.T)
    +
    np_pt_hess_bb * np_grad_g * np_grad_g.T
    +
    np_pt_hess_aa
)

In [None]:
np.diag(np_pt_hess_ab.flatten()) * (np_grad_g + np_grad_g.T)

In [None]:
(-1 / np_pt_grad_b) * (
    #np.diag(np_pt_hess_ab.flatten()) * (np_grad_g + np_grad_g.T)
    np_pt_hess_ab.T * (np_grad_g + np_grad_g.T)
    +
    np_pt_hess_bb * np_grad_g * np_grad_g.T
    +
    np_pt_hess_aa
)

In [None]:
(
    (-1 / pt_grad_b) * (
        #pt_hess_aa + 
        - 1 / pt_grad_b * pt_grad_a * (
            pt_hess_ab.T +
            pt_hess_bb * grad_g.T
        ) +
        pt_hess_ab.T * grad_g.T
        #1 / pt_grad_b * pt_grad_a * (
        #    pt_hess_ab.T +
        #    pt_hess_bb * grad_g.T
        #)
        + pt_hess_aa
    )
)

In [None]:
(
    (-1 / pt_grad_b) * (
        pt_hess_aa + 
        pt_hess_ab.T * grad_g.T
    ) +
    (1 / pt_grad_b ** 2) * pt_grad_a * (
        pt_hess_ab.T +
        pt_hess_bb * grad_g.T
    )
)

In [None]:
h = (
    (-1 / np_pt_grad_b) * (
        np_pt_hess_aa + 
        np_pt_hess_ab.T * np_grad_g.T
    ) +
    (1 / np_pt_grad_b ** 2) * np_pt_grad_a * (
        np_pt_hess_ab.T +
        np_pt_hess_bb * np_grad_g.T
    )
)
h

In [None]:
grad_g

In [None]:
pt_hess_ab.T

In [None]:
pt_hess_bb * grad_g.T

In [None]:
(1 / pt_grad_b ** 2) * pt_grad_a * (
         pt_hess_ab.T +
         pt_hess_bb * grad_g.T
         #pt_hess_bb * (-pt_grad_a / pt_grad_b).T#* grad_g.T
)

In [None]:
torch.tensor([[1.3933]], dtype=torch.float64)**2

In [None]:
# full formula
hess_g
tensor([[-63.7809,  25.6250],
        [ 26.7512, -11.3266]], device='cuda:0', dtype=torch.float64)

In [None]:
ax1.view_init(elev=90., azim=270)
ax2.view_init(elev=90., azim=0)
plt.show()

In [None]:
def f(x):
    #return x[0]*x[1]
    return torch.prod(x)

In [None]:
value_grad_hess(f, torch.tensor([0.0, 0.0]).to(DEVICE))

In [None]:
torch.autograd.functional.hessian(lambda x: x[0]*x[1], torch.tensor([0.0, 0.0]).to(DEVICE))

In [None]:
f_3d = QuadraticFunction([1.0, 3.0, 3.0]).to(DEVICE)


point = torch.tensor([0.0, 0.0, 1]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

M = ortho_group.rvs(len(point))
#M = np.eye(len(point))
#_M = ortho_group.rvs(len(point) - 1)
#M[:-1, :-1] = _M
coordinate_transformation = torch.tensor(M, dtype=torch.double).to(DEVICE)

#coordinate_transformation = None

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    coordinate_transformation=coordinate_transformation
)


#graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

#print("ISO SHAPE", iso_shape_operator)
#print("ISO dir", iso_directions)
print("ISO curv", iso_curvatures)
print(torch.prod(iso_curvatures))

print()


#print("graph SHAPE", graph_shape_operator)
#print("graph curv", graph_curvatures)
#print("graph dir", graph_directions)

and again with subspace

In [None]:
f_3d = QuadraticFunction([1.0, 2.0, 3.0]).to(DEVICE)

projection_subspace_of_interest = torch.tensor([
    [2, 2, 1.0],
    #[2, 0, 0.0],
]).to(DEVICE)


point = torch.tensor([-3.0, 2.0, -1.0]).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

print("eval point:", point)
print("eval value:", value)
print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest,
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO SHAPE", iso_shape_operator)
print("ISO curv", iso_curvatures)
print("ISO dir", iso_directions)
print()


print("graph SHAPE", graph_shape_operator)
print("graph curv", graph_curvatures)
print("graph dir", graph_directions)

In [None]:
ortho_group.rvs(10)

In [None]:
diag = np.random.randn(1000)
diag[2] = 0
hess = np.diag(diag)
base_change = ortho_group.rvs(1000)
hess = base_change.T @ hess @ base_change
f_3d = QuadraticFunction(diag).to(DEVICE)

point = torch.tensor(np.ones(1000)).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

#print("eval point:", point)
#print("eval value:", value)
#print("eval grad:", pt_grad)
print()

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(pt_grad, pt_hess)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

#print("ISO SHAPE", iso_shape_operator)
print("small curvatures", iso_curvatures[torch.sort(torch.abs(iso_curvatures)).indices[:10]])
#print("ISO curv", iso_curvatures)
#print("ISO dir", iso_directions)
print()


#print("graph SHAPE", graph_shape_operator)
#print("graph curv", graph_curvatures)
#print("graph dir", graph_directions)

In [None]:
f_3d = QuadraticFunction(np.ones(1000)).to(DEVICE)
projection_subspace_of_interest = None

point = torch.tensor([0.0 for i in range(999)]+[1.5,], dtype=torch.double).to(DEVICE)
value, pt_grad, pt_hess = value_grad_hess(f_3d, point)

# isoresponse
iso_shape_operator, iso_curvatures, iso_directions = curve_utils.local_response_curvature_isoresponse_surface(
    pt_grad, pt_hess,
    projection_subspace_of_interest=projection_subspace_of_interest
)
graph_shape_operator, graph_curvatures, graph_directions = curve_utils.local_response_curvature_graph(pt_grad, pt_hess)

print("ISO Gauss", torch.prod(iso_curvatures))
print("Gauss target", 1/point[-1]**(999))
print("ratio", torch.prod(iso_curvatures)*point[-1]**(999))