Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vds/combiner] use get_lgt for PGT handling as well #13829

Merged
merged 2 commits into from Oct 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 9 additions & 8 deletions hail/python/hail/vds/combiner/combine.py
Expand Up @@ -30,16 +30,16 @@ def make_variants_matrix_table(mt: MatrixTable,

transform_row = _transform_variant_function_map.get((mt.row.dtype, info_key))
if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name):
def get_lgt(e, n_alleles, has_non_ref, row):
index = e.GT.unphased_diploid_gt_index()
def get_lgt(gt, n_alleles, has_non_ref, row):
index = gt.unphased_diploid_gt_index()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function valid on phased GTs? Do we need to unphase it first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do thanks for checking

n_no_nonref = n_alleles - hl.int(has_non_ref)
triangle_without_nonref = hl.triangle(n_no_nonref)
return (hl.case()
.when(e.GT.is_haploid(),
hl.or_missing(e.GT[0] < n_no_nonref, e.GT))
.when(index < triangle_without_nonref, e.GT)
.when(gt.is_haploid(),
hl.or_missing(gt[0] < n_no_nonref, gt))
.when(index < triangle_without_nonref, gt)
.when(index < hl.triangle(n_alleles), hl.missing('call'))
.or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus)))
.or_error('invalid GT ' + hl.str(gt) + ' at site ' + hl.str(row.locus)))

def make_entry_struct(e, alleles_len, has_non_ref, row):
handled_fields = dict()
Expand All @@ -53,11 +53,12 @@ def make_entry_struct(e, alleles_len, has_non_ref, row):
raise hl.utils.FatalError("the Hail VDS combiner expects input GVCFs to have a 'GT' field in FORMAT.")

handled_fields['LA'] = hl.range(0, alleles_len - hl.if_else(has_non_ref, 1, 0))
handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
handled_fields['LGT'] = get_lgt(e.GT, alleles_len, has_non_ref, row)
if 'AD' in e:
handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD)
if 'PGT' in e:
handled_fields['LPGT'] = e.PGT
handled_fields['LPGT'] = e.PGT if e.PGT.dtype != hl.tcall \
else get_lgt(e.PGT, alleles_len, has_non_ref, row)
if 'PL' in e:
handled_fields['LPL'] = hl.if_else(has_non_ref,
hl.if_else(alleles_len > 2,
Expand Down