diff --git a/gneiss/plot/_plot.py b/gneiss/plot/_plot.py index f0aeab3..d7ead42 100644 --- a/gneiss/plot/_plot.py +++ b/gneiss/plot/_plot.py @@ -24,19 +24,28 @@ from q2_types.feature_table import FeatureTable from qiime2.plugin import Int, MetadataCategory, Str, Choices + from bokeh.embed import file_html from bokeh.resources import CDN from bokeh.plotting import figure, ColumnDataSource -from bokeh.layouts import row -from bokeh.models import HoverTool, BoxZoomTool, ResetTool +from bokeh.layouts import row, column +from bokeh.models import (HoverTool, BoxZoomTool, WheelZoomTool, + ResetTool, SaveTool, PanTool, + FuncTickFormatter, FixedTicker) +from bokeh.palettes import RdYlBu11 as palette -def _projected_prediction(model): +def _projected_prediction(model, plot_width=400, plot_height=400): """ Create projected prediction plot Parameters ---------- model : RegressionModel + Input regression model to plot prediction. + plot_width : int + Width of plot. + plot_height : int + Height of plot. Returns ------- @@ -54,8 +63,9 @@ def _projected_prediction(model): raw = model.balances raw['color'] = 'raw' # make raw values blue - p = figure(plot_width=400, plot_height=400, - tools=[hover, BoxZoomTool(), ResetTool()]) + p = figure(plot_width=plot_width, plot_height=plot_height, + tools=[hover, BoxZoomTool(), ResetTool(), + WheelZoomTool(), SaveTool(), PanTool()]) raw_source = ColumnDataSource(raw) pred_source = ColumnDataSource(pred) @@ -66,15 +76,13 @@ def _projected_prediction(model): p.title.text = 'Projected Prediction' p.title_location = 'above' - p.title.align = 'center' - p.title.text_font_size = '18pt' p.xaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[0], pcvar.iloc[0]) p.yaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[1], pcvar.iloc[1]) return p -def _projected_residuals(model): +def _projected_residuals(model, plot_width=400, plot_height=400): """ Create projected residual plot Parameters @@ -92,8 +100,9 @@ def _projected_residuals(model): ) pcvar = model.percent_explained() resid = model.residuals() - p = figure(plot_width=400, plot_height=400, - tools=[hover, BoxZoomTool(), ResetTool()]) + p = figure(plot_width=plot_width, plot_height=plot_height, + tools=[hover, BoxZoomTool(), ResetTool(), + WheelZoomTool(), SaveTool(), PanTool()]) resid_source = ColumnDataSource(resid) p.circle(resid.columns[0], resid.columns[1], size=7, @@ -101,13 +110,91 @@ def _projected_residuals(model): p.title.text = 'Projected Residuals' p.title_location = 'above' - p.title.align = 'center' - p.title.text_font_size = '18pt' p.xaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[0], pcvar.iloc[0]) p.yaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[1], pcvar.iloc[1]) return p +def _heatmap_summary(pvals, coefs, plot_width=1200, plot_height=400): + """ Plots heatmap of coefficients colored by pvalues + + Parameters + ---------- + pvals : pd.DataFrame + Table of pvalues where rows are balances and columns are + covariates. + coefs : pd.DataFrame + Table of coefficients where rows are balances and columns are + covariates. + plot_width : int + Width of plot. + plot_height : int + Height of plot. + + Returns + ------- + bokeh.charts.Heatmap + Heatmap summarizing the regression statistics. + """ + c = coefs.reset_index() + c = c.rename(columns={'index': 'balance'}) + # log scale for coloring + log_p = -np.log10(pvals+1e-200) + log_p = log_p.reset_index() + log_p = log_p.rename(columns={'index': 'balance'}) + p = pvals.reset_index() + p = p.rename(columns={'index': 'balance'}) + + cm = pd.melt(c, id_vars='balance', var_name='Covariate', + value_name='Coefficient') + pm = pd.melt(p, id_vars='balance', var_name='Covariate', + value_name='Pvalue') + logpm = pd.melt(log_p, id_vars='balance', var_name='Covariate', + value_name='log_Pvalue') + m = pd.merge(cm, pm) + m = pd.merge(m, logpm) + hover = HoverTool( + tooltips=[("Pvalue", "@Pvalue"), + ("Coefficient", "@Coefficient")] + ) + + N, _min, _max = len(palette), m.log_Pvalue.min(), m.log_Pvalue.max() + X = pd.Series(np.arange(len(pvals.index)), index=pvals.index) + Y = pd.Series(np.arange(len(pvals.columns)), index=pvals.columns) + m['X'] = [X.loc[i] for i in m.balance] + m['Y'] = [Y.loc[i] for i in m.Covariate] + + for i in m.index: + x = m.loc[i, 'log_Pvalue'] + ind = int(np.floor((x - _min) / (_max - _min) * (N - 1))) + m.loc[i, 'color'] = palette[ind] + + source = ColumnDataSource(ColumnDataSource.from_df(m)) + hm = figure(title='Regression Coefficients Summary', + plot_width=1200, plot_height=400, + tools=[hover, PanTool(), BoxZoomTool(), + WheelZoomTool(), ResetTool(), + SaveTool()]) + hm.rect(x='X', y='Y', width=1, height=1, + fill_color='color', line_color="white", source=source) + Xlabels = pd.Series(pvals.index, index=np.arange(len(pvals.index))) + Ylabels = pd.Series(pvals.columns, index=np.arange(len(pvals.columns)), ) + + hm.xaxis[0].ticker = FixedTicker(ticks=Xlabels.index) + hm.xaxis.formatter = FuncTickFormatter(code=""" + var labels = %s; + return labels[tick]; + """ % Xlabels.to_dict()) + + hm.yaxis[0].ticker = FixedTicker(ticks=Ylabels.index) + hm.yaxis.formatter = FuncTickFormatter(code=""" + var labels = %s; + return labels[tick]; + """ % Ylabels.to_dict()) + + return hm + + def _decorate_tree(t, series): """ Attaches some default values on the tree for plotting. @@ -135,9 +222,6 @@ def _decorate_tree(t, series): def _deposit_results(model, output_dir): """ Store all of the core regression results into a folder. """ - pred = model.predict() - pred.to_csv(os.path.join(output_dir, 'predicted.csv'), - header=True, index=True) coefficients = model.coefficients() coefficients.to_csv(os.path.join(output_dir, 'coefficients.csv'), header=True, index=True) @@ -153,6 +237,7 @@ def _deposit_results(model, output_dir): balances = model.balances balances.to_csv(os.path.join(output_dir, 'balances.csv'), header=True, index=True) + model.tree.write(os.path.join(output_dir, 'tree.nwk')) def _deposit_results_html(index_f): @@ -172,6 +257,9 @@ def _deposit_results_html(index_f): index_f.write(('Residuals\n')) index_f.write(('' 'Download as CSV
\n')) + index_f.write(('Tree\n')) + index_f.write(('' + 'Download as Newick
\n')) # OLS summary @@ -193,7 +281,7 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None: cv = model.loo() # Relative importance of explanatory variables relimp = model.lovo() - w, h = 400, 400 # plot width and height + w, h = 300, 300 # plot width and height # Histogram of model mean squared error from cross validation mse_p = figure(title="Cross Validation Mean Squared Error", plot_width=w, plot_height=h) @@ -201,7 +289,7 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None: mse_p.quad(top=mse_hist, bottom=0, left=edges[:-1], right=edges[1:], fill_color="#FFFF00", line_color="#033649", fill_alpha=0.5, legend='CV Mean Squared Error') - mse_p.ray(x=model.mse, y=0, length=h, + mse_p.ray(x=model.mse, y=0, length=h*max(mse_hist), angle=1.57079633, color='red', legend='Model Error', line_width=0.5) @@ -216,13 +304,12 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None: angle=1.57079633, color='red', legend='Model Error', line_width=0.5) - cvp = row(mse_p, pred_p) - # Explained sum of squares ess = pd.Series({r.model.endog_names: r.ess for r in model.results}) - # Summary object - smry = model.summary(ndim=10) + # Summary object + smry = model.summary(ndim=ndim) + _deposit_results(model, output_dir) t = _decorate_tree(model.tree, ess) p1 = radialplot(t, edge_color='color', figsize=(800, 800)) @@ -232,13 +319,14 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None: p1.title.text_font_size = '18pt' # 2D scatter plot for prediction on PB - p2 = _projected_prediction(model) - p3 = _projected_residuals(model) - - p23 = row(p2, p3) - - _deposit_results(model, output_dir) - + p2 = _projected_prediction(model, plot_width=w, plot_height=h) + p3 = _projected_residuals(model, plot_width=w, plot_height=h) + hm_p = _heatmap_summary(model.pvalues, model.coefficients()) + # combine the cross validation, explained sum of squares tree and + # residual plots into a single plot + + p = row(column(mse_p, pred_p), column(p2, p3), p1) + p = column(hm_p, p) index_fp = os.path.join(output_dir, 'index.html') with open(index_fp, 'w') as index_f: index_f.write('\n') @@ -247,13 +335,9 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None: index_f.write('Relative importance\n') index_f.write(relimp.to_html()) _deposit_results_html(index_f) - index_f.write('Cross Validation') - cv_html = file_html(cvp, CDN, 'Cross Validation') - index_f.write(cv_html) - ess_tree_html = file_html(p1, CDN, 'Explained Sum of Squares') - index_f.write(ess_tree_html) - reg_smry_html = file_html(p23, CDN, 'Prediction and Residual plot') - index_f.write(reg_smry_html) + + plot_html = file_html(p, CDN, 'Diagnostics') + index_f.write(plot_html) index_f.write('\n') @@ -293,7 +377,7 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None: # log likelihood loglike = pd.Series({r.model.endog_names: r.model.loglike(r.params) for r in model.results}) - + w, h = 300, 300 # plot width and height # Summary object smry = model.summary(ndim=10) @@ -305,10 +389,16 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None: p1.title.text_font_size = '18pt' # 2D scatter plot for prediction on PB - p2 = _projected_prediction(model) - p3 = _projected_residuals(model) + p2 = _projected_prediction(model, plot_width=w, plot_height=h) + p3 = _projected_residuals(model, plot_width=w, plot_height=h) + + hm_p = _heatmap_summary(model.pvalues, model.coefficients(), + plot_width=900, plot_height=400) - p23 = row(p2, p3) + # combine the cross validation, explained sum of squares tree and + # residual plots into a single plot + p = row(column(p2, p3), p1) + p = column(hm_p, p) # Deposit all regression results _deposit_results(model, output_dir) @@ -319,10 +409,8 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None: index_f.write('

Simplicial Linear Mixed Effects Summary

\n') index_f.write(smry.as_html()) _deposit_results_html(index_f) - ess_tree_html = file_html(p1, CDN, 'Loglikelihood') - index_f.write(ess_tree_html) - reg_smry_html = file_html(p23, CDN, 'Prediction and Residual plot') - index_f.write(reg_smry_html) + diag_html = file_html(p, CDN, 'Diagnostic plots') + index_f.write(diag_html) index_f.write('\n') diff --git a/gneiss/plot/_radial.py b/gneiss/plot/_radial.py index dd45395..55dfcc9 100644 --- a/gneiss/plot/_radial.py +++ b/gneiss/plot/_radial.py @@ -9,7 +9,8 @@ from gneiss.plot._dendrogram import UnrootedDendrogram from bokeh.models.glyphs import Circle, Segment from bokeh.models import ColumnDataSource, DataRange1d, Plot -from bokeh.models import HoverTool, BoxZoomTool, ResetTool +from bokeh.models import (HoverTool, BoxZoomTool, ResetTool, + WheelZoomTool, SaveTool, PanTool) def radialplot(tree, node_color='node_color', node_size='node_size', @@ -59,7 +60,7 @@ def radialplot(tree, node_color='node_color', node_size='node_size', # This entire function was motivated by # http://chuckpr.github.io/blog/trees2.html - t = UnrootedDendrogram.from_tree(tree) + t = UnrootedDendrogram.from_tree(tree.copy()) nodes = t.coords(figsize[0], figsize[1]) @@ -123,6 +124,7 @@ def df2ds(df): tooltip += [(hover_var, "@" + hover_var)] hover = HoverTool(renderers=[ns], tooltips=tooltip) - plot.add_tools(hover, BoxZoomTool(), ResetTool()) + plot.add_tools(hover, BoxZoomTool(), ResetTool(), + WheelZoomTool(), SaveTool(), PanTool()) return plot diff --git a/gneiss/plot/tests/test_plot.py b/gneiss/plot/tests/test_plot.py index 980c626..396cb53 100644 --- a/gneiss/plot/tests/test_plot.py +++ b/gneiss/plot/tests/test_plot.py @@ -178,14 +178,15 @@ def test_visualization(self): with open(index_fp, 'r') as fh: html = fh.read() + self.assertIn('

Simplicial Linear Regression Summary

', html) self.assertIn('Relative importance', html) - self.assertIn('Cross Validation', html) self.assertIn('Coefficients\n', html) self.assertIn('Raw Balances\n', html) self.assertIn('Predicted Proportions\n', html) self.assertIn('Residuals\n', html) + self.assertIn('Tree\n', html) class TestLME_Summary(unittest.TestCase): diff --git a/gneiss/plot/tests/test_radial.py b/gneiss/plot/tests/test_radial.py index acd2cf6..f08a618 100644 --- a/gneiss/plot/tests/test_radial.py +++ b/gneiss/plot/tests/test_radial.py @@ -87,6 +87,7 @@ def test_basic_plot(self): self.assertDictEqual(p.renderers[0].data_source.data, exp_edges) self.assertDictEqual(p.renderers[1].data_source.data, exp_nodes) + self.assertTrue(isinstance(t, TreeNode)) if __name__ == "__main__":