Skip to content

Commit

Permalink
Make dummy-code a matrix table method
Browse files Browse the repository at this point in the history
  • Loading branch information
Will-Tyler committed Mar 8, 2024
1 parent 11ab54d commit 2313069
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 60 deletions.
1 change: 0 additions & 1 deletion hail/python/hail/docs/methods/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ identity by descent [1]_, KING [2]_, and PC-Relate [3]_.
maximal_independent_set
rename_duplicates
segment_intervals
dummy_code

.. [1] Purcell, Shaun et al. “PLINK: a tool set for whole-genome association and
population-based linkage analyses.” American journal of human genetics
Expand Down
1 change: 0 additions & 1 deletion hail/python/hail/docs/methods/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ Miscellaneous
.. autofunction:: maximal_independent_set
.. autofunction:: rename_duplicates
.. autofunction:: segment_intervals
.. autofunction:: dummy_code
41 changes: 41 additions & 0 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4567,5 +4567,46 @@ def _calculate_new_partitions(self, n_partitions):
ir.TableToValueApply(ht._tir, {'name': 'TableCalculateNewPartitions', 'nPartitions': n_partitions})
)

def dummy_code(self, *column_field_names: str):
"""Dummy code categorical variables.
Examples
--------
>>> mt = hl.balding_nichols_model(1, 50, 100)
>>> smoking_categories = hl.literal(['current', 'former', 'never'])
>>> mt = mt.annotate_cols(
... smoking_status = smoking_categories[hl.rand_cat([1, 1, 1])],
... is_case = hl.rand_bool(0.5)
... )
>>> mt.dummy_code("smoking_status")
>>> mt.dummy_code(mt.smoking_status) # also works
Parameters
----------
column_field_names : variable-length args of :obj:`str`
The names of the column fields to dummy code.
Returns
-------
:class:`.MatrixTable`
The matrix table with the dummy-coded variables.
:obj:`list`
A list of the names of the dummy-coded variables.
:obj:`dict`
A dictionary mapping the column field names to the categories of the column field.
"""
field_name_to_categories = {
field_name: self.aggregate_cols(hl.agg.collect_as_set(self[field_name]))
for field_name in column_field_names
}
dummy_variables = {
f'{field_name}__{category}': hl.int(self[field_name] == category)
for field_name in column_field_names
for category in field_name_to_categories[field_name]
}
matrix_table = self.annotate_cols(**dummy_variables)
dummy_variable_field_names = list(dummy_variables)
return matrix_table, dummy_variable_field_names, field_name_to_categories


matrix_table_type.set(MatrixTable)
3 changes: 1 addition & 2 deletions hail/python/hail/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
compute_charr,
vep_json_typ,
)
from .misc import rename_duplicates, maximal_independent_set, segment_intervals, filter_intervals, dummy_code
from .misc import rename_duplicates, maximal_independent_set, segment_intervals, filter_intervals
from .relatedness import identity_by_descent, king, pc_relate, simulate_random_mating

__all__ = [
Expand Down Expand Up @@ -133,7 +133,6 @@
'balding_nichols_model',
'ld_prune',
'filter_intervals',
'dummy_code',
'segment_intervals',
'de_novo',
'filter_alleles',
Expand Down
43 changes: 0 additions & 43 deletions hail/python/hail/methods/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,46 +490,3 @@ def segment_intervals(ht, points):
)
ht = ht.annotate(__new_intervals=interval_results, lower=lower, higher=higher).explode('__new_intervals')
return ht.key_by(**{list(ht.key)[0]: ht.__new_intervals}).drop('__new_intervals')


def dummy_code(matrix_table: MatrixTable, *column_field_names: str):
"""Dummy code categorical variables.
Examples
--------
>>> mt = hail.balding_nichols_model(1, 50, 100)
>>> smoking_categories = hail.literal(['current', 'former', 'never'])
>>> mt = mt.annotate_cols(
... smoking_status = smoking_categories[hail.rand_cat([1, 1, 1])],
... is_case = hail.rand_bool(0.5)
... )
>>> dummy_code(mt, "smoking_status")
Parameters
----------
matrix_table : :class:`.MatrixTable`
A matrix table.
*column_field_names : :obj:`str`
The names of the column fields in the matrix table to dummy code.
Returns
-------
:class:`.MatrixTable`
The matrix table with the dummy-coded variables.
:obj:`list`
A list of the names of the dummy-coded variables.
:obj:`dict`
A dictionary mapping the column field names to the categories of the column field.
"""
field_name_to_categories = {
field_name: matrix_table.aggregate_cols(hl.agg.collect_as_set(matrix_table[field_name]))
for field_name in column_field_names
}
dummy_variables = {
f'{field_name}__{category}': hl.int(matrix_table[field_name] == category)
for field_name in column_field_names
for category in field_name_to_categories[field_name]
}
matrix_table = matrix_table.annotate_cols(**dummy_variables)
dummy_variable_field_names = list(dummy_variables)
return matrix_table, dummy_variable_field_names, field_name_to_categories
13 changes: 13 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,19 @@ def test_lower_row_agg_init_arg(self):
mt = mt.semi_join_rows(rows)
hl.hwe_normalized_pca(mt.GT)

def test_dummy_code(self):
matrix_table = hl.utils.range_matrix_table(10, 10)
smoking_categories = hl.literal(['current', 'former', 'never'])
matrix_table = matrix_table.annotate_cols(smoking_status=smoking_categories[matrix_table.col_idx % 3])
matrix_table, dummy_variable_field_names, _ = matrix_table.dummy_code('smoking_status')
self.assertEqual(
{'smoking_status__current', 'smoking_status__former', 'smoking_status__never'},
set(dummy_variable_field_names),
)
self.assertEqual([1, 0, 0, 1, 0, 0, 1, 0, 0, 1], matrix_table['smoking_status__current'].collect())
self.assertEqual([0, 1, 0, 0, 1, 0, 0, 1, 0, 0], matrix_table['smoking_status__former'].collect())
self.assertEqual([0, 0, 1, 0, 0, 1, 0, 0, 1, 0], matrix_table['smoking_status__never'].collect())


def test_keys_before_scans():
mt = hl.utils.range_matrix_table(6, 6)
Expand Down
13 changes: 0 additions & 13 deletions hail/python/test/hail/methods/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,3 @@ def test_segment_intervals(self):
hl.interval(52, 52),
]
)

def test_dummy_code(self):
mt = hl.utils.range_matrix_table(10, 10)
smoking_categories = hl.literal(['current', 'former', 'never'])
mt = mt.annotate_cols(smoking_status=smoking_categories[mt.col_idx % 3])
mt, dummy_variable_field_names, _ = hl.methods.misc.dummy_code(mt, 'smoking_status')
self.assertEqual(
{'smoking_status__current', 'smoking_status__former', 'smoking_status__never'},
set(dummy_variable_field_names),
)
self.assertEqual([1, 0, 0, 1, 0, 0, 1, 0, 0, 1], mt['smoking_status__current'].collect())
self.assertEqual([0, 1, 0, 0, 1, 0, 0, 1, 0, 0], mt['smoking_status__former'].collect())
self.assertEqual([0, 0, 1, 0, 0, 1, 0, 0, 1, 0], mt['smoking_status__never'].collect())

0 comments on commit 2313069

Please sign in to comment.