Skip to content

Commit

Permalink
make train data optional on RAIBaseInsights class (#2029)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Apr 21, 2023
1 parent af5f4a1 commit 5c8671b
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions responsibleai/responsibleai/rai_insights/rai_base_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RAIBaseInsights(ABC):
This class is abstract and should not be instantiated.
"""

def __init__(self, model: Optional[Any], train: pd.DataFrame,
def __init__(self, model: Optional[Any], train: Optional[pd.DataFrame],
test: pd.DataFrame, target_column: str, task_type: str,
serializer: Optional[Any] = None):
"""Creates an RAIBaseInsights object.
Expand All @@ -39,8 +39,11 @@ def __init__(self, model: Optional[Any], train: pd.DataFrame,
or function that accepts a 2d ndarray.
:type model: object
:param train: The training dataset including the label column.
This parameter is optional as some extending downstream classes
may not require it.
:type train: pandas.DataFrame
:param test: The test dataset including the label column.
Note this parameter is always required.
:type test: pandas.DataFrame
:param target_column: The name of the label column.
:type target_column: str
Expand Down Expand Up @@ -138,13 +141,14 @@ def _save_data(self, path):
"""
data_directory = Path(path) / SerializationAttributes.DATA_DIRECTORY
data_directory.mkdir(parents=True, exist_ok=True)
dtypes = self.train.dtypes.astype(str).to_dict()
self._write_to_file(data_directory /
(Metadata.TRAIN + _DTYPES + FileFormats.JSON),
json.dumps(dtypes))
self._write_to_file(data_directory /
(Metadata.TRAIN + FileFormats.JSON),
self.train.to_json(orient='split'))
if self.train is not None:
dtypes = self.train.dtypes.astype(str).to_dict()
self._write_to_file(data_directory /
(Metadata.TRAIN + _DTYPES + FileFormats.JSON),
json.dumps(dtypes))
self._write_to_file(data_directory /
(Metadata.TRAIN + FileFormats.JSON),
self.train.to_json(orient='split'))

dtypes = self.test.dtypes.astype(str).to_dict()
self._write_to_file(data_directory /
Expand Down Expand Up @@ -231,12 +235,16 @@ def _load_data(inst, path):
:type path: str
"""
data_directory = Path(path) / SerializationAttributes.DATA_DIRECTORY
with open(data_directory /
(Metadata.TRAIN + _DTYPES + FileFormats.JSON), 'r') as file:
types = json.load(file)
with open(data_directory / (Metadata.TRAIN + FileFormats.JSON),
'r') as file:
train = pd.read_json(file, dtype=types, orient='split')
train_data_json = data_directory / (Metadata.TRAIN + FileFormats.JSON)
if train_data_json.exists():
train_dtypes_path = (Metadata.TRAIN + _DTYPES + FileFormats.JSON)
train_dtypes_json = data_directory / train_dtypes_path
with open(train_dtypes_json, 'r') as file:
types = json.load(file)
with open(train_data_json, 'r') as file:
train = pd.read_json(file, dtype=types, orient='split')
else:
train = None
inst.__dict__[Metadata.TRAIN] = train
with open(data_directory /
(Metadata.TEST + _DTYPES + FileFormats.JSON), 'r') as file:
Expand Down

0 comments on commit 5c8671b

Please sign in to comment.