-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding discrete curve test class to Geomstats #1814
base: main
Are you sure you want to change the base?
Changes from all commits
2621d6f
cd896b2
8ec764b
875625f
6f69525
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
#!/usr/bin/env python | ||
|
||
"""Discrete_curve.py file details.""" | ||
|
||
__authors__ = "Jax Burd & Abhijith Atreya" | ||
__course__ = "UCSB ECE 594N" | ||
__professor__ = "Nina Miolane" | ||
__deadline__ = "Thursday, 04/21/2022" | ||
|
||
# ------------------------------------------------------------------------------------------------------------------- | ||
|
||
# IMPORTS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unnecessary comment |
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from geomstats.geometry.discrete_curves import DiscreteCurves | ||
from geomstats.geometry.euclidean import Euclidean | ||
|
||
|
||
# CLASS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unnecessary comment |
||
class DiscreteCurveViz: | ||
"""Space of discrete curves sampled at points in ambient_manifold. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Say """Visualization for the space of...""" |
||
|
||
Each curve is represented by a 2d-array of shape `[n_sampling_points, ambient_dim]`. | ||
A Batch of curves can be passed to | ||
all methods either as a 3d-array if all curves have the same number of | ||
sampled points, or as a list of 2d-arrays, each representing a curve. | ||
|
||
Parameters | ||
---------- | ||
curve_dimension : Manifold | ||
Manifold in which curves take values. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove blank line between parameter description, here and everywhere in the docstrings |
||
param_curves_list: list | ||
List of lambda functions representing each parameterized curve. | ||
|
||
n_sampling_points : int | ||
Number of sampling points to applied to discretize the curves | ||
|
||
Attributes | ||
---------- | ||
dim : Manifold | ||
Manifold in which curves take values. | ||
|
||
param_curves : list | ||
List of lambda functions representing each parameterized curve. | ||
|
||
sampling_points : list | ||
List of sampling points to pass through the parameterized curve functions. | ||
|
||
curve_points: list | ||
List of resulting points when sampling points are applied to their function. | ||
|
||
""" | ||
|
||
def __init__(self, curve_dimension, param_curves_list, sampling_points): | ||
self.dim = curve_dimension | ||
self.param_curves = param_curves_list | ||
self.sampling_points = sampling_points | ||
self.n = len(sampling_points[0]) | ||
self.curve_points = self.set_curves() | ||
|
||
def set_curves(self): | ||
"""Pass sampling points to curve functions as an internal helper function.""" | ||
curves = [] | ||
|
||
for i, p_curve in enumerate(self.param_curves): | ||
curves.append(p_curve(self.sampling_points[i])) | ||
|
||
return curves | ||
|
||
def resample(self, adjusted_sampling_points): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename adjusted_sampling_points -> n_sampling_points |
||
"""Resample the curve based on variable adjusted_sampling_points.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring is incomplete: add missing Parameters and Returns descriptions |
||
self.sampling_points = adjusted_sampling_points | ||
curves = [] | ||
|
||
for i, p_curve in enumerate(self.param_curves): | ||
curves.append(p_curve(self.sampling_points[i])) | ||
|
||
self.n = len(list(adjusted_sampling_points)) | ||
|
||
self.curves_points = curves | ||
|
||
def plot_3Dcurves(self, linestyles, labels, title): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no uppercase letter in function names --> rename to plot_curves_3d |
||
"""Create plots of given set of curves in 3D graph space. | ||
|
||
Parameters | ||
---------- | ||
linestyles : array-like, string elements | ||
Matpotlib linestyles to apply to respective curves. | ||
|
||
labels : array-like, string elements | ||
Labels for axes on the plot. | ||
|
||
title : string | ||
Title of the plot. | ||
""" | ||
fig = plt.figure(figsize=(10, 10)) | ||
ax = fig.add_subplot(111, projection="3d") | ||
for i, curve in enumerate(self.curve_points): | ||
ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], linestyles[i], linewidth=2) | ||
ax.set_xlabel(labels[0]) | ||
ax.set_ylabel(labels[1]) | ||
ax.set_zlabel(labels[2]) | ||
ax.set_title(title) | ||
|
||
def plot_geodesic( | ||
self, n_times, inital_index, end_index, linestyles, labels, title | ||
): | ||
"""Create plots of geodesic between two chosen curves. | ||
|
||
Parameters | ||
---------- | ||
n_times : int | ||
Number of geodesic curves to plot inbetween | ||
|
||
inital_index : int | ||
Index of the starting curve. | ||
|
||
end_index : int | ||
Index of the end curve. | ||
|
||
linestyles : array-like, string elements | ||
Matpotlib linestyles to apply to respective curves. | ||
|
||
labels : array-like, string elements | ||
Labels for axes on the plot. | ||
|
||
title : string | ||
Title of the plot. | ||
""" | ||
geod_fun = self.dim.srv_metric.geodesic( | ||
initial_point=self.curve_points[inital_index], | ||
end_point=self.curve_points[end_index], | ||
) | ||
|
||
times = np.linspace(0.0, 1.0, n_times) | ||
geod = geod_fun(times) | ||
|
||
fig = plt.figure(figsize=(10, 10)) | ||
ax = fig.add_subplot(111, projection="3d") | ||
|
||
plt.figure(figsize=(10, 10)) | ||
ax.plot(geod[0, :, 0], geod[0, :, 1], geod[0, :, 2], linestyles[0]) | ||
|
||
for i in range(1, n_times - 1): | ||
ax.plot(geod[i, :, 0], geod[i, :, 1], geod[i, :, 2], linestyles[1]) | ||
|
||
ax.plot(geod[-1, :, 0], geod[-1, :, 1], geod[-1, :, 2], linestyles[2]) | ||
|
||
ax.set_title(title) | ||
ax.set_xlabel(labels[0]) | ||
ax.set_ylabel(labels[1]) | ||
ax.set_zlabel(labels[2]) | ||
|
||
def plot_geodesic_net( | ||
self, n_times, inital_index, end_index, linestyles, labels, title, view_init | ||
): | ||
"""Create a plot of geodesics between two chosen curves in a wireframe style. | ||
|
||
Parameters | ||
---------- | ||
n_times : int | ||
Number of geodesic curves to plot inbetween | ||
|
||
inital_index : int | ||
Index of the starting curve. | ||
|
||
end_index : int | ||
Index of the end curve. | ||
|
||
linestyles : array-like, string elements | ||
Matpotlib linestyles to apply to respective curves. | ||
|
||
labels : array-like, string elements | ||
Labels for axes on the plot. | ||
|
||
title : string | ||
Title of the plot. | ||
""" | ||
geod_fun = self.dim.srv_metric.geodesic( | ||
initial_point=self.curve_points[inital_index], | ||
end_point=self.curve_points[end_index], | ||
) | ||
|
||
times = np.linspace(0.0, 1.0, n_times) | ||
geod = geod_fun(times) | ||
|
||
fig = plt.figure(figsize=(10, 10)) | ||
ax = fig.add_subplot(111, projection="3d") | ||
|
||
plt.figure(figsize=(10, 10)) | ||
|
||
ax.plot3D( | ||
geod[0, :, 0], geod[0, :, 1], geod[0, :, 2], linestyles[0], linewidth=2 | ||
) | ||
for i in range(1, n_times - 1): | ||
ax.plot3D( | ||
geod[i, :, 0], geod[i, :, 1], geod[i, :, 2], linestyles[1], linewidth=1 | ||
) | ||
for j in range(self.n): | ||
ax.plot3D( | ||
geod[:, j, 0], geod[:, j, 1], geod[:, j, 2], linestyles[1], linewidth=1 | ||
) | ||
ax.plot3D( | ||
geod[-1, :, 0], geod[-1, :, 1], geod[-1, :, 2], linestyles[2], linewidth=2 | ||
) | ||
|
||
ax.set_title(title) | ||
ax.set_xlabel(labels[0]) | ||
ax.set_ylabel(labels[1]) | ||
ax.set_zlabel(labels[2]) | ||
|
||
if view_init: | ||
ax.view_init(view_init[0], view_init[1]) | ||
fig | ||
|
||
def plot_parallel_transport( | ||
self, | ||
n_times, | ||
sampling_point_index, | ||
inital_index, | ||
end_index, | ||
linestyles, | ||
labels, | ||
title, | ||
view_init, | ||
): | ||
"""Highlights a certain line along a geodesic between curves in red. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need imperativ form in docstring: Highlights --> Highlight |
||
|
||
Parameters | ||
---------- | ||
n_times : int | ||
Number of geodesic curves to plot inbetween | ||
|
||
sampling_point_index : int | ||
Index of sampling point on both curves to highlight parallel transport | ||
|
||
inital_index : int | ||
Index of the starting curve. | ||
|
||
end_index : int | ||
Index of the end curve. | ||
|
||
linestyles : array-like, string elements | ||
Matpotlib linestyles to apply to respective curves. | ||
|
||
labels : array-like, string elements | ||
Labels for axes on the plot. | ||
|
||
title : string | ||
Title of the plot. | ||
|
||
view_init : array-like | ||
List of elevation arguement and rotation argument of plot3D view angle | ||
""" | ||
geod_fun = self.dim.srv_metric.geodesic( | ||
initial_point=self.curve_points[inital_index], | ||
end_point=self.curve_points[end_index], | ||
) | ||
|
||
times = np.linspace(0.0, 1.0, n_times) | ||
geod = geod_fun(times) | ||
|
||
fig = plt.figure(figsize=(10, 10)) | ||
ax = fig.add_subplot(111, projection="3d") | ||
|
||
plt.figure(figsize=(10, 10)) | ||
|
||
ax.plot3D( | ||
geod[0, :, 0], geod[0, :, 1], geod[0, :, 2], linestyles[0], linewidth=2 | ||
) | ||
ax.plot3D( | ||
geod[0, sampling_point_index, 0], | ||
geod[0, sampling_point_index, 1], | ||
geod[0, sampling_point_index, 2], | ||
"or", | ||
linewidth=2, | ||
) | ||
|
||
for i in range(1, n_times - 1): | ||
ax.plot3D( | ||
geod[i, :, 0], geod[i, :, 1], geod[i, :, 2], linestyles[1], linewidth=1 | ||
) | ||
for j in range(self.n): | ||
if j is sampling_point_index: | ||
ax.plot3D(geod[:, j, 0], geod[:, j, 1], geod[:, j, 2], "r", linewidth=2) | ||
|
||
else: | ||
ax.plot3D( | ||
geod[:, j, 0], | ||
geod[:, j, 1], | ||
geod[:, j, 2], | ||
linestyles[1], | ||
linewidth=1, | ||
) | ||
|
||
ax.plot3D( | ||
geod[-1, :, 0], geod[-1, :, 1], geod[-1, :, 2], linestyles[2], linewidth=2 | ||
) | ||
ax.plot3D( | ||
geod[-1, sampling_point_index, 0], | ||
geod[-1, sampling_point_index, 1], | ||
geod[-1, sampling_point_index, 2], | ||
"or", | ||
linewidth=2, | ||
) | ||
|
||
ax.set_title(title) | ||
ax.set_xlabel(labels[0]) | ||
ax.set_ylabel(labels[1]) | ||
ax.set_zlabel(labels[2]) | ||
|
||
if view_init: | ||
ax.view_init(view_init[0], view_init[1]) | ||
fig | ||
|
||
|
||
# TESTING | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rm unnecessary comment. |
||
def main(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great example! however, this should go in a notebook or in an separate example file: remove and put in 10_pratical_methods__shape_analysis... |
||
"""Show an example usage of the vizualization implemented for discrete curves.""" | ||
R3 = Euclidean(dim=3) | ||
dc = DiscreteCurves(ambient_manifold=R3) | ||
param_curves_list = [ | ||
lambda x: np.array([np.cos(x), np.sin(x), x]), | ||
lambda x: np.array([np.sin(x), np.cos(x), x]), | ||
] | ||
sampling_points = [np.linspace(0, 2 * np.pi, 10), np.linspace(0, 2 * np.pi, 5)] | ||
dcv = DiscreteCurveViz( | ||
curve_dimension=dc, | ||
param_curves_list=param_curves_list, | ||
sampling_points=sampling_points, | ||
) | ||
linestyles = ["r-", "b-"] | ||
labels = ["x", "y", "z"] | ||
title = "3D Curve Plot" | ||
dcv.plot_3Dcurves(linestyles, labels, title) | ||
|
||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format the information at the top of this file: put it in a docstring format, as done in other .py files in geomstats.