Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STY: Refactoring OLS summary #166

Merged
merged 2 commits into from
Apr 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 97 additions & 31 deletions gneiss/plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,27 @@
from q2_types.feature_table import FeatureTable
from qiime2.plugin import Int, MetadataCategory, Str, Choices


from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.plotting import figure, ColumnDataSource
from bokeh.layouts import row
from bokeh.models import HoverTool, BoxZoomTool, ResetTool
from bokeh.layouts import row, column
from bokeh.models import (HoverTool, BoxZoomTool, WheelZoomTool,
ResetTool, SaveTool, PanTool)
from bokeh.charts import HeatMap


def _projected_prediction(model):
def _projected_prediction(model, plot_width=400, plot_height=400):
""" Create projected prediction plot

Parameters
----------
model : RegressionModel
Input regression model to plot prediction.
plot_width : int
Width of plot.
plot_height : int
Height of plot.

Returns
-------
Expand All @@ -54,8 +62,9 @@ def _projected_prediction(model):
raw = model.balances
raw['color'] = 'raw' # make raw values blue

p = figure(plot_width=400, plot_height=400,
tools=[hover, BoxZoomTool(), ResetTool()])
p = figure(plot_width=plot_width, plot_height=plot_height,
tools=[hover, BoxZoomTool(), ResetTool(),
WheelZoomTool(), SaveTool()])
raw_source = ColumnDataSource(raw)
pred_source = ColumnDataSource(pred)

Expand All @@ -74,7 +83,7 @@ def _projected_prediction(model):
return p


def _projected_residuals(model):
def _projected_residuals(model, plot_width=400, plot_height=400):
""" Create projected residual plot

Parameters
Expand All @@ -92,8 +101,9 @@ def _projected_residuals(model):
)
pcvar = model.percent_explained()
resid = model.residuals()
p = figure(plot_width=400, plot_height=400,
tools=[hover, BoxZoomTool(), ResetTool()])
p = figure(plot_width=plot_width, plot_height=plot_height,
tools=[hover, BoxZoomTool(), ResetTool(),
WheelZoomTool(), SaveTool()])
resid_source = ColumnDataSource(resid)

p.circle(resid.columns[0], resid.columns[1], size=7,
Expand All @@ -108,6 +118,65 @@ def _projected_residuals(model):
return p


def _heatmap_summary(pvals, coefs, cmap='viridis',
plot_width=1200, plot_height=400):
""" Plots heatmap of coefficients colored by pvalues

Parameters
----------
pvals : pd.DataFrame
Table of pvalues where rows are balances and columns are
covariates.
coefs : pd.DataFrame
Table of coefficients where rows are balances and columns are
covariates.
cmap : str
Color scheme to plot heatmap.
plot_width : int
Width of plot.
plot_height : int
Height of plot.

Returns
-------
bokeh.charts.Heatmap
Heatmap summarizing the regression statistics.
"""
c = coefs.reset_index()
c = c.rename(columns={'index': 'balance'})
# log scale for coloring
log_p = -np.log10(pvals+1e-200)
log_p = log_p.reset_index()
log_p = log_p.rename(columns={'index': 'balance'})
p = pvals.reset_index()
p = p.rename(columns={'index': 'balance'})

cm = pd.melt(c, id_vars='balance', var_name='Covariate',
value_name='Coefficient')
pm = pd.melt(p, id_vars='balance', var_name='Covariate',
value_name='Pvalue')
logpm = pd.melt(log_p, id_vars='balance', var_name='Covariate',
value_name='log_Pvalue')
m = pd.merge(cm, pm)
m = pd.merge(m, logpm)

hover = HoverTool(
tooltips=[
("log_Pvalue", "@log_Pvalue"),
]
)
source = ColumnDataSource(ColumnDataSource.from_df(m))

hm = HeatMap(m, x='balance', y='Covariate', values='log_Pvalue',
title='Regression Coefficients Summary',
sort_dim={'x': False}, width=plot_width, stat=None,
plot_height=plot_height, legend=False, source=source,
tools=[hover, PanTool(), BoxZoomTool(), WheelZoomTool(),
ResetTool(), SaveTool()])

return hm


def _decorate_tree(t, series):
""" Attaches some default values on the tree for plotting.

Expand Down Expand Up @@ -135,9 +204,6 @@ def _decorate_tree(t, series):

def _deposit_results(model, output_dir):
""" Store all of the core regression results into a folder. """
pred = model.predict()
pred.to_csv(os.path.join(output_dir, 'predicted.csv'),
header=True, index=True)
coefficients = model.coefficients()
coefficients.to_csv(os.path.join(output_dir, 'coefficients.csv'),
header=True, index=True)
Expand All @@ -153,6 +219,7 @@ def _deposit_results(model, output_dir):
balances = model.balances
balances.to_csv(os.path.join(output_dir, 'balances.csv'),
header=True, index=True)
model.tree.write(os.path.join(output_dir, 'tree.nwk'))


def _deposit_results_html(index_f):
Expand All @@ -172,6 +239,9 @@ def _deposit_results_html(index_f):
index_f.write(('<th>Residuals</th>\n'))
index_f.write(('<a href="residuals.csv">'
'Download as CSV</a><br>\n'))
index_f.write(('<th>Tree</th>\n'))
index_f.write(('<a href="tree.nwk">'
'Download as Newick</a><br>\n'))


# OLS summary
Expand All @@ -193,15 +263,15 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
cv = model.loo()
# Relative importance of explanatory variables
relimp = model.lovo()
w, h = 400, 400 # plot width and height
w, h = 300, 300 # plot width and height
# Histogram of model mean squared error from cross validation
mse_p = figure(title="Cross Validation Mean Squared Error",
plot_width=w, plot_height=h)
mse_hist, edges = np.histogram(cv.mse, density=True, bins=20)
mse_p.quad(top=mse_hist, bottom=0, left=edges[:-1], right=edges[1:],
fill_color="#FFFF00", line_color="#033649", fill_alpha=0.5,
legend='CV Mean Squared Error')
mse_p.ray(x=model.mse, y=0, length=h,
mse_p.ray(x=model.mse, y=0, length=h*max(mse_hist),
angle=1.57079633, color='red',
legend='Model Error', line_width=0.5)

Expand All @@ -216,13 +286,12 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
angle=1.57079633, color='red',
legend='Model Error', line_width=0.5)

cvp = row(mse_p, pred_p)

# Explained sum of squares
ess = pd.Series({r.model.endog_names: r.ess for r in model.results})
# Summary object
smry = model.summary(ndim=10)

# Summary object
smry = model.summary(ndim=0)
_deposit_results(model, output_dir)
t = _decorate_tree(model.tree, ess)

p1 = radialplot(t, edge_color='color', figsize=(800, 800))
Expand All @@ -232,13 +301,14 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
p1.title.text_font_size = '18pt'

# 2D scatter plot for prediction on PB
p2 = _projected_prediction(model)
p3 = _projected_residuals(model)

p23 = row(p2, p3)

_deposit_results(model, output_dir)

p2 = _projected_prediction(model, plot_width=w, plot_height=h)
p3 = _projected_residuals(model, plot_width=w, plot_height=h)
hm_p = _heatmap_summary(model.pvalues, model.coefficients())
# combine the cross validation, explained sum of squares tree and
# residual plots into a single plot
p = column(row(mse_p, pred_p), row(p2, p3))
p = row(p, p1)
p = column(hm_p, p)
index_fp = os.path.join(output_dir, 'index.html')
with open(index_fp, 'w') as index_f:
index_f.write('<html><body>\n')
Expand All @@ -247,13 +317,9 @@ def ols_summary(output_dir: str, model: OLSModel, ndim=10) -> None:
index_f.write('<th>Relative importance</th>\n')
index_f.write(relimp.to_html())
_deposit_results_html(index_f)
index_f.write('<th>Cross Validation</th>')
cv_html = file_html(cvp, CDN, 'Cross Validation')
index_f.write(cv_html)
ess_tree_html = file_html(p1, CDN, 'Explained Sum of Squares')
index_f.write(ess_tree_html)
reg_smry_html = file_html(p23, CDN, 'Prediction and Residual plot')
index_f.write(reg_smry_html)

plot_html = file_html(p, CDN, 'Diagnostics')
index_f.write(plot_html)
index_f.write('</body></html>\n')


Expand Down
8 changes: 5 additions & 3 deletions gneiss/plot/_radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from gneiss.plot._dendrogram import UnrootedDendrogram
from bokeh.models.glyphs import Circle, Segment
from bokeh.models import ColumnDataSource, DataRange1d, Plot
from bokeh.models import HoverTool, BoxZoomTool, ResetTool
from bokeh.models import (HoverTool, BoxZoomTool, ResetTool,
WheelZoomTool, SaveTool)


def radialplot(tree, node_color='node_color', node_size='node_size',
Expand Down Expand Up @@ -59,7 +60,7 @@ def radialplot(tree, node_color='node_color', node_size='node_size',

# This entire function was motivated by
# http://chuckpr.github.io/blog/trees2.html
t = UnrootedDendrogram.from_tree(tree)
t = UnrootedDendrogram.from_tree(tree.copy())

nodes = t.coords(figsize[0], figsize[1])

Expand Down Expand Up @@ -123,6 +124,7 @@ def df2ds(df):
tooltip += [(hover_var, "@" + hover_var)]

hover = HoverTool(renderers=[ns], tooltips=tooltip)
plot.add_tools(hover, BoxZoomTool(), ResetTool())
plot.add_tools(hover, BoxZoomTool(), ResetTool(),
WheelZoomTool(), SaveTool())

return plot
3 changes: 2 additions & 1 deletion gneiss/plot/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,15 @@ def test_visualization(self):

with open(index_fp, 'r') as fh:
html = fh.read()

self.assertIn('<h1>Simplicial Linear Regression Summary</h1>',
html)
self.assertIn('<th>Relative importance</th>', html)
self.assertIn('<th>Cross Validation</th>', html)
self.assertIn('<th>Coefficients</th>\n', html)
self.assertIn('<th>Raw Balances</th>\n', html)
self.assertIn('<th>Predicted Proportions</th>\n', html)
self.assertIn('<th>Residuals</th>\n', html)
self.assertIn('<th>Tree</th>\n', html)


class TestLME_Summary(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions gneiss/plot/tests/test_radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_basic_plot(self):

self.assertDictEqual(p.renderers[0].data_source.data, exp_edges)
self.assertDictEqual(p.renderers[1].data_source.data, exp_nodes)
self.assertTrue(isinstance(t, TreeNode))


if __name__ == "__main__":
Expand Down