Skip to content

Commit

Permalink
removed action and simplified rrm and grm (#3944)
Browse files Browse the repository at this point in the history
* removed action and simplified rrm and grm

* deleted repetative note in grm
  • Loading branch information
jbloom22 committed Jul 18, 2018
1 parent 03b493a commit 19213e5
Showing 1 changed file with 19 additions and 44 deletions.
63 changes: 19 additions & 44 deletions python/hail/methods/statgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,7 @@ def genetic_relatedness_matrix(call_expr) -> BlockMatrix:
Examples
--------
>>> grm = hl.genetic_relatedness_matrix(dataset.GT)
Notes
Expand Down Expand Up @@ -1694,8 +1694,8 @@ def genetic_relatedness_matrix(call_expr) -> BlockMatrix:
G_{ik} = \\frac{1}{m} \\sum_{j=1}^m \\frac{(C_{ij}-2p_j)(C_{kj}-2p_j)}{2 p_j (1-p_j)}
Note that variants for which the alternate allele frequency is zero or one are not
normalizable, and therefore removed prior to calculating the GRM.
This method drops variants with :math:`p_j = 0` or math:`p_j = 1` before
computing kinship.
Parameters
----------
Expand All @@ -1710,29 +1710,20 @@ def genetic_relatedness_matrix(call_expr) -> BlockMatrix:
correspond to matrix table column index.
"""
mt = matrix_table_source('genetic_relatedness_matrix/call_expr', call_expr)
check_entry_indexed('genetic_relatedness_matrix/call_expr', call_expr)

mt = mt.select_entries(__gt=call_expr.n_alt_alleles())
mt = mt.annotate_rows(__AC=agg.sum(mt.__gt),
__n_called=agg.count_where(hl.is_defined(mt.__gt)))
mt = mt.select_rows(__AC=agg.sum(mt.__gt),
__n_called=agg.count_where(hl.is_defined(mt.__gt)))
mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called))
mt = mt.persist()

n_variants = mt.count_rows()
if n_variants == 0:
raise FatalError("Cannot run GRM: found 0 variants after filtering out monomorphic sites.")
info("Computing GRM using {} variants.".format(n_variants))

mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called)
mt = mt.annotate_rows(
__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) * n_variants / 2))
mt = mt.select_rows(__mean_gt=mt.__AC / mt.__n_called)
mt = mt.annotate_rows(__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt)))

normalized_gt = hl.or_else((mt.__gt - mt.__mean_gt) / mt.__hwe_scaled_std_dev, 0.0)

bm = BlockMatrix.from_entry_expr(normalized_gt)
mt.unpersist()
grm = bm.T @ bm

return grm
return (bm.T @ bm) / (bm.n_rows / 2.0)


@typecheck(call_expr=expr_call)
Expand Down Expand Up @@ -1777,6 +1768,8 @@ def realized_relationship_matrix(call_expr) -> BlockMatrix:
where RRM uses empirical variance, GRM uses expected variance under
Hardy-Weinberg Equilibrium.
This method drops variants with zero variance before computing kinship.
Parameters
----------
call_expr : :class:`.CallExpression`
Expand All @@ -1790,38 +1783,20 @@ def realized_relationship_matrix(call_expr) -> BlockMatrix:
correspond to matrix table column index.
"""
mt = matrix_table_source('realized_relationship_matrix/call_expr', call_expr)
check_entry_indexed('realized_relationship_matrix/call_expr', call_expr)

mt = mt.select_entries(__gt=call_expr.n_alt_alleles())

mt = mt.annotate_rows(__AC=agg.sum(mt.__gt),
__ACsq=agg.sum(mt.__gt * mt.__gt),
__n_called=agg.count_where(hl.is_defined(mt.__gt)))

mt = mt.filter_rows((mt.__AC > 0) &
(mt.__AC < 2 * mt.__n_called) &
((mt.__AC != mt.__n_called) |
(mt.__ACsq != mt.__n_called)))
mt = mt.persist()

n_variants, n_samples = mt.count()

# once count_rows() adds partition_counts we can avoid annotating and filtering twice
if n_variants == 0:
raise FatalError("Cannot run RRM: found 0 variants after filtering out monomorphic sites.")
info("Computing RRM using {} variants.".format(n_variants))

mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called)
mt = mt.annotate_rows(__scaled_std_dev=hl.sqrt((mt.__ACsq + (n_samples - mt.__n_called) * mt.__mean_gt ** 2) /
n_samples - mt.__mean_gt ** 2))
mt = mt.select_rows(__AC=agg.sum(mt.__gt),
__ACsq=agg.sum(mt.__gt * mt.__gt),
__n_called=agg.count_where(hl.is_defined(mt.__gt)))
mt = mt.select_rows(__mean_gt=mt.__AC / mt.__n_called,
__scaled_std_dev=hl.sqrt(mt.__ACsq - (mt.__AC ** 2) / mt.__n_called))
mt = mt.filter_rows(mt.__scaled_std_dev > 1e-30)

normalized_gt = hl.or_else((mt.__gt - mt.__mean_gt) / mt.__scaled_std_dev, 0.0)

bm = BlockMatrix.from_entry_expr(normalized_gt)
mt.unpersist()

rrm = (bm.T @ bm) / n_variants

return rrm
return (bm.T @ bm) / (bm.n_rows / bm.n_cols)


@typecheck(n_populations=int,
Expand Down

0 comments on commit 19213e5

Please sign in to comment.