diff --git a/responsibleai/responsibleai/rai_insights/rai_base_insights.py b/responsibleai/responsibleai/rai_insights/rai_base_insights.py index f108b41fe3..bb34ae05ab 100644 --- a/responsibleai/responsibleai/rai_insights/rai_base_insights.py +++ b/responsibleai/responsibleai/rai_insights/rai_base_insights.py @@ -29,7 +29,7 @@ class RAIBaseInsights(ABC): This class is abstract and should not be instantiated. """ - def __init__(self, model: Optional[Any], train: pd.DataFrame, + def __init__(self, model: Optional[Any], train: Optional[pd.DataFrame], test: pd.DataFrame, target_column: str, task_type: str, serializer: Optional[Any] = None): """Creates an RAIBaseInsights object. @@ -39,8 +39,11 @@ def __init__(self, model: Optional[Any], train: pd.DataFrame, or function that accepts a 2d ndarray. :type model: object :param train: The training dataset including the label column. + This parameter is optional as some extending downstream classes + may not require it. :type train: pandas.DataFrame :param test: The test dataset including the label column. + Note this parameter is always required. :type test: pandas.DataFrame :param target_column: The name of the label column. :type target_column: str @@ -138,13 +141,14 @@ def _save_data(self, path): """ data_directory = Path(path) / SerializationAttributes.DATA_DIRECTORY data_directory.mkdir(parents=True, exist_ok=True) - dtypes = self.train.dtypes.astype(str).to_dict() - self._write_to_file(data_directory / - (Metadata.TRAIN + _DTYPES + FileFormats.JSON), - json.dumps(dtypes)) - self._write_to_file(data_directory / - (Metadata.TRAIN + FileFormats.JSON), - self.train.to_json(orient='split')) + if self.train is not None: + dtypes = self.train.dtypes.astype(str).to_dict() + self._write_to_file(data_directory / + (Metadata.TRAIN + _DTYPES + FileFormats.JSON), + json.dumps(dtypes)) + self._write_to_file(data_directory / + (Metadata.TRAIN + FileFormats.JSON), + self.train.to_json(orient='split')) dtypes = self.test.dtypes.astype(str).to_dict() self._write_to_file(data_directory / @@ -231,12 +235,16 @@ def _load_data(inst, path): :type path: str """ data_directory = Path(path) / SerializationAttributes.DATA_DIRECTORY - with open(data_directory / - (Metadata.TRAIN + _DTYPES + FileFormats.JSON), 'r') as file: - types = json.load(file) - with open(data_directory / (Metadata.TRAIN + FileFormats.JSON), - 'r') as file: - train = pd.read_json(file, dtype=types, orient='split') + train_data_json = data_directory / (Metadata.TRAIN + FileFormats.JSON) + if train_data_json.exists(): + train_dtypes_path = (Metadata.TRAIN + _DTYPES + FileFormats.JSON) + train_dtypes_json = data_directory / train_dtypes_path + with open(train_dtypes_json, 'r') as file: + types = json.load(file) + with open(train_data_json, 'r') as file: + train = pd.read_json(file, dtype=types, orient='split') + else: + train = None inst.__dict__[Metadata.TRAIN] = train with open(data_directory / (Metadata.TEST + _DTYPES + FileFormats.JSON), 'r') as file: