Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make train data optional on RAIBaseInsights class #2029

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 22 additions & 14 deletions responsibleai/responsibleai/rai_insights/rai_base_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RAIBaseInsights(ABC):
This class is abstract and should not be instantiated.
"""

def __init__(self, model: Optional[Any], train: pd.DataFrame,
def __init__(self, model: Optional[Any], train: Optional[pd.DataFrame],
imatiach-msft marked this conversation as resolved.
Show resolved Hide resolved
test: pd.DataFrame, target_column: str, task_type: str,
serializer: Optional[Any] = None):
"""Creates an RAIBaseInsights object.
Expand All @@ -39,8 +39,11 @@ def __init__(self, model: Optional[Any], train: pd.DataFrame,
or function that accepts a 2d ndarray.
:type model: object
:param train: The training dataset including the label column.
This parameter is optional as some extending downstream classes
may not require it.
:type train: pandas.DataFrame
:param test: The test dataset including the label column.
Note this parameter is always required.
:type test: pandas.DataFrame
:param target_column: The name of the label column.
:type target_column: str
Expand Down Expand Up @@ -138,13 +141,14 @@ def _save_data(self, path):
"""
data_directory = Path(path) / SerializationAttributes.DATA_DIRECTORY
data_directory.mkdir(parents=True, exist_ok=True)
dtypes = self.train.dtypes.astype(str).to_dict()
self._write_to_file(data_directory /
(Metadata.TRAIN + _DTYPES + FileFormats.JSON),
json.dumps(dtypes))
self._write_to_file(data_directory /
(Metadata.TRAIN + FileFormats.JSON),
self.train.to_json(orient='split'))
if self.train is not None:
dtypes = self.train.dtypes.astype(str).to_dict()
self._write_to_file(data_directory /
(Metadata.TRAIN + _DTYPES + FileFormats.JSON),
json.dumps(dtypes))
self._write_to_file(data_directory /
(Metadata.TRAIN + FileFormats.JSON),
self.train.to_json(orient='split'))

dtypes = self.test.dtypes.astype(str).to_dict()
self._write_to_file(data_directory /
Expand Down Expand Up @@ -231,12 +235,16 @@ def _load_data(inst, path):
:type path: str
"""
data_directory = Path(path) / SerializationAttributes.DATA_DIRECTORY
with open(data_directory /
(Metadata.TRAIN + _DTYPES + FileFormats.JSON), 'r') as file:
types = json.load(file)
with open(data_directory / (Metadata.TRAIN + FileFormats.JSON),
'r') as file:
train = pd.read_json(file, dtype=types, orient='split')
train_data_json = data_directory / (Metadata.TRAIN + FileFormats.JSON)
if train_data_json.exists():
train_dtypes_path = (Metadata.TRAIN + _DTYPES + FileFormats.JSON)
train_dtypes_json = data_directory / train_dtypes_path
with open(train_dtypes_json, 'r') as file:
types = json.load(file)
with open(train_data_json, 'r') as file:
train = pd.read_json(file, dtype=types, orient='split')
else:
train = None
inst.__dict__[Metadata.TRAIN] = train
with open(data_directory /
(Metadata.TEST + _DTYPES + FileFormats.JSON), 'r') as file:
Expand Down