Skip to content

Commit

Permalink
Merge a621de0 into fd96da9
Browse files Browse the repository at this point in the history
  • Loading branch information
mortonjt committed Apr 23, 2017
2 parents fd96da9 + a621de0 commit 3e66028
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 47 deletions.
174 changes: 131 additions & 43 deletions gneiss/plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,28 @@
from q2_types.feature_table import FeatureTable
from qiime2.plugin import Int, MetadataCategory, Str, Choices


from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.plotting import figure, ColumnDataSource
from bokeh.layouts import row
from bokeh.models import HoverTool, BoxZoomTool, ResetTool
from bokeh.layouts import row, column
from bokeh.models import (HoverTool, BoxZoomTool, WheelZoomTool,
ResetTool, SaveTool, PanTool,
FuncTickFormatter, FixedTicker)
from bokeh.palettes import RdYlBu11 as palette


def _projected_prediction(model):
def _projected_prediction(model, plot_width=400, plot_height=400):
""" Create projected prediction plot
Parameters
----------
model : RegressionModel
Input regression model to plot prediction.
plot_width : int
Width of plot.
plot_height : int
Height of plot.
Returns
-------
Expand All @@ -54,8 +63,9 @@ def _projected_prediction(model):
raw = model.balances
raw['color'] = 'raw' # make raw values blue

p = figure(plot_width=400, plot_height=400,
tools=[hover, BoxZoomTool(), ResetTool()])
p = figure(plot_width=plot_width, plot_height=plot_height,
tools=[hover, BoxZoomTool(), ResetTool(),
WheelZoomTool(), SaveTool(), PanTool()])
raw_source = ColumnDataSource(raw)
pred_source = ColumnDataSource(pred)

Expand All @@ -66,15 +76,13 @@ def _projected_prediction(model):

p.title.text = 'Projected Prediction'
p.title_location = 'above'
p.title.align = 'center'
p.title.text_font_size = '18pt'

p.xaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[0], pcvar.iloc[0])
p.yaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[1], pcvar.iloc[1])
return p


def _projected_residuals(model):
def _projected_residuals(model, plot_width=400, plot_height=400):
""" Create projected residual plot
Parameters
Expand All @@ -92,22 +100,101 @@ def _projected_residuals(model):
)
pcvar = model.percent_explained()
resid = model.residuals()
p = figure(plot_width=400, plot_height=400,
tools=[hover, BoxZoomTool(), ResetTool()])
p = figure(plot_width=plot_width, plot_height=plot_height,
tools=[hover, BoxZoomTool(), ResetTool(),
WheelZoomTool(), SaveTool(), PanTool()])
resid_source = ColumnDataSource(resid)

p.circle(resid.columns[0], resid.columns[1], size=7,
source=resid_source, fill_color='blue', legend='residuals')

p.title.text = 'Projected Residuals'
p.title_location = 'above'
p.title.align = 'center'
p.title.text_font_size = '18pt'
p.xaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[0], pcvar.iloc[0])
p.yaxis.axis_label = '{} ({:.2%})'.format(pcvar.index[1], pcvar.iloc[1])
return p


def _heatmap_summary(pvals, coefs, plot_width=1200, plot_height=400):
""" Plots heatmap of coefficients colored by pvalues
Parameters
----------
pvals : pd.DataFrame
Table of pvalues where rows are balances and columns are
covariates.
coefs : pd.DataFrame
Table of coefficients where rows are balances and columns are
covariates.
plot_width : int
Width of plot.
plot_height : int
Height of plot.
Returns
-------
bokeh.charts.Heatmap
Heatmap summarizing the regression statistics.
"""
c = coefs.reset_index()
c = c.rename(columns={'index': 'balance'})
# log scale for coloring
log_p = -np.log10(pvals+1e-200)
log_p = log_p.reset_index()
log_p = log_p.rename(columns={'index': 'balance'})
p = pvals.reset_index()
p = p.rename(columns={'index': 'balance'})

cm = pd.melt(c, id_vars='balance', var_name='Covariate',
value_name='Coefficient')
pm = pd.melt(p, id_vars='balance', var_name='Covariate',
value_name='Pvalue')
logpm = pd.melt(log_p, id_vars='balance', var_name='Covariate',
value_name='log_Pvalue')
m = pd.merge(cm, pm)
m = pd.merge(m, logpm)
hover = HoverTool(
tooltips=[("Pvalue", "@Pvalue"),
("Coefficient", "@Coefficient")]
)

N, _min, _max = len(palette), m.log_Pvalue.min(), m.log_Pvalue.max()
X = pd.Series(np.arange(len(pvals.index)), index=pvals.index)
Y = pd.Series(np.arange(len(pvals.columns)), index=pvals.columns)
m['X'] = [X.loc[i] for i in m.balance]
m['Y'] = [Y.loc[i] for i in m.Covariate]

for i in m.index:
x = m.loc[i, 'log_Pvalue']
ind = int(np.floor((x - _min) / (_max - _min) * (N - 1)))
m.loc[i, 'color'] = palette[ind]

source = ColumnDataSource(ColumnDataSource.from_df(m))
hm = figure(title='Regression Coefficients Summary',
plot_width=1200, plot_height=400,
tools=[hover, PanTool(), BoxZoomTool(),
WheelZoomTool(), ResetTool(),
SaveTool()])
hm.rect(x='X', y='Y', width=1, height=1,
fill_color='color', line_color="white", source=source)
Xlabels = pd.Series(pvals.index, index=np.arange(len(pvals.index)))
Ylabels = pd.Series(pvals.columns, index=np.arange(len(pvals.columns)), )

hm.xaxis[0].ticker = FixedTicker(ticks=Xlabels.index)
hm.xaxis.formatter = FuncTickFormatter(code="""
var labels = %s;
return labels[tick];
""" % Xlabels.to_dict())

hm.yaxis[0].ticker = FixedTicker(ticks=Ylabels.index)
hm.yaxis.formatter = FuncTickFormatter(code="""
var labels = %s;
return labels[tick];
""" % Ylabels.to_dict())

return hm


def _decorate_tree(t, series):
""" Attaches some default values on the tree for plotting.
Expand Down Expand Up @@ -135,9 +222,6 @@ def _decorate_tree(t, series):

def _deposit_results(model, output_dir):
""" Store all of the core regression results into a folder. """
pred = model.predict()
pred.to_csv(os.path.join(output_dir, 'predicted.csv'),
header=True, index=True)
coefficients = model.coefficients()
coefficients.to_csv(os.path.join(output_dir, 'coefficients.csv'),
header=True, index=True)
Expand All @@ -153,6 +237,7 @@ def _deposit_results(model, output_dir):
balances = model.balances
balances.to_csv(os.path.join(output_dir, 'balances.csv'),
header=True, index=True)
model.tree.write(os.path.join(output_dir, 'tree.nwk'))


def _deposit_results_html(index_f):
Expand All @@ -172,6 +257,9 @@ def _deposit_results_html(index_f):
index_f.write(('<th>Residuals</th>\n'))
index_f.write(('<a href="residuals.csv">'
'Download as CSV</a><br>\n'))
index_f.write(('<th>Tree</th>\n'))
index_f.write(('<a href="tree.nwk">'
'Download as Newick</a><br>\n'))


# OLS summary
Expand All @@ -193,15 +281,15 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
cv = model.loo()
# Relative importance of explanatory variables
relimp = model.lovo()
w, h = 400, 400 # plot width and height
w, h = 300, 300 # plot width and height
# Histogram of model mean squared error from cross validation
mse_p = figure(title="Cross Validation Mean Squared Error",
plot_width=w, plot_height=h)
mse_hist, edges = np.histogram(cv.mse, density=True, bins=20)
mse_p.quad(top=mse_hist, bottom=0, left=edges[:-1], right=edges[1:],
fill_color="#FFFF00", line_color="#033649", fill_alpha=0.5,
legend='CV Mean Squared Error')
mse_p.ray(x=model.mse, y=0, length=h,
mse_p.ray(x=model.mse, y=0, length=h*max(mse_hist),
angle=1.57079633, color='red',
legend='Model Error', line_width=0.5)

Expand All @@ -216,13 +304,12 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
angle=1.57079633, color='red',
legend='Model Error', line_width=0.5)

cvp = row(mse_p, pred_p)

# Explained sum of squares
ess = pd.Series({r.model.endog_names: r.ess for r in model.results})
# Summary object
smry = model.summary(ndim=10)

# Summary object
smry = model.summary(ndim=ndim)
_deposit_results(model, output_dir)
t = _decorate_tree(model.tree, ess)

p1 = radialplot(t, edge_color='color', figsize=(800, 800))
Expand All @@ -232,13 +319,14 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
p1.title.text_font_size = '18pt'

# 2D scatter plot for prediction on PB
p2 = _projected_prediction(model)
p3 = _projected_residuals(model)

p23 = row(p2, p3)

_deposit_results(model, output_dir)

p2 = _projected_prediction(model, plot_width=w, plot_height=h)
p3 = _projected_residuals(model, plot_width=w, plot_height=h)
hm_p = _heatmap_summary(model.pvalues, model.coefficients())
# combine the cross validation, explained sum of squares tree and
# residual plots into a single plot

p = row(column(mse_p, pred_p), column(p2, p3), p1)
p = column(hm_p, p)
index_fp = os.path.join(output_dir, 'index.html')
with open(index_fp, 'w') as index_f:
index_f.write('<html><body>\n')
Expand All @@ -247,13 +335,9 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
index_f.write('<th>Relative importance</th>\n')
index_f.write(relimp.to_html())
_deposit_results_html(index_f)
index_f.write('<th>Cross Validation</th>')
cv_html = file_html(cvp, CDN, 'Cross Validation')
index_f.write(cv_html)
ess_tree_html = file_html(p1, CDN, 'Explained Sum of Squares')
index_f.write(ess_tree_html)
reg_smry_html = file_html(p23, CDN, 'Prediction and Residual plot')
index_f.write(reg_smry_html)

plot_html = file_html(p, CDN, 'Diagnostics')
index_f.write(plot_html)
index_f.write('</body></html>\n')


Expand Down Expand Up @@ -293,7 +377,7 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None:
# log likelihood
loglike = pd.Series({r.model.endog_names: r.model.loglike(r.params)
for r in model.results})

w, h = 300, 300 # plot width and height
# Summary object
smry = model.summary(ndim=10)

Expand All @@ -305,10 +389,16 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None:
p1.title.text_font_size = '18pt'

# 2D scatter plot for prediction on PB
p2 = _projected_prediction(model)
p3 = _projected_residuals(model)
p2 = _projected_prediction(model, plot_width=w, plot_height=h)
p3 = _projected_residuals(model, plot_width=w, plot_height=h)

hm_p = _heatmap_summary(model.pvalues, model.coefficients(),
plot_width=900, plot_height=400)

p23 = row(p2, p3)
# combine the cross validation, explained sum of squares tree and
# residual plots into a single plot
p = row(column(p2, p3), p1)
p = column(hm_p, p)

# Deposit all regression results
_deposit_results(model, output_dir)
Expand All @@ -319,10 +409,8 @@ def lme_summary(output_dir: str, model: LMEModel, ndim=10) -> None:
index_f.write('<h1>Simplicial Linear Mixed Effects Summary</h1>\n')
index_f.write(smry.as_html())
_deposit_results_html(index_f)
ess_tree_html = file_html(p1, CDN, 'Loglikelihood')
index_f.write(ess_tree_html)
reg_smry_html = file_html(p23, CDN, 'Prediction and Residual plot')
index_f.write(reg_smry_html)
diag_html = file_html(p, CDN, 'Diagnostic plots')
index_f.write(diag_html)
index_f.write('</body></html>\n')


Expand Down
8 changes: 5 additions & 3 deletions gneiss/plot/_radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from gneiss.plot._dendrogram import UnrootedDendrogram
from bokeh.models.glyphs import Circle, Segment
from bokeh.models import ColumnDataSource, DataRange1d, Plot
from bokeh.models import HoverTool, BoxZoomTool, ResetTool
from bokeh.models import (HoverTool, BoxZoomTool, ResetTool,
WheelZoomTool, SaveTool, PanTool)


def radialplot(tree, node_color='node_color', node_size='node_size',
Expand Down Expand Up @@ -59,7 +60,7 @@ def radialplot(tree, node_color='node_color', node_size='node_size',

# This entire function was motivated by
# http://chuckpr.github.io/blog/trees2.html
t = UnrootedDendrogram.from_tree(tree)
t = UnrootedDendrogram.from_tree(tree.copy())

nodes = t.coords(figsize[0], figsize[1])

Expand Down Expand Up @@ -123,6 +124,7 @@ def df2ds(df):
tooltip += [(hover_var, "@" + hover_var)]

hover = HoverTool(renderers=[ns], tooltips=tooltip)
plot.add_tools(hover, BoxZoomTool(), ResetTool())
plot.add_tools(hover, BoxZoomTool(), ResetTool(),
WheelZoomTool(), SaveTool(), PanTool())

return plot
3 changes: 2 additions & 1 deletion gneiss/plot/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,15 @@ def test_visualization(self):

with open(index_fp, 'r') as fh:
html = fh.read()

self.assertIn('<h1>Simplicial Linear Regression Summary</h1>',
html)
self.assertIn('<th>Relative importance</th>', html)
self.assertIn('<th>Cross Validation</th>', html)
self.assertIn('<th>Coefficients</th>\n', html)
self.assertIn('<th>Raw Balances</th>\n', html)
self.assertIn('<th>Predicted Proportions</th>\n', html)
self.assertIn('<th>Residuals</th>\n', html)
self.assertIn('<th>Tree</th>\n', html)


class TestLME_Summary(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions gneiss/plot/tests/test_radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_basic_plot(self):

self.assertDictEqual(p.renderers[0].data_source.data, exp_edges)
self.assertDictEqual(p.renderers[1].data_source.data, exp_nodes)
self.assertTrue(isinstance(t, TreeNode))


if __name__ == "__main__":
Expand Down

0 comments on commit 3e66028

Please sign in to comment.