Skip to content

Commit

Permalink
Add to_json() and from_json() methods to Cohort class (#1300)
Browse files Browse the repository at this point in the history
* Add to_json() and from_json() methods to Cohort class

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

* Address code review comments

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

* Fix linting

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
  • Loading branch information
gaugup committed May 26, 2022
1 parent 7b0d7cf commit fddca3a
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 10 deletions.
100 changes: 100 additions & 0 deletions raiwidgets/raiwidgets/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Module for defining cohorts in raiwidgets package."""

import json
from typing import Any, List, Optional

import numpy as np
Expand Down Expand Up @@ -110,6 +111,11 @@ def __init__(self, method: str, arg: List[Any], column: str):
self.arg = arg
self.column = column

def __eq__(self, cohort_filter: Any):
return self.method == cohort_filter.method and \
self.arg == cohort_filter.arg and \
self.column == cohort_filter.column

def _validate_cohort_filter_parameters(
self, method: str, arg: List[Any], column: str):
"""Validate the input values for the cohort filter.
Expand Down Expand Up @@ -397,6 +403,100 @@ def __init__(self, name: str):
self.name = name
self.cohort_filter_list = None

def __eq__(self, cohort: Any):
same_name = self.name == cohort.name
if self.cohort_filter_list is None and \
cohort.cohort_filter_list is None:
return same_name
elif self.cohort_filter_list is not None and \
cohort.cohort_filter_list is None:
return False
elif self.cohort_filter_list is None and \
cohort.cohort_filter_list is not None:
return False

same_num_cohort_filters = len(self.cohort_filter_list) == \
len(cohort.cohort_filter_list)
if not same_num_cohort_filters:
return False

same_cohort_filters = True
for index in range(0, len(self.cohort_filter_list)):
if self.cohort_filter_list[index] != \
cohort.cohort_filter_list[index]:
same_cohort_filters = False
break

return same_name and same_cohort_filters

@staticmethod
def _cohort_serializer(obj):
"""The function to serialize the Cohort class object.
:param obj: Any member of the Cohort class object.
:type: Any
:return: Python dictionary.
:rtype: dict[Any, Any]
"""
return obj.__dict__

def to_json(self):
"""Returns a serialized JSON string for the Cohort object.
:return: The JSON string for the cohort.
:rtype: str
"""
return json.dumps(self, default=Cohort._cohort_serializer)

@staticmethod
def _get_cohort_object(json_dict):
"""Method to read a JSON dictionary and return a Cohort object.
:param json_dict: JSON dictionary containing cohort data.
:type: dict[str, str]
:return: The Cohort object.
:rtype: Cohort
"""
cohort_fields = ["name", "cohort_filter_list"]
for cohort_field in cohort_fields:
if cohort_field not in json_dict:
raise UserConfigValidationException(
"No {0} field found for cohort deserialization".format(
cohort_field))

if not isinstance(json_dict['cohort_filter_list'], list):
raise UserConfigValidationException(
"Field cohort_filter_list not of type list for "
"cohort deserialization")

deserialized_cohort = Cohort(json_dict['name'])
for serialized_cohort_filter in json_dict['cohort_filter_list']:
cohort_filter_fields = ["method", "arg", "column"]
for cohort_filter_field in cohort_filter_fields:
if cohort_filter_field not in serialized_cohort_filter:
raise UserConfigValidationException(
"No {0} field found for cohort filter "
"deserialization".format(cohort_filter_field))

cohort_filter = CohortFilter(
method=serialized_cohort_filter['method'],
arg=serialized_cohort_filter['arg'],
column=serialized_cohort_filter['column'])
deserialized_cohort.add_cohort_filter(cohort_filter=cohort_filter)
return deserialized_cohort

@staticmethod
def from_json(json_str):
"""Method to read a json string and return a Cohort object.
:param json_str: Serialized JSON string.
:type: str
:return: The Cohort object.
:rtype: Cohort
"""
json_dict = json.loads(json_str)
return Cohort._get_cohort_object(json_dict)

def add_cohort_filter(self, cohort_filter: CohortFilter):
"""Add a cohort filter into the cohort.
:param cohort_filter: Cohort filter defined by CohortFilter class.
Expand Down
75 changes: 65 additions & 10 deletions raiwidgets/tests/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,59 +497,114 @@ def test_cohort_serialization_single_value_method(self, method):
arg=[65], column='age')
cohort_1 = Cohort(name="Cohort New")
cohort_1.add_cohort_filter(cohort_filter_1)
json_str = json.dumps(cohort_1,
default=cohort_filter_json_converter)
json_str = cohort_1.to_json()

assert 'Cohort New' in json_str
assert method in json_str
assert '[65]' in json_str
assert 'age' in json_str

def test_cohort_serialization_in_range_method(self):
def test_cohort_serialization_deserialization_in_range_method(self):
cohort_filter_1 = CohortFilter(
method=CohortFilterMethods.METHOD_RANGE,
arg=[65.0, 70.0], column='age')
cohort_1 = Cohort(name="Cohort New")
cohort_1.add_cohort_filter(cohort_filter_1)
json_str = json.dumps(cohort_1,
default=cohort_filter_json_converter)

json_str = cohort_1.to_json()
assert 'Cohort New' in json_str
assert CohortFilterMethods.METHOD_RANGE in json_str
assert '65.0' in json_str
assert '70.0' in json_str
assert 'age' in json_str

cohort_1_new = Cohort.from_json(json_str)
assert cohort_1_new.name == cohort_1.name
assert len(cohort_1_new.cohort_filter_list) == \
len(cohort_1.cohort_filter_list)
assert cohort_1_new.cohort_filter_list[0].method == \
cohort_1.cohort_filter_list[0].method

@pytest.mark.parametrize('method',
[CohortFilterMethods.METHOD_INCLUDES,
CohortFilterMethods.METHOD_EXCLUDES])
def test_cohort_serialization_include_exclude_methods(self, method):
def test_cohort_serialization_deserialization_include_exclude_methods(
self, method):
cohort_filter_str = CohortFilter(method=method,
arg=['val1', 'val2', 'val3'],
column='age')
cohort_str = Cohort(name="Cohort New Str")
cohort_str.add_cohort_filter(cohort_filter_str)
json_str = json.dumps(cohort_str,
default=cohort_filter_json_converter)

json_str = cohort_str.to_json()
assert method in json_str
assert 'val1' in json_str
assert 'val2' in json_str
assert 'val3' in json_str
assert 'age' in json_str
cohort_str_new = Cohort.from_json(json_str)
assert cohort_str == cohort_str_new

cohort_filter_int = CohortFilter(method=method,
arg=[1, 2, 3],
column='age')
cohort_int = Cohort(name="Cohort New Int")
cohort_int.add_cohort_filter(cohort_filter_int)
json_str = json.dumps(cohort_filter_int,
default=cohort_filter_json_converter)

json_str = cohort_int.to_json()
assert method in json_str
assert '1' in json_str
assert '2' in json_str
assert '3' in json_str
assert 'age' in json_str

cohort_int_new = Cohort.from_json(json_str)
assert cohort_int == cohort_int_new

def test_cohort_deserialization_error_conditions(self):
test_dict = {}
with pytest.raises(
UserConfigValidationException,
match="No name field found for cohort deserialization"):
Cohort.from_json(json.dumps(test_dict))

test_dict = {'name': 'Cohort New'}
with pytest.raises(
UserConfigValidationException,
match="No cohort_filter_list field found for "
"cohort deserialization"):
Cohort.from_json(json.dumps(test_dict))

test_dict = {'name': 'Cohort New', 'cohort_filter_list': {}}
with pytest.raises(UserConfigValidationException,
match="Field cohort_filter_list not of type list "
"for cohort deserialization"):
Cohort.from_json(json.dumps(test_dict))

test_dict = {'name': 'Cohort New', 'cohort_filter_list': [{}]}
with pytest.raises(
UserConfigValidationException,
match="No method field found for cohort filter "
"deserialization"):
Cohort.from_json(json.dumps(test_dict))

test_dict = {
'name': 'Cohort New',
'cohort_filter_list': [{"method": "fake_method"}]}
with pytest.raises(
UserConfigValidationException,
match="No arg field found for cohort filter deserialization"):
Cohort.from_json(json.dumps(test_dict))

test_dict = {
'name': 'Cohort New',
'cohort_filter_list': [{"method": "fake_method", "arg": []}]}
with pytest.raises(
UserConfigValidationException,
match="No column field found for cohort filter "
"deserialization"):
Cohort.from_json(json.dumps(test_dict))


class TestCohortList:
def test_cohort_list_serialization(self):
Expand Down

0 comments on commit fddca3a

Please sign in to comment.