Skip to content

Commit

Permalink
Merge 1a74beb into a14ab26
Browse files Browse the repository at this point in the history
  • Loading branch information
mortonjt committed Apr 12, 2017
2 parents a14ab26 + 1a74beb commit 31c6c33
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 21 deletions.
2 changes: 1 addition & 1 deletion ci/conda_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ IPython>4.0.0
notebook
scikit-bio=0.5.1
pyqt=4.11.4
bokeh
bokeh=0.12.4
24 changes: 15 additions & 9 deletions gneiss/plot/_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from gneiss.util import match_tips


def heatmap(table, tree, mdvar, highlights=None,
grid_col='w', grid_width=2,
highlight_width=0.02, figsize=(5, 5)):
def heatmap(table, tree, mdvar, highlights=None, cmap='viridis',
linewidth=0.5, grid_col='w', grid_width=2, highlight_width=0.02,
figsize=(5, 5)):
""" Creates heatmap plotting object
Parameters
Expand All @@ -31,6 +31,10 @@ def heatmap(table, tree, mdvar, highlights=None,
subtree and the other for the right subtree highlight.
The first color will always correspond to the left subtree,
and the second color will always correspond to the right subtree.
cmap : str
Specifies the matplotlib colormap for the heatmap (default='viridis')
linewidth : int
Width of dendrogram lines.
mdvar: pd.Series
Metadata values for samples. The index must correspond to the
index of `table`.
Expand Down Expand Up @@ -98,12 +102,12 @@ def heatmap(table, tree, mdvar, highlights=None,

# plot heatmap
ax_heatmap = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True)
_plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width)
_plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width, cmap)

# plot dendrogram
ax_dendrogram = fig.add_axes([axm_x, axm_y, axm_w, axm_h],
frame_on=True, sharey=ax_heatmap)
_plot_dendrogram(ax_dendrogram, table, edges)
_plot_dendrogram(ax_dendrogram, table, edges, linewidth=linewidth)

# plot highlights for dendrogram
if highlights is not None:
Expand All @@ -114,6 +118,8 @@ def heatmap(table, tree, mdvar, highlights=None,
return fig


# TODO: Refactor and place in utils. This can be also
# be used for the balance_basis calculations
def _tree_coordinates(t):
""" Builds a matrix to link tree positions to matrix"""
# first traverse the tree to count the children
Expand Down Expand Up @@ -176,7 +182,7 @@ def _plot_highlights_dendrogram(ax_highlights, table, t, highlights):
ax_highlights.set_xticklabels(highlights.index, rotation=90)


def _plot_dendrogram(ax_dendrogram, table, edges):
def _plot_dendrogram(ax_dendrogram, table, edges, linewidth=1):
""" Plots the actual dendrogram.
Parameters
Expand All @@ -193,7 +199,7 @@ def _plot_dendrogram(ax_dendrogram, table, edges):
for i in range(len(edges.index)):
row = edges.iloc[i]
ax_dendrogram.plot([row.x0, row.x1],
[row.y0-offset, row.y1-offset], '-k')
[row.y0-offset, row.y1-offset], '-k', lw=linewidth)
ax_dendrogram.set_ylim([-offset, table.shape[0]-offset])
ax_dendrogram.set_yticks([])
ax_dendrogram.set_xticks([])
Expand Down Expand Up @@ -223,7 +229,7 @@ def _sort_table(table, mdvar):
return table, mdvar


def _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width):
def _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width, cmap):
""" Sorts metadata category and aligns with table.
Parameters
Expand All @@ -248,7 +254,7 @@ def _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width):
table, mdvar = _sort_table(table, mdvar)
table = table.iloc[::-1, :]
ax_heatmap.imshow(table, aspect='auto', interpolation='nearest',
cmap='viridis')
cmap=cmap)
ax_heatmap.set_ylim([0, table.shape[0]])
vcounts = mdvar.value_counts()

Expand Down
53 changes: 42 additions & 11 deletions gneiss/plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from q2_types.tree import Phylogeny, Rooted
from q2_composition.plugin_setup import Composition
from q2_types.feature_table import FeatureTable
from qiime2.plugin import Int, MetadataCategory
from qiime2.plugin import Int, MetadataCategory, Str, Choices

from bokeh.embed import file_html
from bokeh.resources import CDN
Expand Down Expand Up @@ -343,24 +343,48 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None:
"predicted fit and residuals")
)

_transform_methods = ['clr', 'log']
_mpl_colormaps = ['viridis', 'inferno', 'plasma', 'magma',
'Blues', 'BuGn', 'BuPu',
'GnBu', 'Greens', 'Greys', 'Oranges', 'OrRd',
'PuBu', 'PuBuGn', 'PuRd', 'Purples', 'RdPu',
'Reds', 'YlGn', 'YlGnBu', 'YlOrBr', 'YlOrRd',
'afmhot', 'autumn', 'bone', 'cool',
'copper', 'gist_heat', 'gray', 'hot',
'pink', 'spring', 'summer', 'winter',
'BrBG', 'bwr', 'coolwarm', 'PiYG', 'PRGn', 'PuOr',
'RdBu', 'RdGy', 'RdYlBu', 'RdYlGn', 'Spectral',
'seismic', 'Accent', 'Dark2', 'Paired', 'Pastel1',
'Pastel2', 'Set1', 'Set2', 'Set3', 'Vega10',
'Vega20', 'Vega20b', 'Vega20c',
'gist_earth', 'terrain', 'ocean', 'gist_stern',
'brg', 'CMRmap', 'cubehelix',
'gnuplot', 'gnuplot2', 'gist_ncar',
'nipy_spectral', 'jet', 'rainbow',
'gist_rainbow', 'hsv', 'flag', 'prism']


# Heatmap
def dendrogram_heatmap(output_dir: str, table: pd.DataFrame,
tree: TreeNode, metadata: MetadataCategory,
ndim=10):
ndim=10, method='clr', color_map='viridis'):

nodes = [n.name for n in tree.levelorder()]
nlen = min(ndim, len(nodes))
highlights = pd.DataFrame([['#00FF00', '#FF0000']] * nlen,
index=nodes[:nlen])

mat = pd.DataFrame(clr(centralize(table)),
index=table.index,
columns=table.columns)
if method == 'clr':
mat = pd.DataFrame(clr(centralize(table)),
index=table.index,
columns=table.columns)
elif method == 'log':
mat = pd.DataFrame(np.log(table),
index=table.index,
columns=table.columns)

# TODO: There are a few hard-coded constants here
# will need to have some adaptive defaults set in the future
fig = heatmap(mat, tree, metadata.to_series(), highlights,
fig = heatmap(mat, tree, metadata.to_series(), highlights, cmap=color_map,
highlight_width=0.01, figsize=(12, 8))
fig.savefig(os.path.join(output_dir, 'heatmap.svg'))

Expand All @@ -376,7 +400,9 @@ def dendrogram_heatmap(output_dir: str, table: pd.DataFrame,
function=dendrogram_heatmap,
inputs={'table': FeatureTable[Composition],
'tree': Phylogeny[Rooted]},
parameters={'metadata': MetadataCategory, 'ndim': Int},
parameters={'metadata': MetadataCategory, 'ndim': Int,
'method': Str % Choices(_transform_methods),
'color_map': Str % Choices(_mpl_colormaps)},
input_descriptions={
'table': ('The feature table that will be plotted as a heatmap. '
'This table is assumed to have strictly positive values.'),
Expand All @@ -387,10 +413,15 @@ def dendrogram_heatmap(output_dir: str, table: pd.DataFrame,
'present in this tree.')},
parameter_descriptions={
'metadata': ('Metadata to group the samples. '),
'ndim': 'Number of dimensions to highlight.'},
'ndim': 'Number of dimensions to highlight.',
'method': ("Specifies how the data should be normalized for display."
"Options include 'log' or 'clr' (default='clr')."),
'color_map': ("Specifies the color map for plotting the heatmap. "
"See https://matplotlib.org/examples/color/"
"colormaps_reference.html for more details.")
},
name='Dendrogram heatmap.',
description=("Visualize the feature tables as a heatmap. "
"with samples sorted along a specified metadata category "
"and features clustered together specified by the tree."
"In addition, the heatmap values are clr transformed.")
"and features clustered together specified by the tree.")
)
10 changes: 10 additions & 0 deletions gneiss/plot/tests/test_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,21 @@ def test_basic(self):
res = str(fig.get_axes()[0].get_xlabel())
self.assertEqual(res, "None")

def test_basic_line_width(self):
fig = heatmap(self.table, self.t, self.md,
figsize=(5, self.table.shape[0]), linewidth=1)

# Test to see if the lineages of the tree are ok
lines = list(fig.get_axes()[1].get_lines())
widths = [l.get_lw() for l in lines]
np.allclose(widths, [1.0] * len(widths))

def test_basic_highlights(self):
fig = heatmap(self.table, self.t, self.md, self.highlights)

# Test to see if the lineages of the tree are ok
lines = list(fig.get_axes()[1].get_lines())

pts = self.t.coords(width=20, height=self.table.shape[0])
pts['y'] = pts['y'] - 0.5 # account for offset
pts['x'] = pts['x'].astype(np.float)
Expand Down

0 comments on commit 31c6c33

Please sign in to comment.