diff --git a/hail/python/hail/docs/plot.rst b/hail/python/hail/docs/plot.rst index 4e6e58c3179..0bd42563310 100644 --- a/hail/python/hail/docs/plot.rst +++ b/hail/python/hail/docs/plot.rst @@ -21,12 +21,14 @@ Plot functions in Hail accept data in the form of either Python objects or :clas histogram cumulative_histogram + histogram2d scatter qq manhattan .. autofunction:: histogram .. autofunction:: cumulative_histogram +.. autofunction:: histogram2d .. autofunction:: scatter .. autofunction:: qq .. autofunction:: manhattan diff --git a/hail/python/hail/docs/tutorials/plotting.ipynb b/hail/python/hail/docs/tutorials/plotting.ipynb index d28513e856d..749d037d99d 100644 --- a/hail/python/hail/docs/tutorials/plotting.ipynb +++ b/hail/python/hail/docs/tutorials/plotting.ipynb @@ -171,6 +171,25 @@ "show(gridplot([p, p2], ncols=2, plot_width=400, plot_height=400))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2-D histogram\n", + "\n", + "For visualizing relationships between variables in large datasets (where scatter plots may be less informative since they highlight outliers), the `histogram_2d()` function will create a heatmap with the number of observations in each section of a 2-d grid based on two variables." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "p = hl.plot.histogram2d(pca_scores.scores[0], pca_scores.scores[1])\n", + "show(p)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -228,13 +247,6 @@ "p = hl.plot.manhattan(gwas.p_value, hover_fields=hover_fields, collect_all=True)\n", "show(p)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/hail/python/hail/plot/__init__.py b/hail/python/hail/plot/__init__.py index 046dae4e7d9..9c12289b149 100644 --- a/hail/python/hail/plot/__init__.py +++ b/hail/python/hail/plot/__init__.py @@ -1,9 +1,10 @@ -from .plots import output_notebook, show, histogram, cumulative_histogram, scatter, qq, manhattan +from .plots import output_notebook, show, histogram, cumulative_histogram, histogram2d, scatter, qq, manhattan __all__ = ['output_notebook', 'show', 'histogram', 'cumulative_histogram', 'scatter', + 'histogram2d', 'qq', 'manhattan'] diff --git a/hail/python/hail/plot/plots.py b/hail/python/hail/plot/plots.py index 5c88d859306..b7614234818 100644 --- a/hail/python/hail/plot/plots.py +++ b/hail/python/hail/plot/plots.py @@ -1,6 +1,7 @@ from math import log, isnan, log10 import numpy as np +import bokeh import bokeh.io from bokeh.models import * from bokeh.plotting import figure @@ -146,6 +147,169 @@ def cumulative_histogram(data, range=None, bins=50, legend=None, title=None, nor return p +@typecheck(x=expr_numeric, y=expr_numeric, bins=oneof(int, sequenceof(int)), + range=nullable(sized_tupleof(nullable(sized_tupleof(numeric, numeric)), + nullable(sized_tupleof(numeric, numeric)))), + title=nullable(str), width=int, height=int, + font_size=str, colors=sequenceof(str)) +def histogram2d(x, y, bins=40, range=None, + title=None, width=600, height=600, font_size='7pt', + colors=bokeh.palettes.all_palettes['Blues'][7][::-1]): + """Plot a two-dimensional histogram. + + ``x`` and ``y`` must both be a :class:`NumericExpression` from the same :class:`Table`. + + If ``x_range`` or ``y_range`` are not provided, the function will do a pass through the data to determine + min and max of each variable. + + Examples + -------- + + >>> ht = hail.utils.range_table(1000).annotate(x=hail.rand_norm(), y=hail.rand_norm()) + >>> p_hist = hail.plot.histogram2d(ht.x, ht.y) + + >>> ht = hail.utils.range_table(1000).annotate(x=hail.rand_norm(), y=hail.rand_norm()) + >>> p_hist = hail.plot.histogram2d(ht.x, ht.y, bins=10, range=((0, 1), None)) + + Parameters + ---------- + x : :class:`.NumericExpression` + Expression for x-axis (from a Hail table). + y : :class:`.NumericExpression` + Expression for y-axis (from the same Hail table as ``x``). + bins : int or [int, int] + The bin specification: + - If int, the number of bins for the two dimensions (nx = ny = bins). + - If [int, int], the number of bins in each dimension (nx, ny = bins). + The default value is 40. + range : None or ((float, float), (float, float)) + The leftmost and rightmost edges of the bins along each dimension: + ((xmin, xmax), (ymin, ymax)). All values outside of this range will be considered outliers + and not tallied in the histogram. If this value is None, or either of the inner lists is None, + the range will be computed from the data. + width : int + Plot width (default 600px). + height : int + Plot height (default 600px). + title : str + Title of the plot. + font_size : str + String of font size in points (default '7pt'). + colors : List[str] + List of colors (hex codes, or strings as described + `here `__). Compatible with one of the many + built-in palettes available `here `__. + + Returns + ------- + :class:`bokeh.plotting.figure.Figure` + """ + source = x._indices.source + y_source = y._indices.source + + if source is None or y_source is None: + raise ValueError("histogram_2d expects two expressions of 'Table', found scalar expression") + if isinstance(source, hail.MatrixTable): + raise ValueError("histogram_2d requires source to be Table, not MatrixTable") + if source != y_source: + raise ValueError(f"histogram_2d expects two expressions from the same 'Table', found {source} and {y_source}") + check_row_indexed('histogram_2d', x) + check_row_indexed('histogram_2d', y) + if isinstance(bins, int): + x_bins = y_bins = bins + else: + x_bins, y_bins = bins + if range is None: + x_range = y_range = None + else: + x_range, y_range = range + if x_range is None or y_range is None: + warnings.warn('At least one range was not defined in histogram_2d. Doing two passes...') + ranges = source.aggregate(hail.struct(x_stats=hail.agg.stats(x), + y_stats=hail.agg.stats(y))) + if x_range is None: + x_range = (ranges.x_stats.min, ranges.x_stats.max) + if y_range is None: + y_range = (ranges.y_stats.min, ranges.y_stats.max) + else: + warnings.warn('If x_range or y_range are specified in histogram_2d, and there are points ' + 'outside of these ranges, they will not be plotted') + x_range = list(map(float, x_range)) + y_range = list(map(float, y_range)) + x_spacing = (x_range[1] - x_range[0]) / x_bins + y_spacing = (y_range[1] - y_range[0]) / y_bins + + def frange(start, stop, step): + from itertools import count, takewhile + return takewhile(lambda x: x <= stop, count(start, step)) + + x_levels = hail.literal(list(frange(x_range[0], x_range[1], x_spacing))[::-1]) + y_levels = hail.literal(list(frange(y_range[0], y_range[1], y_spacing))[::-1]) + + grouped_ht = source.group_by( + x=hail.str(x_levels.find(lambda w: x >= w)), + y=hail.str(y_levels.find(lambda w: y >= w)) + ).aggregate(c=hail.agg.count()) + data = grouped_ht.filter(hail.is_defined(grouped_ht.x) & (grouped_ht.x != str(x_range[1])) & + hail.is_defined(grouped_ht.y) & (grouped_ht.y != str(y_range[1]))).to_pandas() + + mapper = LinearColorMapper(palette=colors, low=data.c.min(), high=data.c.max()) + + x_axis = sorted(set(data.x), key=lambda z: float(z)) + y_axis = sorted(set(data.y), key=lambda z: float(z)) + p = figure(title=title, + x_range=x_axis, y_range=y_axis, + x_axis_location="above", plot_width=width, plot_height=height, + tools="hover,save,pan,box_zoom,reset,wheel_zoom", toolbar_location='below') + + p.grid.grid_line_color = None + p.axis.axis_line_color = None + p.axis.major_tick_line_color = None + p.axis.major_label_standoff = 0 + p.axis.major_label_text_font_size = font_size + import math + p.xaxis.major_label_orientation = math.pi / 3 + + p.rect(x='x', y='y', width=1, height=1, + source=data, + fill_color={'field': 'c', 'transform': mapper}, + line_color=None) + + color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size=font_size, + ticker=BasicTicker(desired_num_ticks=6), + label_standoff=6, border_line_color=None, location=(0, 0)) + p.add_layout(color_bar, 'right') + + def set_font_size(p, font_size: str = '12pt'): + """Set most of the font sizes in a bokeh figure + + Parameters + ---------- + p : :class:`bokeh.plotting.figure.Figure` + Input figure. + font_size : str + String of font size in points (e.g. '12pt'). + + Returns + ------- + :class:`bokeh.plotting.figure.Figure` + """ + p.legend.label_text_font_size = font_size + p.xaxis.axis_label_text_font_size = font_size + p.yaxis.axis_label_text_font_size = font_size + p.xaxis.major_label_text_font_size = font_size + p.yaxis.major_label_text_font_size = font_size + if hasattr(p.title, 'text_font_size'): + p.title.text_font_size = font_size + if hasattr(p.xaxis, 'group_text_font_size'): + p.xaxis.group_text_font_size = font_size + return p + + p.select_one(HoverTool).tooltips = [('x', '@x'), ('y', '@y',), ('count', '@c')] + p = set_font_size(p, font_size) + return p + + @typecheck(x=oneof(sequenceof(numeric), expr_float64), y=oneof(sequenceof(numeric), expr_float64), label=oneof(nullable(str), expr_str, sequenceof(str)), title=nullable(str), xlabel=nullable(str), ylabel=nullable(str), size=int, legend=bool,