Skip to content

Commit

Permalink
Expose label kwarg on plot methods. Closes #221.
Browse files Browse the repository at this point in the history
  • Loading branch information
polyatail committed Mar 6, 2019
1 parent fbbd44e commit d5f3dfc
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 31 deletions.
7 changes: 6 additions & 1 deletion onecodex/viz/_bargraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def plot_bargraph(
return_chart=False,
haxis=None,
legend="auto",
label=None,
):
"""Plot a bargraph of relative abundance of taxa for multiple samples.
Expand Down Expand Up @@ -49,6 +50,10 @@ def plot_bargraph(
legend: `string`, optional
Title for color scale. Defaults to the field used to generate the plot, e.g.
readcount_w_children or abundance.
label : `string` or `callable`, optional
A metadata field (or function) used to label each analysis. If passing a function, a
dict containing the metadata for each analysis is passed as the first and only
positional argument. The callable function must return a string.
Examples
--------
Expand Down Expand Up @@ -96,7 +101,7 @@ def plot_bargraph(

# takes metadata columns and returns a dataframe with just those columns
# renames columns in the case where columns are taxids
magic_metadata, magic_fields = self._metadata_fetch(tooltip)
magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label)

# add sort order to long-format df
if haxis:
Expand Down
42 changes: 28 additions & 14 deletions onecodex/viz/_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@ def _cluster_by_sample(self, rank='auto', metric='braycurtis', linkage='average'
clustering = hierarchy.linkage(squareform(dist_matrix), method=linkage)
scipy_tree = hierarchy.dendrogram(clustering, no_plot=True)
ids_in_order = [self._results.index[int(x)] for x in scipy_tree['ivl']]
labels_in_order = [self.metadata['_display_name'][t] for t in ids_in_order]

return {
'dist_matrix': dist_matrix,
'clustering': clustering,
'scipy_tree': scipy_tree,
'ids_in_order': ids_in_order,
'labels_in_order': labels_in_order
'ids_in_order': ids_in_order
}

def _cluster_by_taxa(self, linkage='average'):
Expand All @@ -75,7 +73,7 @@ def _cluster_by_taxa(self, linkage='average'):

def plot_distance(self, rank='auto', metric='braycurtis',
title=None, xlabel=None, ylabel=None, tooltip=None, return_chart=False,
linkage='average'):
linkage='average', label=None):
"""Plot beta diversity distance matrix as a heatmap and dendrogram.
Parameters
Expand All @@ -96,6 +94,10 @@ def plot_distance(self, rank='auto', metric='braycurtis',
A string or list containing strings representing metadata fields. When a point in the
plot is hovered over, the value of the metadata associated with that sample will be
displayed in a modal.
label : `string` or `callable`, optional
A metadata field (or function) used to label each analysis. If passing a function, a
dict containing the metadata for each analysis is passed as the first and only
positional argument. The callable function must return a string.
Examples
--------
Expand All @@ -121,7 +123,9 @@ def plot_distance(self, rank='auto', metric='braycurtis',
else:
tooltip = []

magic_metadata, magic_fields = self._metadata_fetch(tooltip)
tooltip.insert(0, "Label")

magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label)
formatted_fields = []

for _, magic_field in magic_fields.items():
Expand All @@ -144,8 +148,6 @@ def plot_distance(self, rank='auto', metric='braycurtis',
else:
plot_data['Distance'].append(clust['dist_matrix'].iloc[idx1, idx2])

plot_data['1) Label'].append(self.metadata['_display_name'][id1])
plot_data['2) Label'].append(self.metadata['_display_name'][id2])
plot_data['classification_id'].append(id1)

for field_group, magic_field in zip(formatted_fields, magic_fields.values()):
Expand All @@ -154,13 +156,15 @@ def plot_distance(self, rank='auto', metric='braycurtis',

plot_data = pd.DataFrame(data=plot_data)

labels_in_order = magic_metadata['Label'][clust['ids_in_order']].tolist()

# it's important to tell altair to order the cells in the heatmap according to the clustering
# obtained from scipy
alt_kwargs = dict(
x=alt.X('1) Label:N', axis=alt.Axis(title=xlabel), sort=clust['labels_in_order']),
y=alt.Y('2) Label:N', axis=alt.Axis(title=ylabel, orient='right'), sort=clust['labels_in_order']),
x=alt.X('1) Label:N', axis=alt.Axis(title=xlabel), sort=labels_in_order),
y=alt.Y('2) Label:N', axis=alt.Axis(title=ylabel, orient='right'), sort=labels_in_order),
color='Distance:Q',
tooltip=['1) Label', '2) Label', 'Distance:Q'] + list(chain.from_iterable(formatted_fields)),
tooltip=list(chain.from_iterable(formatted_fields)) + ['Distance:Q'],
href='url:N',
url='https://app.onecodex.com/classification/' + alt.datum.classification_id
)
Expand All @@ -184,7 +188,7 @@ def plot_distance(self, rank='auto', metric='braycurtis',

def plot_mds(self, rank='auto', metric='braycurtis', method='pcoa',
title=None, xlabel=None, ylabel=None, color=None, size=None, tooltip=None,
return_chart=False):
return_chart=False, label=None):
"""Plot beta diversity distance matrix using multidimensional scaling (MDS).
Parameters
Expand Down Expand Up @@ -213,6 +217,10 @@ def plot_mds(self, rank='auto', metric='braycurtis', method='pcoa',
A string or list containing strings representing metadata fields. When a point in the
plot is hovered over, the value of the metadata associated with that sample will be
displayed in a modal.
label : `string` or `callable`, optional
A metadata field (or function) used to label each analysis. If passing a function, a
dict containing the metadata for each analysis is passed as the first and only
positional argument. The callable function must return a string.
Examples
--------
Expand Down Expand Up @@ -242,9 +250,15 @@ def plot_mds(self, rank='auto', metric='braycurtis', method='pcoa',
else:
tooltip = []

tooltip = list(set(['Label', color, size] + tooltip))
tooltip.insert(0, 'Label')

if color and color not in tooltip:
tooltip.insert(1, color)

if size and size not in tooltip:
tooltip.insert(2, size)

magic_metadata, magic_fields = self._metadata_fetch(tooltip)
magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label)

if method == 'smacof':
# adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_mds.html
Expand Down Expand Up @@ -303,7 +317,7 @@ def plot_mds(self, rank='auto', metric='braycurtis', method='pcoa',
alt_kwargs = dict(
x=alt.X(x_field, axis=alt.Axis(title=xlabel)),
y=alt.Y(y_field, axis=alt.Axis(title=ylabel)),
tooltip=[magic_fields[t] for t in tooltip if t],
tooltip=[magic_fields[t] for t in tooltip],
href='url:N',
url='https://app.onecodex.com/classification/' + alt.datum.classification_id
)
Expand Down
28 changes: 20 additions & 8 deletions onecodex/viz/_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
class VizHeatmapMixin(object):
def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='auto',
title=None, xlabel=None, ylabel=None, tooltip=None, return_chart=False,
linkage='average', haxis=None, metric='euclidean', legend='auto'):
linkage='average', haxis=None, metric='euclidean', legend='auto',
label=None):
"""Plot heatmap of taxa abundance/count data for several samples.
Parameters
Expand Down Expand Up @@ -44,6 +45,10 @@ def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='a
legend: `string`, optional
Title for color scale. Defaults to the field used to generate the plot, e.g.
readcount_w_children or abundance.
label : `string` or `callable`, optional
A metadata field (or function) used to label each analysis. If passing a function, a
dict containing the metadata for each analysis is passed as the first and only
positional argument. The callable function must return a string.
Examples
--------
Expand Down Expand Up @@ -88,10 +93,12 @@ def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='a
if haxis:
tooltip.append(haxis)

magic_metadata, magic_fields = self._metadata_fetch(tooltip)
tooltip.insert(0, "Label")

magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label)

# add columns for prettier display
df['display_name'] = self.metadata['_display_name'][df['classification_id']].tolist()
df['Label'] = magic_metadata['Label'][df['classification_id']].tolist()
df['tax_name'] = ['{} ({})'.format(self.taxonomy['name'][t], t) for t in df['tax_id']]

# and for metadata
Expand Down Expand Up @@ -122,7 +129,7 @@ def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='a
sample_cluster = df_sample_cluster.ocx._cluster_by_sample(rank=rank, metric=metric, linkage=linkage)
taxa_cluster = df_taxa_cluster.ocx._cluster_by_taxa(linkage=linkage)

labels_in_order = sample_cluster['labels_in_order']
labels_in_order = magic_metadata['Label'][sample_cluster['ids_in_order']].tolist()
else:
if not (pd.api.types.is_bool_dtype(df[magic_fields[haxis]]) or # noqa
pd.api.types.is_categorical_dtype(df[magic_fields[haxis]]) or # noqa
Expand Down Expand Up @@ -156,7 +163,7 @@ def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='a
if len(c_ids_in_group) < 3:
# clustering not possible in this case
cluster_by_group[group] = {
'labels_in_order': self.metadata._display_name[c_ids_in_group]
'ids_in_order': c_ids_in_group
}
else:
cluster_by_group[group] = sample_slice.ocx._cluster_by_sample(
Expand All @@ -165,7 +172,7 @@ def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='a

plot_data['x'].append(len(labels_in_order) + 0.25)

labels_in_order.extend(cluster_by_group[group]['labels_in_order'])
labels_in_order.extend(magic_metadata['Label'][cluster_by_group[group]['ids_in_order']].tolist())

plot_data['x'].append(len(labels_in_order) - 0.25)
plot_data['y'].extend([0, 0])
Expand Down Expand Up @@ -229,11 +236,16 @@ def plot_heatmap(self, rank='auto', normalize='auto', top_n='auto', threshold='a

top_label = alt.layer(label_text, label_bars)

# should ultimately be Label, tax_name, readcount_w_children, then custom fields
tooltip_for_altair = [magic_fields[f] for f in tooltip]
tooltip_for_altair.insert(1, "tax_name")
tooltip_for_altair.insert(2, "{}:Q".format(self._field))

alt_kwargs = dict(
x=alt.X('display_name:N', axis=alt.Axis(title=xlabel), sort=labels_in_order),
x=alt.X('Label:N', axis=alt.Axis(title=xlabel), sort=labels_in_order),
y=alt.Y('tax_name:N', axis=alt.Axis(title=ylabel), sort=taxa_cluster['labels_in_order']),
color=alt.Color('{}:Q'.format(self._field), legend=alt.Legend(title=legend)),
tooltip=['{}:Q'.format(self._field)] + [magic_fields[f] for f in tooltip],
tooltip=tooltip_for_altair,
href='url:N',
url='https://app.onecodex.com/classification/' + alt.datum.classification_id
)
Expand Down
12 changes: 8 additions & 4 deletions onecodex/viz/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@


class VizMetadataMixin(object):
def plot_metadata(self, rank='auto',
haxis='Label', vaxis='simpson',
title=None, xlabel=None, ylabel=None, return_chart=False, plot_type='auto'):
def plot_metadata(self, rank='auto', haxis='Label', vaxis='simpson', title=None, xlabel=None,
ylabel=None, return_chart=False, plot_type='auto', label=None):
"""Plot an arbitrary metadata field versus an arbitrary quantity as a boxplot or scatter plot.
Parameters
Expand Down Expand Up @@ -41,6 +40,11 @@ def plot_metadata(self, rank='auto',
By default, will determine plot type automatically based on the data. Otherwise, specify
one of 'boxplot' or 'scatter' to set the type of plot manually.
label : `string` or `callable`, optional
A metadata field (or function) used to label each analysis. If passing a function, a
dict containing the metadata for each analysis is passed as the first and only
positional argument. The callable function must return a string.
Examples
--------
Generate a boxplot of the abundance of Bacteroides (genus) of samples grouped by whether the
Expand All @@ -55,7 +59,7 @@ def plot_metadata(self, rank='auto',
raise OneCodexException('Plot type must be one of: auto, boxplot, scatter')

# alpha diversity is only allowed on vertical axis--horizontal can be magically mapped
df, magic_fields = self._metadata_fetch([haxis, 'Label'])
df, magic_fields = self._metadata_fetch([haxis, 'Label'], label=label)

if vaxis in ('simpson', 'chao1', 'shannon'):
df.loc[:, vaxis] = self.alpha_diversity(vaxis, rank=rank)
Expand Down
18 changes: 14 additions & 4 deletions onecodex/viz/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class VizPCAMixin(object):
def plot_pca(self, rank='auto', normalize='auto', org_vectors=0, org_vectors_scale=None,
title=None, xlabel=None, ylabel=None, color=None, size=None, tooltip=None,
return_chart=False):
return_chart=False, label=None):
"""Perform principal component analysis and plot first two axes.
Parameters
Expand Down Expand Up @@ -39,6 +39,10 @@ def plot_pca(self, rank='auto', normalize='auto', org_vectors=0, org_vectors_sca
A string or list containing strings representing metadata fields. When a point in the
plot is hovered over, the value of the metadata associated with that sample will be
displayed in a modal.
label : `string` or `callable`, optional
A metadata field (or function) used to label each analysis. If passing a function, a
dict containing the metadata for each analysis is passed as the first and only
positional argument. The callable function must return a string.
Examples
--------
Expand Down Expand Up @@ -73,9 +77,15 @@ def plot_pca(self, rank='auto', normalize='auto', org_vectors=0, org_vectors_sca
else:
tooltip = []

tooltip = list(set(['Label', color, size] + tooltip))
tooltip.insert(0, 'Label')

magic_metadata, magic_fields = self._metadata_fetch(tooltip)
if color and color not in tooltip:
tooltip.insert(1, color)

if size and size not in tooltip:
tooltip.insert(2, size)

magic_metadata, magic_fields = self._metadata_fetch(tooltip, label=label)

pca = PCA()
pca_vals = pca.fit(df.values).transform(df.values)
Expand All @@ -94,7 +104,7 @@ def plot_pca(self, rank='auto', normalize='auto', org_vectors=0, org_vectors_sca
alt_kwargs = dict(
x=alt.X('PC1', axis=alt.Axis(title=xlabel)),
y=alt.Y('PC2', axis=alt.Axis(title=ylabel)),
tooltip=[magic_fields[t] for t in tooltip if t],
tooltip=[magic_fields[t] for t in tooltip],
href='url:N',
url='https://app.onecodex.com/classification/' + alt.datum.classification_id
)
Expand Down

0 comments on commit d5f3dfc

Please sign in to comment.