Skip to content

Commit

Permalink
Make dummy-code except expressions as well as column field names
Browse files Browse the repository at this point in the history
  • Loading branch information
Will-Tyler committed Mar 8, 2024
1 parent 2313069 commit 312fcb1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
12 changes: 8 additions & 4 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Iterable, Optional, Dict, Tuple, Any, List
from typing import Iterable, Optional, Dict, Tuple, Any, List, Union
from collections import Counter
from deprecated import deprecated
import hail as hl
Expand Down Expand Up @@ -4567,7 +4567,7 @@ def _calculate_new_partitions(self, n_partitions):
ir.TableToValueApply(ht._tir, {'name': 'TableCalculateNewPartitions', 'nPartitions': n_partitions})
)

def dummy_code(self, *column_field_names: str):
def dummy_code(self, *column_fields: Union[str, Expression]):
"""Dummy code categorical variables.
Examples
Expand All @@ -4583,8 +4583,8 @@ def dummy_code(self, *column_field_names: str):
Parameters
----------
column_field_names : variable-length args of :obj:`str`
The names of the column fields to dummy code.
column_fields : variable-length args of :class:`str` or :class:`.Expression`
The column fields to dummy code.
Returns
-------
Expand All @@ -4595,6 +4595,10 @@ def dummy_code(self, *column_field_names: str):
:obj:`dict`
A dictionary mapping the column field names to the categories of the column field.
"""
column_field_names = [
self._fields_inverse[column_field] if isinstance(column_field, Expression) else column_field
for column_field in column_fields
]
field_name_to_categories = {
field_name: self.aggregate_cols(hl.agg.collect_as_set(self[field_name]))
for field_name in column_field_names
Expand Down
14 changes: 14 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,7 @@ def test_dummy_code(self):
matrix_table = hl.utils.range_matrix_table(10, 10)
smoking_categories = hl.literal(['current', 'former', 'never'])
matrix_table = matrix_table.annotate_cols(smoking_status=smoking_categories[matrix_table.col_idx % 3])
# Test using the column field name
matrix_table, dummy_variable_field_names, _ = matrix_table.dummy_code('smoking_status')
self.assertEqual(
{'smoking_status__current', 'smoking_status__former', 'smoking_status__never'},
Expand All @@ -1895,6 +1896,19 @@ def test_dummy_code(self):
self.assertEqual([0, 1, 0, 0, 1, 0, 0, 1, 0, 0], matrix_table['smoking_status__former'].collect())
self.assertEqual([0, 0, 1, 0, 0, 1, 0, 0, 1, 0], matrix_table['smoking_status__never'].collect())

# Reset matrix_table
matrix_table = hl.utils.range_matrix_table(10, 10)
matrix_table = matrix_table.annotate_cols(smoking_status=smoking_categories[matrix_table.col_idx % 3])
# Test using an expression
matrix_table, dummy_variable_field_names, _ = matrix_table.dummy_code(matrix_table.smoking_status)
self.assertEqual(
{'smoking_status__current', 'smoking_status__former', 'smoking_status__never'},
set(dummy_variable_field_names),
)
self.assertEqual([1, 0, 0, 1, 0, 0, 1, 0, 0, 1], matrix_table['smoking_status__current'].collect())
self.assertEqual([0, 1, 0, 0, 1, 0, 0, 1, 0, 0], matrix_table['smoking_status__former'].collect())
self.assertEqual([0, 0, 1, 0, 0, 1, 0, 0, 1, 0], matrix_table['smoking_status__never'].collect())


def test_keys_before_scans():
mt = hl.utils.range_matrix_table(6, 6)
Expand Down

0 comments on commit 312fcb1

Please sign in to comment.