Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make ld_prune fast again #5078

Merged
merged 33 commits into from
Jan 24, 2019
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ddac8b2
fast ldprune
danking Jan 4, 2019
9c093ba
remove unnecessary writes
danking Jan 4, 2019
f1fb603
remove unnecessary coerce sorted
danking Jan 4, 2019
ada7979
better error messages
danking Jan 4, 2019
7168e23
remove unncessary sort
danking Jan 4, 2019
5f20758
carefully avoid sorts
danking Jan 4, 2019
099579c
fix filtering expression
danking Jan 4, 2019
8bfa95b
remove rogue files
danking Jan 4, 2019
ea27190
BM.entries optionally keys its output
danking Jan 8, 2019
11e4b7c
use konrads xmas present
danking Jan 8, 2019
3e72db1
Update statgen.py
danking Jan 11, 2019
56bfa43
Update misc.py
danking Jan 11, 2019
74a4d0a
wip
danking Jan 16, 2019
53fea49
fix mis
danking Jan 17, 2019
eec60bf
passes tests
danking Jan 17, 2019
0b16c8e
passes tests
danking Jan 17, 2019
bc9cd60
remove gradle changes
danking Jan 17, 2019
408522c
wip works maybe
danking Jan 17, 2019
c4bf4e0
wip works
danking Jan 17, 2019
55041ed
read/write before mis
danking Jan 17, 2019
01713d3
use localize
danking Jan 17, 2019
0d9c814
fix mis
danking Jan 17, 2019
a1a4be2
run sort once
danking Jan 17, 2019
a5edd98
no unnecessary select_entries
danking Jan 17, 2019
a7857f2
avoid an allocation of a struct for require biallelic
danking Jan 17, 2019
d9c3f1e
remove rogue maxfail
danking Jan 23, 2019
71815dc
fix capitalization of returns
danking Jan 23, 2019
943808b
cannot reliably check this without running the computation twice
danking Jan 23, 2019
baf9ccb
fix when expr is expression but no normalization etc. is done
danking Jan 23, 2019
cbcd4a8
fix spaces/tabs
danking Jan 23, 2019
e768bbd
Update blockmatrix.py
danking Jan 24, 2019
a62ee9a
fix readme
danking Jan 24, 2019
b3bdf66
no fixmes
danking Jan 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions hail/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ task deploy(type: Exec, dependsOn: ['generateBuildInfo', 'shadowJar']) {
outputs.upToDateWhen { false }
}

task pipInstall(type: Exec, dependsOn: ['generateBuildInfo', 'shadowJar']) {
commandLine 'bash', 'python/pipinstall.sh'
outputs.upToDateWhen { false }
}

compileScala {
dependsOn generateBuildInfo

Expand Down
19 changes: 13 additions & 6 deletions hail/python/hail/linalg/blockmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,14 @@ def write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=False,
check_entry_indexed('BlockMatrix.write_from_entry_expr', entry_expr)
mt = matrix_table_source('BlockMatrix.write_from_entry_expr', entry_expr)

if (not (mean_impute or center or normalize)) and (entry_expr in mt._fields_inverse):
# FIXME: remove once select_entries on a field is free
field = mt._fields_inverse[entry_expr]
mt._write_block_matrix(path, overwrite, field, block_size)
if not (mean_impute or center or normalize):
if entry_expr in mt._fields_inverse:
# FIXME: remove once select_entries on a field is free
field = mt._fields_inverse[entry_expr]
mt._write_block_matrix(path, overwrite, field, block_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prevents field pruning. Emitting a mt.select_entries('field') would fix that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed. I also removed the FIXME in favor of a punch-list issue: #5202

else:
field = Env.get_uid()
mt.select_entries(**{field: entry_expr})._write_block_matrix(path, overwrite, field, block_size)
else:
n_cols = mt.count_cols()
mt = mt.select_entries(__x=entry_expr)
Expand Down Expand Up @@ -1554,7 +1558,7 @@ def sum(self, axis=None):
else:
raise ValueError(f'axis must be None, 0, or 1: found {axis}')

def entries(self):
def entries(self, keyed=True):
"""Returns a table with the indices and value of each block matrix entry.

Examples
Expand Down Expand Up @@ -1595,7 +1599,10 @@ def entries(self):
:class:`.Table`
Table with a row for each entry.
"""
return Table._from_java(self._jbm.entriesTable(Env.hc()._jhc))
t = Table._from_java(self._jbm.entriesTable(Env.hc()._jhc))
if keyed:
t = t.key_by('i', 'j')
return t

@staticmethod
@typecheck(path_in=str,
Expand Down
44 changes: 28 additions & 16 deletions hail/python/hail/methods/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hail.matrixtable import MatrixTable
from hail.table import Table
from hail.typecheck import *
from hail.utils import Interval, Struct
from hail.utils import Interval, Struct, new_temp_file
from hail.utils.misc import plural
from hail.utils.java import Env, joption, info
from hail.ir import *
Expand All @@ -15,8 +15,9 @@
@typecheck(i=Expression,
j=Expression,
keep=bool,
tie_breaker=nullable(func_spec(2, expr_numeric)))
def maximal_independent_set(i, j, keep=True, tie_breaker=None) -> Table:
tie_breaker=nullable(func_spec(2, expr_numeric)),
keyed=bool)
def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Table:
"""Return a table containing the vertices in a near
`maximal independent set <https://en.wikipedia.org/wiki/Maximal_independent_set>`_
of an undirected graph whose edges are given by a two-column table.
Expand Down Expand Up @@ -101,6 +102,9 @@ def maximal_independent_set(i, j, keep=True, tie_breaker=None) -> Table:
If ``True``, return vertices in set. If ``False``, return vertices removed.
tie_breaker : function
Function used to order nodes with equal degree.
keyed : :obj:`bool`
If ``True``, key the resulting table by the `node` field, this requires
a sort.

Returns
-------
Expand Down Expand Up @@ -137,19 +141,27 @@ def maximal_independent_set(i, j, keep=True, tie_breaker=None) -> Table:
t, _ = source._process_joins(i, j)
tie_breaker_str = None

nodes = (t.select(node=[i, j])
.explode('node')
.key_by('node')
.select())

edges = t.select(__i=i, __j=j).key_by().select('__i', '__j')
nodes_in_set = Env.hail().utils.Graph.maximalIndependentSet(edges._jt.collect(), node_t._parsable_string(), joption(tie_breaker_str))
nt = Table._from_java(nodes._jt.annotateGlobal(nodes_in_set, hl.tset(node_t)._parsable_string(), 'nodes_in_set'))
nt = (nt
.filter(nt.nodes_in_set.contains(nt.node), keep)
.drop('nodes_in_set'))

return nt
edges_path = new_temp_file()
edges.write(edges_path)
edges = hl.read_table(edges_path)

mis_nodes = Env.hail().utils.Graph.maximalIndependentSet(
edges._jt.collect(),
node_t._parsable_string(),
joption(tie_breaker_str))

nodes = edges.select(node = [edges.__i, edges.__j])
nodes = nodes.explode(nodes.node)
# avoid serializing `mis_nodes` from java to python and back to java
nodes = Table._from_java(
nodes._jt.annotateGlobal(
mis_nodes, hl.tset(node_t)._parsable_string(), 'mis_nodes'))
nodes = nodes.filter(nodes.mis_nodes.contains(nodes.node), keep)
nodes = nodes.select_globals()
if keyed:
return nodes.key_by('node')
return nodes


def require_col_key_str(dataset: MatrixTable, method: str):
Expand Down Expand Up @@ -216,7 +228,7 @@ def require_biallelic(dataset, method) -> MatrixTable:
require_row_key_variant(dataset, method)
return dataset._select_rows(method,
hl.case()
.when(dataset.alleles.length() == 2, dataset.row)
.when(dataset.alleles.length() == 2, dataset._rvrow)
.or_error(f"'{method}' expects biallelic variants ('alleles' field of length 2), found " +
hl.str(dataset.locus) + ", " + hl.str(dataset.alleles)))

Expand Down
65 changes: 32 additions & 33 deletions hail/python/hail/methods/statgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,60 +3311,59 @@ def ld_prune(call_expr, r2=0.2, bp_window_size=1000000, memory_per_core=256, kee
.write(locally_pruned_table_path, overwrite=True))
locally_pruned_table = hl.read_table(locally_pruned_table_path).add_index()

locally_pruned_ds_path = new_temp_file()
mt = mt.annotate_rows(info=locally_pruned_table[mt.row_key])
(mt.filter_rows(hl.is_defined(mt.info))
.write(locally_pruned_ds_path, overwrite=True))
locally_pruned_ds = hl.read_matrix_table(locally_pruned_ds_path)
mt = mt.filter_rows(hl.is_defined(mt.info))

n_locally_pruned_variants = locally_pruned_ds.count_rows()
info(f'ld_prune: local pruning stage retained {n_locally_pruned_variants} variants')

standardized_mean_imputed_gt_expr = hl.or_else(
(locally_pruned_ds[field].n_alt_alleles() - locally_pruned_ds.info.mean) * locally_pruned_ds.info.centered_length_rec,
0.0)

std_gt_bm = BlockMatrix.from_entry_expr(standardized_mean_imputed_gt_expr, block_size=block_size)
std_gt_bm = BlockMatrix.from_entry_expr(
hl.or_else(
(mt[field].n_alt_alleles() - mt.info.mean) * mt.info.centered_length_rec,
0.0),
block_size=block_size)
r2_bm = (std_gt_bm @ std_gt_bm.T) ** 2

_, stops = hl.linalg.utils.locus_windows(locally_pruned_table.locus, bp_window_size)

entries = r2_bm.sparsify_row_intervals(range(stops.size), stops, blocks_only=True).entries()
entries = r2_bm.sparsify_row_intervals(range(stops.size), stops, blocks_only=True).entries(keyed=False)
entries = entries.filter((entries.entry >= r2) & (entries.i < entries.j))
entries = entries.select(i = hl.int32(entries.i), j = hl.int32(entries.j))

locally_pruned_info = locally_pruned_table.key_by('idx').select('locus', 'mean')

entries = entries.select(info_i=locally_pruned_info[entries.i],
info_j=locally_pruned_info[entries.j])
if keep_higher_maf:
fields = ['mean', 'locus']
else:
fields = ['locus']

entries = entries.filter((entries.info_i.locus.contig == entries.info_j.locus.contig)
& (entries.info_j.locus.position - entries.info_i.locus.position <= bp_window_size))
info = locally_pruned_table.aggregate(
hl.agg.collect(locally_pruned_table.row.select('idx', *fields)), _localize=False)
info = hl.sorted(info, key=lambda x: x.idx)

entries_path = new_temp_file()
entries.write(entries_path, overwrite=True)
entries = hl.read_table(entries_path)
entries = entries.annotate_globals(info = info)

n_edges = entries.count()
info(f'ld_prune: correlation graph of locally-pruned variants has {n_edges} edges,'
f'\n finding maximal independent set...')
entries = entries.filter(
(entries.info[entries.i].locus.contig == entries.info[entries.j].locus.contig) &
(entries.info[entries.j].locus.position - entries.info[entries.i].locus.position <= bp_window_size))

if keep_higher_maf:
entries = entries.key_by(
entries = entries.annotate(
i=hl.struct(idx=entries.i,
twice_maf=hl.min(entries.info_i.mean, 2.0 - entries.info_i.mean)),
twice_maf=hl.min(entries.info[entries.i].mean, 2.0 - entries.info[entries.i].mean)),
j=hl.struct(idx=entries.j,
twice_maf=hl.min(entries.info_j.mean, 2.0 - entries.info_j.mean)))
twice_maf=hl.min(entries.info[entries.j].mean, 2.0 - entries.info[entries.j].mean)))

def tie_breaker(l, r):
return hl.sign(r.twice_maf - l.twice_maf)

variants_to_remove = hl.maximal_independent_set(entries.i, entries.j, keep=False, tie_breaker=tie_breaker)
variants_to_remove = variants_to_remove.key_by(variants_to_remove.node.idx)
else:
variants_to_remove = hl.maximal_independent_set(entries.i, entries.j, keep=False)
tie_breaker = None

variants_to_remove = hl.maximal_independent_set(
entries.i, entries.j, keep=False, tie_breaker=tie_breaker, keyed=False)

locally_pruned_table = locally_pruned_table.annotate_globals(
variants_to_remove = variants_to_remove.aggregate(
hl.agg.collect_as_set(variants_to_remove.node.idx), _localize=False))
return locally_pruned_table.filter(
hl.is_defined(variants_to_remove[locally_pruned_table.idx]), keep=False).select().persist()
locally_pruned_table.variants_to_remove.contains(hl.int32(locally_pruned_table.idx)),
keep=False
).select().persist()


def _warn_if_no_intercept(caller, covariates):
Expand Down
23 changes: 23 additions & 0 deletions hail/python/pipinstall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash

set -ex

cd $(CDPATH= cd -- "$(dirname -- "$0")" && pwd)

cleanup() {
trap "" INT TERM
rm hail/hail-all-spark.jar
rm README.md
rm -rf build/lib
}
trap cleanup EXIT
trap "exit 24" INT TERM

python3=${HAIL_PYTHON3:-python3}

cp ../build/libs/hail-all-spark.jar hail/
cp ../../README.md .
rm -f dist/*
$python3 setup.py sdist bdist_wheel
ls dist
pip install -U dist/hail-$(cat hail/hail_pip_version)-py3-none-any.whl
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/linalg/BlockMatrix.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])],
}
}

new Table(hc, entriesRDD, rvRowType, Array("i", "j"))
new Table(hc, entriesRDD, rvRowType, Array[String]())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is technically a user-facing change, right? That's OK since this is experimental, but I'll remember to put it in the change log.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by adding an optional keyed=True parameter in python which calls key_by if necessary.

}
}

Expand Down Expand Up @@ -1704,4 +1704,4 @@ class WriteBlocksRDD(path: String,
outPerBlockCol.foreach(_.close())
blockPartFiles.iterator
}
}
}
5 changes: 2 additions & 3 deletions hail/src/main/scala/is/hail/methods/LocalLDPrune.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import is.hail.expr.types.physical.{PArray, PInt64Required, PStruct}
import is.hail.expr.types.virtual._
import is.hail.rvd.{RVD, RVDType}
import is.hail.table.Table
import is.hail.variant._
import is.hail.utils._
import is.hail.variant._

object BitPackedVectorView {
val bpvElementSize: Long = PInt64Required.byteSize
Expand Down Expand Up @@ -365,8 +365,7 @@ case class LocalLDPrune(
}
})

TableValue(tableType, BroadcastRow.empty(mv.sparkContext),
RVD.coerce(tableRVDType, sitesOnly.crdd))
TableValue(tableType, BroadcastRow.empty(mv.sparkContext), sitesOnly)
}
}

12 changes: 9 additions & 3 deletions hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,15 @@ class BlockMatrixSuite extends SparkSuite {
val data = (0 until 50).map(_.toDouble).toArray
val lm = new BDM[Double](5, 10, data)
val bm = toBM(lm, blockSize = 2)

assert(bm.filterBlocks(Array(0, 1, 6)).entriesTable(hc).collect().map(r => r.get(2).asInstanceOf[Double]) sameElements
Array(0, 5, 20, 25, 1, 6, 21, 26, 2, 7, 3, 8).map(_.toDouble))

val expected = bm
.filterBlocks(Array(0, 1, 6))
.entriesTable(hc)
.collect()
.sortBy(r => (r.get(0).asInstanceOf[Long], r.get(1).asInstanceOf[Long]))
.map(r => r.get(2).asInstanceOf[Double])

assert(expected sameElements Array[Double](0, 5, 20, 25, 1, 6, 21, 26, 2, 7, 3, 8))
}

@Test
Expand Down