Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding discrete curve test class to Geomstats #1814

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions geomstats/geometry/discrete_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import math

from scipy.interpolate import CubicSpline

import geomstats.backend as gs
from geomstats.algebra_utils import from_vector_to_diagonal_matrix
from geomstats.geometry.base import LevelSet
Expand All @@ -17,6 +15,7 @@
from geomstats.geometry.quotient_metric import QuotientMetric
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from scipy.interpolate import CubicSpline

R2 = Euclidean(dim=2)
R3 = Euclidean(dim=3)
Expand Down
342 changes: 342 additions & 0 deletions geomstats/visualization/discrete_curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
#!/usr/bin/env python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format the information at the top of this file: put it in a docstring format, as done in other .py files in geomstats.


"""Discrete_curve.py file details."""

__authors__ = "Jax Burd & Abhijith Atreya"
__course__ = "UCSB ECE 594N"
__professor__ = "Nina Miolane"
__deadline__ = "Thursday, 04/21/2022"

# -------------------------------------------------------------------------------------------------------------------

# IMPORTS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary comment

import matplotlib.pyplot as plt
import numpy as np
from geomstats.geometry.discrete_curves import DiscreteCurves
from geomstats.geometry.euclidean import Euclidean


# CLASS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary comment

class DiscreteCurveViz:
"""Space of discrete curves sampled at points in ambient_manifold.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say """Visualization for the space of..."""


Each curve is represented by a 2d-array of shape `[n_sampling_points, ambient_dim]`.
A Batch of curves can be passed to
all methods either as a 3d-array if all curves have the same number of
sampled points, or as a list of 2d-arrays, each representing a curve.

Parameters
----------
curve_dimension : Manifold
Manifold in which curves take values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove blank line between parameter description, here and everywhere in the docstrings

param_curves_list: list
List of lambda functions representing each parameterized curve.

n_sampling_points : int
Number of sampling points to applied to discretize the curves

Attributes
----------
dim : Manifold
Manifold in which curves take values.

param_curves : list
List of lambda functions representing each parameterized curve.

sampling_points : list
List of sampling points to pass through the parameterized curve functions.

curve_points: list
List of resulting points when sampling points are applied to their function.

"""

def __init__(self, curve_dimension, param_curves_list, sampling_points):
self.dim = curve_dimension
self.param_curves = param_curves_list
self.sampling_points = sampling_points
self.n = len(sampling_points[0])
self.curve_points = self.set_curves()

def set_curves(self):
"""Pass sampling points to curve functions as an internal helper function."""
curves = []

for i, p_curve in enumerate(self.param_curves):
curves.append(p_curve(self.sampling_points[i]))

return curves

def resample(self, adjusted_sampling_points):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename adjusted_sampling_points -> n_sampling_points

"""Resample the curve based on variable adjusted_sampling_points."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring is incomplete: add missing Parameters and Returns descriptions

self.sampling_points = adjusted_sampling_points
curves = []

for i, p_curve in enumerate(self.param_curves):
curves.append(p_curve(self.sampling_points[i]))

self.n = len(list(adjusted_sampling_points))

self.curves_points = curves

def plot_3Dcurves(self, linestyles, labels, title):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no uppercase letter in function names --> rename to plot_curves_3d

"""Create plots of given set of curves in 3D graph space.

Parameters
----------
linestyles : array-like, string elements
Matpotlib linestyles to apply to respective curves.

labels : array-like, string elements
Labels for axes on the plot.

title : string
Title of the plot.
"""
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
for i, curve in enumerate(self.curve_points):
ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], linestyles[i], linewidth=2)
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])
ax.set_title(title)

def plot_geodesic(
self, n_times, inital_index, end_index, linestyles, labels, title
):
"""Create plots of geodesic between two chosen curves.

Parameters
----------
n_times : int
Number of geodesic curves to plot inbetween

inital_index : int
Index of the starting curve.

end_index : int
Index of the end curve.

linestyles : array-like, string elements
Matpotlib linestyles to apply to respective curves.

labels : array-like, string elements
Labels for axes on the plot.

title : string
Title of the plot.
"""
geod_fun = self.dim.srv_metric.geodesic(
initial_point=self.curve_points[inital_index],
end_point=self.curve_points[end_index],
)

times = np.linspace(0.0, 1.0, n_times)
geod = geod_fun(times)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")

plt.figure(figsize=(10, 10))
ax.plot(geod[0, :, 0], geod[0, :, 1], geod[0, :, 2], linestyles[0])

for i in range(1, n_times - 1):
ax.plot(geod[i, :, 0], geod[i, :, 1], geod[i, :, 2], linestyles[1])

ax.plot(geod[-1, :, 0], geod[-1, :, 1], geod[-1, :, 2], linestyles[2])

ax.set_title(title)
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])

def plot_geodesic_net(
self, n_times, inital_index, end_index, linestyles, labels, title, view_init
):
"""Create a plot of geodesics between two chosen curves in a wireframe style.

Parameters
----------
n_times : int
Number of geodesic curves to plot inbetween

inital_index : int
Index of the starting curve.

end_index : int
Index of the end curve.

linestyles : array-like, string elements
Matpotlib linestyles to apply to respective curves.

labels : array-like, string elements
Labels for axes on the plot.

title : string
Title of the plot.
"""
geod_fun = self.dim.srv_metric.geodesic(
initial_point=self.curve_points[inital_index],
end_point=self.curve_points[end_index],
)

times = np.linspace(0.0, 1.0, n_times)
geod = geod_fun(times)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")

plt.figure(figsize=(10, 10))

ax.plot3D(
geod[0, :, 0], geod[0, :, 1], geod[0, :, 2], linestyles[0], linewidth=2
)
for i in range(1, n_times - 1):
ax.plot3D(
geod[i, :, 0], geod[i, :, 1], geod[i, :, 2], linestyles[1], linewidth=1
)
for j in range(self.n):
ax.plot3D(
geod[:, j, 0], geod[:, j, 1], geod[:, j, 2], linestyles[1], linewidth=1
)
ax.plot3D(
geod[-1, :, 0], geod[-1, :, 1], geod[-1, :, 2], linestyles[2], linewidth=2
)

ax.set_title(title)
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])

if view_init:
ax.view_init(view_init[0], view_init[1])
fig

def plot_parallel_transport(
self,
n_times,
sampling_point_index,
inital_index,
end_index,
linestyles,
labels,
title,
view_init,
):
"""Highlights a certain line along a geodesic between curves in red.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need imperativ form in docstring: Highlights --> Highlight


Parameters
----------
n_times : int
Number of geodesic curves to plot inbetween

sampling_point_index : int
Index of sampling point on both curves to highlight parallel transport

inital_index : int
Index of the starting curve.

end_index : int
Index of the end curve.

linestyles : array-like, string elements
Matpotlib linestyles to apply to respective curves.

labels : array-like, string elements
Labels for axes on the plot.

title : string
Title of the plot.

view_init : array-like
List of elevation arguement and rotation argument of plot3D view angle
"""
geod_fun = self.dim.srv_metric.geodesic(
initial_point=self.curve_points[inital_index],
end_point=self.curve_points[end_index],
)

times = np.linspace(0.0, 1.0, n_times)
geod = geod_fun(times)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")

plt.figure(figsize=(10, 10))

ax.plot3D(
geod[0, :, 0], geod[0, :, 1], geod[0, :, 2], linestyles[0], linewidth=2
)
ax.plot3D(
geod[0, sampling_point_index, 0],
geod[0, sampling_point_index, 1],
geod[0, sampling_point_index, 2],
"or",
linewidth=2,
)

for i in range(1, n_times - 1):
ax.plot3D(
geod[i, :, 0], geod[i, :, 1], geod[i, :, 2], linestyles[1], linewidth=1
)
for j in range(self.n):
if j is sampling_point_index:
ax.plot3D(geod[:, j, 0], geod[:, j, 1], geod[:, j, 2], "r", linewidth=2)

else:
ax.plot3D(
geod[:, j, 0],
geod[:, j, 1],
geod[:, j, 2],
linestyles[1],
linewidth=1,
)

ax.plot3D(
geod[-1, :, 0], geod[-1, :, 1], geod[-1, :, 2], linestyles[2], linewidth=2
)
ax.plot3D(
geod[-1, sampling_point_index, 0],
geod[-1, sampling_point_index, 1],
geod[-1, sampling_point_index, 2],
"or",
linewidth=2,
)

ax.set_title(title)
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])

if view_init:
ax.view_init(view_init[0], view_init[1])
fig


# TESTING
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rm unnecessary comment.

def main():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great example! however, this should go in a notebook or in an separate example file: remove and put in 10_pratical_methods__shape_analysis...

"""Show an example usage of the vizualization implemented for discrete curves."""
R3 = Euclidean(dim=3)
dc = DiscreteCurves(ambient_manifold=R3)
param_curves_list = [
lambda x: np.array([np.cos(x), np.sin(x), x]),
lambda x: np.array([np.sin(x), np.cos(x), x]),
]
sampling_points = [np.linspace(0, 2 * np.pi, 10), np.linspace(0, 2 * np.pi, 5)]
dcv = DiscreteCurveViz(
curve_dimension=dc,
param_curves_list=param_curves_list,
sampling_points=sampling_points,
)
linestyles = ["r-", "b-"]
labels = ["x", "y", "z"]
title = "3D Curve Plot"
dcv.plot_3Dcurves(linestyles, labels, title)

return 0


if __name__ == "__main__":
main()
Loading