From 312fcb174614f9ea3025161d92de03a2a693bf42 Mon Sep 17 00:00:00 2001 From: willtyler Date: Sun, 11 Feb 2024 02:57:30 +0000 Subject: [PATCH] Make dummy-code except expressions as well as column field names --- hail/python/hail/matrixtable.py | 12 ++++++++---- .../test/hail/matrixtable/test_matrix_table.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 54ca3264b21c..c40a88a5736a 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -1,5 +1,5 @@ import itertools -from typing import Iterable, Optional, Dict, Tuple, Any, List +from typing import Iterable, Optional, Dict, Tuple, Any, List, Union from collections import Counter from deprecated import deprecated import hail as hl @@ -4567,7 +4567,7 @@ def _calculate_new_partitions(self, n_partitions): ir.TableToValueApply(ht._tir, {'name': 'TableCalculateNewPartitions', 'nPartitions': n_partitions}) ) - def dummy_code(self, *column_field_names: str): + def dummy_code(self, *column_fields: Union[str, Expression]): """Dummy code categorical variables. Examples @@ -4583,8 +4583,8 @@ def dummy_code(self, *column_field_names: str): Parameters ---------- - column_field_names : variable-length args of :obj:`str` - The names of the column fields to dummy code. + column_fields : variable-length args of :class:`str` or :class:`.Expression` + The column fields to dummy code. Returns ------- @@ -4595,6 +4595,10 @@ def dummy_code(self, *column_field_names: str): :obj:`dict` A dictionary mapping the column field names to the categories of the column field. """ + column_field_names = [ + self._fields_inverse[column_field] if isinstance(column_field, Expression) else column_field + for column_field in column_fields + ] field_name_to_categories = { field_name: self.aggregate_cols(hl.agg.collect_as_set(self[field_name])) for field_name in column_field_names diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index e532b794d3a1..6440f3d2f5f7 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -1886,6 +1886,7 @@ def test_dummy_code(self): matrix_table = hl.utils.range_matrix_table(10, 10) smoking_categories = hl.literal(['current', 'former', 'never']) matrix_table = matrix_table.annotate_cols(smoking_status=smoking_categories[matrix_table.col_idx % 3]) + # Test using the column field name matrix_table, dummy_variable_field_names, _ = matrix_table.dummy_code('smoking_status') self.assertEqual( {'smoking_status__current', 'smoking_status__former', 'smoking_status__never'}, @@ -1895,6 +1896,19 @@ def test_dummy_code(self): self.assertEqual([0, 1, 0, 0, 1, 0, 0, 1, 0, 0], matrix_table['smoking_status__former'].collect()) self.assertEqual([0, 0, 1, 0, 0, 1, 0, 0, 1, 0], matrix_table['smoking_status__never'].collect()) + # Reset matrix_table + matrix_table = hl.utils.range_matrix_table(10, 10) + matrix_table = matrix_table.annotate_cols(smoking_status=smoking_categories[matrix_table.col_idx % 3]) + # Test using an expression + matrix_table, dummy_variable_field_names, _ = matrix_table.dummy_code(matrix_table.smoking_status) + self.assertEqual( + {'smoking_status__current', 'smoking_status__former', 'smoking_status__never'}, + set(dummy_variable_field_names), + ) + self.assertEqual([1, 0, 0, 1, 0, 0, 1, 0, 0, 1], matrix_table['smoking_status__current'].collect()) + self.assertEqual([0, 1, 0, 0, 1, 0, 0, 1, 0, 0], matrix_table['smoking_status__former'].collect()) + self.assertEqual([0, 0, 1, 0, 0, 1, 0, 0, 1, 0], matrix_table['smoking_status__never'].collect()) + def test_keys_before_scans(): mt = hl.utils.range_matrix_table(6, 6)