Skip to content

Commit

Permalink
fixes [gaps between clusters](#32) issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremymanning committed Nov 30, 2021
1 parent 5f0e5de commit 0123b00
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions hypertools/plot/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,8 @@ def static_plot(data, **kwargs):
fig.add_trace(get_plotly_shape(dummy_coords, **s, name=n, legendgroup=n))
kwargs['showlegend'] = False

# FIXME: fill this in-- add "null" objects to the legend, and then force showlegend to be False for all other
# shapes. also need to change how trace "legendgroups" are inferred; if legend_override is specified (and not
# None), shape legendgroups should correspond to cluster labels rather than traces

if type(data) is list:
names = kwargs.pop('name', [str(d) for d in range(len(data))]) # FIXME: flagging for updating...
names = kwargs.pop('name', [str(d) for d in range(len(data))])
for i, d in enumerate(data):
opts = {'color': get(color, i), 'fig': fig, 'name': get(names, i), 'legendgroup': get(names, i)}

Expand Down Expand Up @@ -291,13 +287,20 @@ def static_plot(data, **kwargs):
if legend_override is not None:
next_labels = legend_override['labels'].iloc[inds].values
for k in np.unique(next_labels):
group_inds = np.where(next_labels == k)[0]
group_inds = inds[np.where(next_labels == k)[0]]

if max(group_inds) < data.shape[0] - 1:
group_inds = np.append(group_inds, [max(group_inds) + 1])

group_opts = opts.copy()
group_opts['legendgroup'] = legend_override['names'][k]
fig.add_trace(get_plotly_shape(data.values[inds[group_inds], :],
fig.add_trace(get_plotly_shape(data.values[group_inds, :],
**dw.core.update_dict(kwargs, group_opts),
color=mpl2plotly_color(c)))
else:
if max(inds) < data.shape[0] - 1:
inds = np.append(inds, [max(inds) + 1])

fig.add_trace(get_plotly_shape(data.values[inds, :],
**dw.core.update_dict(kwargs, opts),
color=mpl2plotly_color(c)))
Expand Down

0 comments on commit 0123b00

Please sign in to comment.