Skip to content

Commit

Permalink
[vds] Add a function to impute sex ploidy directly from a coverage MT (
Browse files Browse the repository at this point in the history
…#12195)

* Add function to impute sex chromosome ploidy from precomputed interval coverage MT

* Add test for `impute_sex_chr_ploidy_from_interval_coverage`

* Fix `impute_sex_chr_ploidy_from_interval_coverage` types

* Fix logger

* Fix interval annotation when `use_variant_dataset`

* Expose impute_sex_chr_ploidy_from_interval_coverage in hl.vds.__init__
  • Loading branch information
jkgoodrich committed Sep 29, 2022
1 parent f669168 commit 2d7e006
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 25 deletions.
5 changes: 3 additions & 2 deletions hail/python/hail/vds/__init__.py
@@ -1,8 +1,8 @@
from . import combiner
from .functions import lgt_to_gt
from .methods import filter_intervals, filter_samples, filter_variants, sample_qc, split_multi, to_dense_mt, \
to_merged_sparse_mt, segment_reference_blocks, write_variant_datasets, interval_coverage, impute_sex_chromosome_ploidy, \
filter_chromosomes
to_merged_sparse_mt, segment_reference_blocks, write_variant_datasets, interval_coverage, \
impute_sex_chr_ploidy_from_interval_coverage, impute_sex_chromosome_ploidy, filter_chromosomes
from .variant_dataset import VariantDataset, read_vds
from .combiner import load_combiner, new_combiner

Expand All @@ -23,6 +23,7 @@
'write_variant_datasets',
'segment_reference_blocks',
'interval_coverage',
'impute_sex_chr_ploidy_from_interval_coverage',
'impute_sex_chromosome_ploidy',
'lgt_to_gt'
]
97 changes: 74 additions & 23 deletions hail/python/hail/vds/methods.py
Expand Up @@ -372,6 +372,75 @@ def new_la_index(old_idx):
return VariantDataset(reference_data, variant_data)


@typecheck(mt=MatrixTable, normalization_contig=str)
def impute_sex_chr_ploidy_from_interval_coverage(
mt: 'MatrixTable',
normalization_contig: str,
) -> 'Table':
"""Impute sex chromosome ploidy from a precomputed interval coverage MatrixTable.
The input MatrixTable must have the following row fields:
- ``interval`` (*interval*): Genomic interval of interest.
- ``interval_size`` (*int32*): Size of interval, in bases.
And the following entry fields:
- ``sum_dp`` (*int64*): Sum of depth values by base across the interval.
Returns a :class:`.Table` with sample ID keys, with the following fields:
- ``autosomal_mean_dp`` (*float64*): Mean depth on calling intervals on normalization contig.
- ``x_mean_dp`` (*float64*): Mean depth on calling intervals on X chromosome.
- ``x_ploidy`` (*float64*): Estimated ploidy on X chromosome. Equal to ``2 * x_mean_dp / autosomal_mean_dp``.
- ``y_mean_dp`` (*float64*): Mean depth on calling intervals on chromosome.
- ``y_ploidy`` (*float64*): Estimated ploidy on Y chromosome. Equal to ``2 * y_mean_db / autosomal_mean_dp``.
Parameters
----------
mt : :class:`.MatrixTable`
Interval-by-sample MatrixTable with sum of depth values across the interval.
normalization_contig : str
Autosomal contig for depth comparison.
Returns
-------
:class:`.Table`
"""

rg = mt.interval.start.dtype.reference_genome

if len(rg.x_contigs) != 1:
raise NotImplementedError(
f"reference genome {rg.name!r} has multiple X contigs, this is not supported in 'impute_sex_chr_ploidy_from_interval_coverage'"
)
chr_x = rg.x_contigs[0]
if len(rg.y_contigs) != 1:
raise NotImplementedError(
f"reference genome {rg.name!r} has multiple Y contigs, this is not supported in 'impute_sex_chr_ploidy_from_interval_coverage'"
)
chr_y = rg.y_contigs[0]

mt = mt.annotate_rows(contig=mt.interval.start.contig)
mt = mt.annotate_cols(
__mean_dp=hl.agg.group_by(
mt.contig, hl.agg.sum(mt.sum_dp) / hl.agg.sum(mt.interval_size)
)
)

mean_dp_dict = mt.__mean_dp
auto_dp = mean_dp_dict.get(normalization_contig, 0.0)
x_dp = mean_dp_dict.get(chr_x, 0.0)
y_dp = mean_dp_dict.get(chr_y, 0.0)
per_sample = mt.transmute_cols(autosomal_mean_dp=auto_dp,
x_mean_dp=x_dp,
x_ploidy=2 * x_dp / auto_dp,
y_mean_dp=y_dp,
y_ploidy=2 * y_dp / auto_dp)
info("'impute_sex_chromosome_ploidy': computing and checkpointing coverage and karyotype metrics")
return per_sample.cols().checkpoint(new_temp_file('impute_sex_karyotype', extension='ht'))


@typecheck(vds=VariantDataset,
calling_intervals=oneof(Table, expr_array(expr_interval(expr_locus()))),
normalization_contig=str,
Expand Down Expand Up @@ -446,43 +515,25 @@ def impute_sex_chromosome_ploidy(
raise NotImplementedError(
f"reference genome {rg.name!r} has multiple X contigs, this is not supported in 'impute_sex_chromosome_ploidy'"
)
chr_x = rg.x_contigs[0]
if len(rg.y_contigs) != 1:
raise NotImplementedError(
f"reference genome {rg.name!r} has multiple Y contigs, this is not supported in 'impute_sex_chromosome_ploidy'"
)
chr_y = rg.y_contigs[0]

kept_contig_filter = hl.array(chrs_represented).map(lambda x: hl.parse_locus_interval(x, reference_genome=rg))
vds = VariantDataset(hl.filter_intervals(vds.reference_data, kept_contig_filter),
hl.filter_intervals(vds.variant_data, kept_contig_filter))

if use_variant_dataset:
mt = vds.variant_data
mt = mt.filter_rows(hl.is_defined(calling_intervals[mt.locus]))
coverage = mt.select_cols(
__mean_dp=hl.agg.group_by(mt.locus.contig,
hl.agg.sum(mt.DP)
/ hl.agg.filter(mt["LGT" if "LGT" in mt.entry else "GT"].is_non_ref(),
hl.agg.count())))
calling_intervals = calling_intervals.annotate(interval_dup=interval)
mt = mt.annotate_rows(interval=calling_intervals[mt.locus].interval_dup)
mt = mt.filter_rows(hl.is_defined(mt.interval))
coverage = mt.select_entries(sum_dp=mt.DP, interval_size=hl.is_defined(mt.DP))
else:
coverage = interval_coverage(vds, calling_intervals, gq_thresholds=()).drop('gq_thresholds')

coverage = coverage.annotate_rows(contig=coverage.interval.start.contig)
coverage = coverage.annotate_cols(
__mean_dp=hl.agg.group_by(coverage.contig, hl.agg.sum(coverage.sum_dp) / hl.agg.sum(coverage.interval_size)))

mean_dp_dict = coverage.__mean_dp
auto_dp = mean_dp_dict.get(normalization_contig, 0.0)
x_dp = mean_dp_dict.get(chr_x, 0.0)
y_dp = mean_dp_dict.get(chr_y, 0.0)
per_sample = coverage.transmute_cols(autosomal_mean_dp=auto_dp,
x_mean_dp=x_dp,
x_ploidy=2 * x_dp / auto_dp,
y_mean_dp=y_dp,
y_ploidy=2 * y_dp / auto_dp)
info("'impute_sex_chromosome_ploidy': computing and checkpointing coverage and karyotype metrics")
return per_sample.cols().checkpoint(new_temp_file('impute_sex_karyotype', extension='ht'))
return impute_sex_chr_ploidy_from_interval_coverage(coverage, normalization_contig)


@typecheck(vds=VariantDataset, variants_table=Table, keep=bool)
Expand Down
41 changes: 41 additions & 0 deletions hail/python/test/hail/vds/test_vds.py
Expand Up @@ -257,6 +257,47 @@ def test_interval_coverage():
pytest.approx(obs.mean_dp, exp.mean_dp)


def test_impute_sex_chr_ploidy_from_interval_coverage():
norm_interval_1 = hl.parse_locus_interval('20:10-30', reference_genome='GRCh37')
norm_interval_2 = hl.parse_locus_interval('20:40-45', reference_genome='GRCh37')
x_interval_1 = hl.parse_locus_interval('X:10-20', reference_genome='GRCh37')
x_interval_2 = hl.parse_locus_interval('X:25-35', reference_genome='GRCh37')
y_interval_1 = hl.parse_locus_interval('Y:10-20', reference_genome='GRCh37')
y_interval_2 = hl.parse_locus_interval('Y:25-30', reference_genome='GRCh37')

mt = hl.Table.parallelize([hl.Struct(s='sample_xx', interval=norm_interval_1, sum_dp=195),
hl.Struct(s='sample_xx', interval=norm_interval_2, sum_dp=55),
hl.Struct(s='sample_xx', interval=x_interval_1, sum_dp=95),
hl.Struct(s='sample_xx', interval=x_interval_2, sum_dp=85),
hl.Struct(s='sample_xy', interval=norm_interval_1, sum_dp=190),
hl.Struct(s='sample_xy', interval=norm_interval_2, sum_dp=85),
hl.Struct(s='sample_xy', interval=x_interval_1, sum_dp=61),
hl.Struct(s='sample_xy', interval=x_interval_2, sum_dp=49),
hl.Struct(s='sample_xy', interval=y_interval_1, sum_dp=54),
hl.Struct(s='sample_xy', interval=y_interval_2, sum_dp=45)],
schema=hl.dtype(
'struct{s:str,interval:interval<locus<GRCh37>>,sum_dp:int32}')).to_matrix_table(
row_key=['interval'], col_key=['s'])

mt = mt.annotate_rows(interval_size=mt.interval.end.position - mt.interval.start.position)
r = hl.vds.impute_sex_chr_ploidy_from_interval_coverage(mt, normalization_contig='20')

assert r.collect() == [
hl.Struct(s='sample_xx',
autosomal_mean_dp=10.0,
x_mean_dp=9.0,
x_ploidy=1.8,
y_mean_dp=0.0,
y_ploidy=0.0),
hl.Struct(s='sample_xy',
autosomal_mean_dp=11.0,
x_mean_dp=5.5,
x_ploidy=1.0,
y_mean_dp=6.6,
y_ploidy=1.2)
]


def test_impute_sex_chromosome_ploidy():
x_par_end = 2699521
y_par_end = 2649521
Expand Down

0 comments on commit 2d7e006

Please sign in to comment.