Skip to content

Commit

Permalink
add plot_4d
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Jun 4, 2021
1 parent d32983d commit d2e8b8a
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions getdist/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,7 +2697,9 @@ def add_colorbar(self, param, orientation='vertical', mappable=None, ax=None,
**color_label_in_axes** - if True, label is not added (insert as text label in plot instead)
:return: The new :class:`~matplotlib:matplotlib.colorbar.Colorbar` instance
"""
cb = self.fig.colorbar(mappable, orientation=orientation, ax=self.get_axes(ax), **colorbar_args)
kwargs = {'orientation': orientation}
kwargs.update(colorbar_args)
cb = self.fig.colorbar(mappable, ax=self.get_axes(ax), **kwargs)
cb.set_alpha(1)
if not ax_args.get('color_label_in_axes'):
self.add_colorbar_label(cb, param)
Expand Down Expand Up @@ -3019,13 +3021,12 @@ def add_4d_scatter(self, root, params, ax, color_bar=False, max_scatter_points:
x, y, z = samples[:3]
colors = fixed_color or samples[3]

opts = dict(marker=kwargs.get('marker', 'o'), cmap=self.settings.colormap_scatter,
s=kwargs.get('s', self.settings.scatter_size))
opts.update(kwargs)
opts = dict({'marker': 'o', 'cmap': self.settings.colormap_scatter,
's': self.settings.scatter_size}, **kwargs)
ax.scatter(x, y, z, c=colors, depthshade=True, **opts)

if color_bar and not fixed_color:
mappable = cm.ScalarMappable(plt.Normalize(colors.min(), colors.max()))
mappable = cm.ScalarMappable(plt.Normalize(colors.min(), colors.max()), cmap=opts['cmap'])
mappable.set_array(colors)
self.last_colorbar = self.add_colorbar(params[3], mappable=mappable,
ax=ax, colorbar_args=colorbar_args)
Expand All @@ -3035,6 +3036,7 @@ def add_4d_scatter(self, root, params, ax, color_bar=False, max_scatter_points:
def plot_4d(self, roots, params, color_bar=True, colorbar_args: Mapping = empty_dict,
ax=None, lims=empty_dict, azim: Optional[float] = 77, elev: Optional[float] = None,
alpha: Union[float, Sequence[float]] = 0.1, marker='o',
max_scatter_points: Optional[int] = None,
shadow_color=None, shadow_alpha=None, fixed_color=None, compare_colors=None,
animate=False, anim_angle_degrees=0.6, anim_fps=15,
mp4_filename: Optional[str] = None, mp4_bitrate=-1, **kwargs):
Expand All @@ -3056,6 +3058,7 @@ def plot_4d(self, roots, params, color_bar=True, colorbar_args: Mapping = empty_
:param elev: elevation for initial view
:param alpha: alpha, or list of alphas for each root, to use for scatter samples
:param marker: marker, or list of markers for each root
:param max_scatter_points: if set, maximum number of points to plots from each root
:param shadow_color: if not None, a color value (or list of color values) to use for plotting axes-projected
samples; or True to plot gray shadows
:param shadow_alpha: if not None, separate alpha or list of alpha for shadows
Expand All @@ -3069,6 +3072,32 @@ def plot_4d(self, roots, params, color_bar=True, colorbar_args: Mapping = empty_
:param mp4_filename: if animating, optional filename to produce mp4 video
:param mp4_bitrate: bitrate
:param kwargs: additional optional arguments for :func:`~matplotlib:matplotlib.axes3d.Axes3D.scatter`
.. plot::
:include-source:
from getdist import plots, gaussian_mixtures
samples1, samples2 = gaussian_mixtures.randomTestMCSamples(ndim=4, nMCSamples=2)
samples1.samples[:, 0] *= 5 # stretch out in one direction
g = plots.get_single_plotter(width_inch=8)
g.plot_4d([samples1, samples2], ['x0', 'x1', 'x2', 'x3'], cmap='viridis',
alpha=[0.3, 0.1], shadow_color=False, compare_colors=['k'])
.. plot::
:include-source:
from getdist import plots, gaussian_mixtures
samples1 = gaussian_mixtures.randomTestMCSamples(ndim=4)
samples1.samples[:, 0] *= 5 # stretch out in one direction
g.plot_4d(samples1, ['x0', 'x1', 'x2', 'x3'], cmap='jet',
alpha=0.4, shadow_alpha=0.05, shadow_color=True,
max_scatter_points=6000,
lims={'x2': (-3, 3), 'x3': (-3, 3)},
colorbar_args={'shrink': 0.6});
"""
roots = makeList(roots)
if not params:
Expand All @@ -3086,6 +3115,7 @@ def plot_4d(self, roots, params, color_bar=True, colorbar_args: Mapping = empty_
if compare_colors is not None
else None)),
lims=lims, alpha=alph, marker=mark,
max_scatter_points=max_scatter_points,
colorbar_args=colorbar_args, **kwargs))

axes = ax.xaxis, ax.yaxis, ax.zaxis
Expand Down

0 comments on commit d2e8b8a

Please sign in to comment.