diff --git a/gneiss/plot/_plot.py b/gneiss/plot/_plot.py index f0aeab3..d7ead42 100644 --- a/gneiss/plot/_plot.py +++ b/gneiss/plot/_plot.py @@ -24,19 +24,28 @@ from q2_types.feature_table import FeatureTable from qiime2.plugin import Int, MetadataCategory, Str, Choices + from bokeh.embed import file_html from bokeh.resources import CDN from bokeh.plotting import figure, ColumnDataSource -from bokeh.layouts import row -from bokeh.models import HoverTool, BoxZoomTool, ResetTool +from bokeh.layouts import row, column +from bokeh.models import (HoverTool, BoxZoomTool, WheelZoomTool, + ResetTool, SaveTool, PanTool, + FuncTickFormatter, FixedTicker) +from bokeh.palettes import RdYlBu11 as palette -def _projected_prediction(model): +def _projected_prediction(model, plot_width=400, plot_height=400): """ Create projected prediction plot Parameters ---------- model : RegressionModel + Input regression model to plot prediction. + plot_width : int + Width of plot. + plot_height : int + Height of plot. Returns ------- @@ -54,8 +63,9 @@ def _projected_prediction(model): raw = model.balances raw['color'] = 'raw' # make raw values blue - p = figure(plot_width=400, plot_height=400, - tools=[hover, BoxZoomTool(), ResetTool()]) + p = figure(plot_width=plot_width, plot_height=plot_height, + tools=[hover, BoxZoomTool(), ResetTool(), + WheelZoomTool(), SaveTool(), PanTool()]) raw_source = ColumnDataSource(raw) pred_source = ColumnDataSource(pred) @@ -66,15 +76,13 @@ def _projected_prediction(model): p.title.text = 'Projected Prediction' p.title_location = 'above' - p.title.align = 'center' - p.title.text_font_size = '18pt' p.xaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[0], pcvar.iloc[0]) p.yaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[1], pcvar.iloc[1]) return p -def _projected_residuals(model): +def _projected_residuals(model, plot_width=400, plot_height=400): """ Create projected residual plot Parameters @@ -92,8 +100,9 @@ def _projected_residuals(model): ) pcvar = model.percent_explained() resid = model.residuals() - p = figure(plot_width=400, plot_height=400, - tools=[hover, BoxZoomTool(), ResetTool()]) + p = figure(plot_width=plot_width, plot_height=plot_height, + tools=[hover, BoxZoomTool(), ResetTool(), + WheelZoomTool(), SaveTool(), PanTool()]) resid_source = ColumnDataSource(resid) p.circle(resid.columns[0], resid.columns[1], size=7, @@ -101,13 +110,91 @@ def _projected_residuals(model): p.title.text = 'Projected Residuals' p.title_location = 'above' - p.title.align = 'center' - p.title.text_font_size = '18pt' p.xaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[0], pcvar.iloc[0]) p.yaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[1], pcvar.iloc[1]) return p +def _heatmap_summary(pvals, coefs, plot_width=1200, plot_height=400): + """ Plots heatmap of coefficients colored by pvalues + + Parameters + ---------- + pvals : pd.DataFrame + Table of pvalues where rows are balances and columns are + covariates. + coefs : pd.DataFrame + Table of coefficients where rows are balances and columns are + covariates. + plot_width : int + Width of plot. + plot_height : int + Height of plot. + + Returns + ------- + bokeh.charts.Heatmap + Heatmap summarizing the regression statistics. + """ + c = coefs.reset_index() + c = c.rename(columns={'index': 'balance'}) + # log scale for coloring + log_p = -np.log10(pvals+1e-200) + log_p = log_p.reset_index() + log_p = log_p.rename(columns={'index': 'balance'}) + p = pvals.reset_index() + p = p.rename(columns={'index': 'balance'}) + + cm = pd.melt(c, id_vars='balance', var_name='Covariate', + value_name='Coefficient') + pm = pd.melt(p, id_vars='balance', var_name='Covariate', + value_name='Pvalue') + logpm = pd.melt(log_p, id_vars='balance', var_name='Covariate', + value_name='log_Pvalue') + m = pd.merge(cm, pm) + m = pd.merge(m, logpm) + hover = HoverTool( + tooltips=[("Pvalue", "@Pvalue"), + ("Coefficient", "@Coefficient")] + ) + + N, _min, _max = len(palette), m.log_Pvalue.min(), m.log_Pvalue.max() + X = pd.Series(np.arange(len(pvals.index)), index=pvals.index) + Y = pd.Series(np.arange(len(pvals.columns)), index=pvals.columns) + m['X'] = [X.loc[i] for i in m.balance] + m['Y'] = [Y.loc[i] for i in m.Covariate] + + for i in m.index: + x = m.loc[i, 'log_Pvalue'] + ind = int(np.floor((x - _min) / (_max - _min) * (N - 1))) + m.loc[i, 'color'] = palette[ind] + + source = ColumnDataSource(ColumnDataSource.from_df(m)) + hm = figure(title='Regression Coefficients Summary', + plot_width=1200, plot_height=400, + tools=[hover, PanTool(), BoxZoomTool(), + WheelZoomTool(), ResetTool(), + SaveTool()]) + hm.rect(x='X', y='Y', width=1, height=1, + fill_color='color', line_color="white", source=source) + Xlabels = pd.Series(pvals.index, index=np.arange(len(pvals.index))) + Ylabels = pd.Series(pvals.columns, index=np.arange(len(pvals.columns)), ) + + hm.xaxis[0].ticker = FixedTicker(ticks=Xlabels.index) + hm.xaxis.formatter = FuncTickFormatter(code=""" + var labels = %s; + return labels[tick]; + """ % Xlabels.to_dict()) + + hm.yaxis[0].ticker = FixedTicker(ticks=Ylabels.index) + hm.yaxis.formatter = FuncTickFormatter(code=""" + var labels = %s; + return labels[tick]; + """ % Ylabels.to_dict()) + + return hm + + def _decorate_tree(t, series): """ Attaches some default values on the tree for plotting. @@ -135,9 +222,6 @@ def _decorate_tree(t, series): def _deposit_results(model, output_dir): """ Store all of the core regression results into a folder. """ - pred = model.predict() - pred.to_csv(os.path.join(output_dir, 'predicted.csv'), - header=True, index=True) coefficients = model.coefficients() coefficients.to_csv(os.path.join(output_dir, 'coefficients.csv'), header=True, index=True) @@ -153,6 +237,7 @@ def _deposit_results(model, output_dir): balances = model.balances balances.to_csv(os.path.join(output_dir, 'balances.csv'), header=True, index=True) + model.tree.write(os.path.join(output_dir, 'tree.nwk')) def _deposit_results_html(index_f): @@ -172,6 +257,9 @@ def _deposit_results_html(index_f): index_f.write(('