diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index 1971d8f8d..903ab3c10 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1917,10 +1917,10 @@ def annotate_downsamplings( else: t = t.annotate(downsampling=ht[t.s]) - t = t.annotate_globals( - downsamplings=downsamplings, - ds_gen_anc_counts=gen_anc_counts, - ) + t = t.annotate_globals(downsamplings=downsamplings) + + if gen_anc_counts: + t = t.annotate_globals(ds_gen_anc_counts=gen_anc_counts) return t diff --git a/tests/utils/test_annotations.py b/tests/utils/test_annotations.py index a97234a92..f6b060ecc 100644 --- a/tests/utils/test_annotations.py +++ b/tests/utils/test_annotations.py @@ -6,6 +6,7 @@ import pytest from gnomad.utils.annotations import ( + annotate_downsamplings, fill_missing_key_combinations, get_copy_state_by_sex, merge_array_expressions, @@ -1085,3 +1086,162 @@ def test_merge_histograms_sum_with_negatives_error(self, sample_ht): result_hist = merge_histograms([ht.hist1, ht.hist2], operation="sum") with pytest.raises(Exception): ht.select(result_hist=result_hist).collect() + + +class TestAnnotateDownsamplings: + """Test the annotate_downsamplings function.""" + + @pytest.fixture + def sample_matrix_table(self): + """Create a sample MatrixTable for testing.""" + samples = [ + {"s": "sample1", "gen_anc": "AFR"}, + {"s": "sample2", "gen_anc": "EUR"}, + {"s": "sample3", "gen_anc": "AFR"}, + {"s": "sample4", "gen_anc": "EUR"}, + {"s": "sample5", "gen_anc": "SAS"}, + ] + + variants = [ + { + "locus": hl.locus("chr1", 1000, reference_genome="GRCh38"), + "alleles": ["A", "T"], + }, + { + "locus": hl.locus("chr1", 2000, reference_genome="GRCh38"), + "alleles": ["C", "G"], + }, + { + "locus": hl.locus("chr1", 3000, reference_genome="GRCh38"), + "alleles": ["T", "A"], + }, + ] + + sample_table = hl.Table.parallelize( + samples, + hl.tstruct(s=hl.tstr, gen_anc=hl.tstr), + ).key_by("s") + + entries = [] + for variant in variants: + for sample in samples: + entries.append( + { + "locus": variant["locus"], + "alleles": variant["alleles"], + "s": sample["s"], + "GT": hl.call(0, 1), + } + ) + + mt = hl.Table.parallelize( + entries, + hl.tstruct( + locus=hl.tlocus("GRCh38"), + alleles=hl.tarray(hl.tstr), + s=hl.tstr, + GT=hl.tcall, + ), + ).to_matrix_table(row_key=["locus", "alleles"], col_key=["s"]) + + mt = mt.annotate_cols(gen_anc=sample_table[mt.s].gen_anc) + + return mt + + @pytest.fixture + def sample_table(self): + """Create a sample Table for testing.""" + return hl.Table.parallelize( + [ + {"s": "sample1", "sex": "XX", "gen_anc": "AFR"}, + {"s": "sample2", "sex": "XY", "gen_anc": "EUR"}, + {"s": "sample3", "sex": "XX", "gen_anc": "AFR"}, + {"s": "sample4", "sex": "XY", "gen_anc": "EUR"}, + {"s": "sample5", "sex": "XX", "gen_anc": "SAS"}, + ], + hl.tstruct(s=hl.tstr, sex=hl.tstr, gen_anc=hl.tstr), + ).key_by("s") + + def test_annotate_downsamplings_matrix_table_no_gen_anc(self, sample_matrix_table): + """Test annotate_downsamplings with MatrixTable input without genetic ancestry.""" + downsamplings = [2, 3, 4] + + result = annotate_downsamplings(sample_matrix_table, downsamplings) + + assert isinstance(result, hl.MatrixTable) + assert "downsampling" in result.col.dtype + assert "downsamplings" in result.globals.dtype + + result_downsamplings = hl.eval(result.downsamplings) + assert result_downsamplings == [2, 3, 4] + + sample_cols = result.cols().collect() + for col in sample_cols: + assert "global_idx" in col.downsampling + + def test_annotate_downsamplings_matrix_table_with_gen_anc( + self, sample_matrix_table + ): + """Test annotate_downsamplings with MatrixTable input with genetic ancestry.""" + downsamplings = [2, 3, 4] + gen_anc_expr = sample_matrix_table.gen_anc + + result = annotate_downsamplings( + sample_matrix_table, downsamplings, gen_anc_expr + ) + + assert isinstance(result, hl.MatrixTable) + assert "downsampling" in result.col.dtype + assert "downsamplings" in result.globals.dtype + assert "ds_gen_anc_counts" in result.globals.dtype + + result_downsamplings = hl.eval(result.downsamplings) + assert result_downsamplings == [1, 2, 3, 4] + + gen_anc_counts = hl.eval(result.ds_gen_anc_counts) + assert gen_anc_counts == {"AFR": 2, "EUR": 2, "SAS": 1} + + sample_cols = result.cols().collect() + for col in sample_cols: + assert "global_idx" in col.downsampling + assert "gen_anc_idx" in col.downsampling + + def test_annotate_downsamplings_table_no_gen_anc(self, sample_table): + """Test annotate_downsamplings with Table input without genetic ancestry.""" + downsamplings = [2, 3, 4] + + result = annotate_downsamplings(sample_table, downsamplings) + + assert isinstance(result, hl.Table) + assert "downsampling" in result.row.dtype + assert "downsamplings" in result.globals.dtype + + result_downsamplings = hl.eval(result.downsamplings) + assert result_downsamplings == [2, 3, 4] + + rows = result.collect() + for row in rows: + assert "global_idx" in row.downsampling + + def test_annotate_downsamplings_table_with_gen_anc(self, sample_table): + """Test annotate_downsamplings with Table input with genetic ancestry.""" + downsamplings = [2, 3, 4] + gen_anc_expr = sample_table.gen_anc + + result = annotate_downsamplings(sample_table, downsamplings, gen_anc_expr) + + assert isinstance(result, hl.Table) + assert "downsampling" in result.row.dtype + assert "downsamplings" in result.globals.dtype + assert "ds_gen_anc_counts" in result.globals.dtype + + result_downsamplings = hl.eval(result.downsamplings) + assert result_downsamplings == [1, 2, 3, 4] + + gen_anc_counts = hl.eval(result.ds_gen_anc_counts) + assert gen_anc_counts == {"AFR": 2, "EUR": 2, "SAS": 1} + + rows = result.collect() + for row in rows: + assert "global_idx" in row.downsampling + assert "gen_anc_idx" in row.downsampling