/
persistence_diagrams.py
140 lines (125 loc) · 4.92 KB
/
persistence_diagrams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Persistent-homology–related plotting functions and classes."""
# License: GNU AGPLv3
import numpy as np
import plotly.graph_objs as gobj
def plot_diagram(diagram, homology_dimensions=None, plotly_params=None):
"""Plot a single persistence diagram.
Parameters
----------
diagram : ndarray of shape (n_points, 3)
The persistence diagram to plot, where the third dimension along axis 1
contains homology dimensions, and the first two contain (birth, death)
pairs to be used as coordinates in the two-dimensional plot.
homology_dimensions : list of int or None, optional, default: ``None``
Homology dimensions which will appear on the plot. If ``None``, all
homology dimensions which appear in `diagram` will be plotted.
plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should be
dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.
Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Figure representing the persistence diagram.
"""
# TODO: increase the marker size
if homology_dimensions is None:
homology_dimensions = np.unique(diagram[:, 2])
diagram = diagram[diagram[:, 0] != diagram[:, 1]]
diagram_no_dims = diagram[:, :2]
posinfinite_mask = np.isposinf(diagram_no_dims)
neginfinite_mask = np.isneginf(diagram_no_dims)
max_val = np.max(np.where(posinfinite_mask, -np.inf, diagram_no_dims))
min_val = np.min(np.where(neginfinite_mask, np.inf, diagram_no_dims))
parameter_range = max_val - min_val
extra_space_factor = 0.02
has_posinfinite_death = np.any(posinfinite_mask[:, 1])
if has_posinfinite_death:
posinfinity_val = max_val + 0.1 * parameter_range
extra_space_factor += 0.1
extra_space = extra_space_factor * parameter_range
min_val_display = min_val - extra_space
max_val_display = max_val + extra_space
fig = gobj.Figure()
fig.add_trace(gobj.Scatter(
x=[min_val_display, max_val_display],
y=[min_val_display, max_val_display],
mode="lines",
line={"dash": "dash", "width": 1, "color": "black"},
showlegend=False,
hoverinfo="none"
))
for dim in homology_dimensions:
name = f"H{int(dim)}" if dim != np.inf else "Any homology dimension"
subdiagram = diagram[diagram[:, 2] == dim]
unique, inverse, counts = np.unique(
subdiagram, axis=0, return_inverse=True, return_counts=True
)
hovertext = [
f"{tuple(unique[unique_row_index][:2])}" +
(
f", multiplicity: {counts[unique_row_index]}"
if counts[unique_row_index] > 1 else ""
)
for unique_row_index in inverse
]
y = subdiagram[:, 1]
if has_posinfinite_death:
y[np.isposinf(y)] = posinfinity_val
fig.add_trace(gobj.Scatter(
x=subdiagram[:, 0], y=y, mode="markers",
hoverinfo="text", hovertext=hovertext, name=name
))
fig.update_layout(
width=500,
height=500,
xaxis1={
"title": "Birth",
"side": "bottom",
"type": "linear",
"range": [min_val_display, max_val_display],
"autorange": False,
"ticks": "outside",
"showline": True,
"zeroline": True,
"linewidth": 1,
"linecolor": "black",
"mirror": False,
"showexponent": "all",
"exponentformat": "e"
},
yaxis1={
"title": "Death",
"side": "left",
"type": "linear",
"range": [min_val_display, max_val_display],
"autorange": False, "scaleanchor": "x", "scaleratio": 1,
"ticks": "outside",
"showline": True,
"zeroline": True,
"linewidth": 1,
"linecolor": "black",
"mirror": False,
"showexponent": "all",
"exponentformat": "e"
},
plot_bgcolor="white"
)
# Add a horizontal dashed line for points with infinite death
if has_posinfinite_death:
fig.add_trace(gobj.Scatter(
x=[min_val_display, max_val_display],
y=[posinfinity_val, posinfinity_val],
mode="lines",
line={"dash": "dash", "width": 0.5, "color": "black"},
showlegend=True,
name=u"\u221E",
hoverinfo="none"
))
# Update traces and layout according to user input
if plotly_params:
fig.update_traces(plotly_params.get("traces", None))
fig.update_layout(plotly_params.get("layout", None))
return fig