Skip to content

Commit

Permalink
dataset mutable properties fix (#509)
Browse files Browse the repository at this point in the history
* dataset small improvement
* removed redundant method from the dataset
  • Loading branch information
yromanyshyn committed Jan 5, 2022
1 parent 1112b97 commit de8742f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
56 changes: 26 additions & 30 deletions deepchecks/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ class Dataset:
_data: pd.DataFrame
_max_categorical_ratio: float
_max_categories: int
_label_type: str
_label_type: t.Optional[str]

def __init__(
self,
df: pd.DataFrame,
label: t.Union[Hashable, pd.Series, pd.DataFrame, np.array] = None,
label: t.Union[Hashable, pd.Series, pd.DataFrame, np.ndarray] = None,
features: t.Optional[t.Sequence[Hashable]] = None,
cat_features: t.Optional[t.Sequence[Hashable]] = None,
index_name: t.Optional[Hashable] = None,
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(
self._data[label_name] = label
else:
self._data[label_name] = np.array(label).reshape(-1, 1)
elif isinstance(label, t.Hashable):
elif isinstance(label, Hashable):
label_name = label
if label_name not in self._data.columns:
raise DeepchecksValueError(f'label column {label_name} not found in dataset columns')
Expand Down Expand Up @@ -246,12 +246,6 @@ def __init__(
if self._label_name in self.features:
raise DeepchecksValueError(f'label column {self._label_name} can not be a feature column')

if self._label_name:
try:
self.check_compatible_labels()
except DeepchecksValueError as e:
logger.warning(str(e))

if self._datetime_name in self.features:
raise DeepchecksValueError(f'datetime column {self._datetime_name} can not be a feature column')

Expand Down Expand Up @@ -402,8 +396,15 @@ def data(self) -> pd.DataFrame:
"""Return the data of dataset."""
return self._data

def copy(self: TDataset, new_data) -> TDataset:
"""Create a copy of this Dataset with new data."""
def copy(self: TDataset, new_data: pd.DataFrame) -> TDataset:
"""Create a copy of this Dataset with new data.
Args:
new_data (DataFrame): new data from which new dataset will be created
Returns:
Dataset: new dataset instance
"""
# Filter out if columns were dropped
features = [feat for feat in self._features if feat in new_data.columns]
cat_features = [feat for feat in self.cat_features if feat in new_data.columns]
Expand All @@ -419,7 +420,7 @@ def copy(self: TDataset, new_data) -> TDataset:
convert_datetime=False, max_categorical_ratio=self._max_categorical_ratio,
max_categories=self._max_categories, label_type=self.label_type)

def sample(self, n_samples: int, replace: bool = False, random_state: t.Optional[int] = None) -> TDataset:
def sample(self: TDataset, n_samples: int, replace: bool = False, random_state: t.Optional[int] = None) -> TDataset:
"""Create a copy of the dataset object, with the internal dataframe being a sample of the original dataframe.
Args:
Expand Down Expand Up @@ -447,20 +448,21 @@ def __len__(self) -> int:
return self.n_samples

@property
def label_type(self):
def label_type(self) -> t.Optional[str]:
"""Return the label type.
Returns:
Label type
"""
return self._label_type

def train_test_split(self,
def train_test_split(self: TDataset,
train_size: t.Union[int, float, None] = None,
test_size: t.Union[int, float] = 0.25,
random_state: int = 42,
shuffle: bool = True,
stratify: t.Union[t.List, pd.Series, np.ndarray, bool] = False) -> t.Tuple[TDataset, TDataset]:
stratify: t.Union[t.List, pd.Series, np.ndarray, bool] = False
) -> t.Tuple[TDataset, TDataset]:
"""Split dataset into random train and test datasets.
Args:
Expand Down Expand Up @@ -648,7 +650,7 @@ def features(self) -> t.List[Hashable]:
Returns:
List of feature names.
"""
return self._features
return list(self._features)

@property
def cat_features(self) -> t.List[Hashable]:
Expand All @@ -657,7 +659,7 @@ def cat_features(self) -> t.List[Hashable]:
Returns:
List of categorical feature names.
"""
return self._cat_features
return list(self._cat_features)

@property
def features_columns(self) -> t.Optional[pd.DataFrame]:
Expand All @@ -670,15 +672,15 @@ def features_columns(self) -> t.Optional[pd.DataFrame]:

@property
@lru_cache(maxsize=128)
def classes(self) -> t.List[str]:
def classes(self) -> t.Tuple[str, ...]:
"""Return the classes from label column in sorted list. if no label column defined, return empty list.
Returns:
Sorted classes
"""
if self.label_col is not None:
return sorted(self.label_col.dropna().unique().tolist())
return []
return tuple(sorted(self.label_col.dropna().unique().tolist()))
return tuple()

@property
def columns_info(self) -> t.Dict[Hashable, str]:
Expand All @@ -705,12 +707,6 @@ def columns_info(self) -> t.Dict[Hashable, str]:
columns[column] = value
return columns

def check_compatible_labels(self):
"""Check if label column is supported by deepchecks."""
labels = self.label_col
if labels is None:
return

# Validations:

def validate_label(self):
Expand Down Expand Up @@ -795,7 +791,7 @@ def validate_shared_features(self, other) -> t.List[Hashable]:
"""
Dataset.validate_dataset(other)
if sorted(self.features) == sorted(other.features):
return self.features
return list(self.features)
else:
raise DeepchecksValueError('Check requires datasets to share the same features')

Expand All @@ -814,7 +810,7 @@ def validate_shared_categorical_features(self, other) -> t.List[Hashable]:
"""
Dataset.validate_dataset(other)
if sorted(self.cat_features) == sorted(other.cat_features):
return self.cat_features
return list(self.cat_features)
else:
raise DeepchecksValueError('Check requires datasets to share '
'the same categorical features. Possible reason is that some columns were'
Expand All @@ -838,8 +834,8 @@ def validate_shared_label(self, other) -> Hashable:
"""
Dataset.validate_dataset(other)
if (
self.label_name is not None and other.label_name is not None
and self.label_name == other.label_name
self.label_name is not None and other.label_name is not None
and self.label_name == other.label_name
):
return t.cast(Hashable, self.label_name)
else:
Expand Down
20 changes: 20 additions & 0 deletions tests/base/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def assert_dataset(dataset: Dataset, args):
is_(True)
)

def test_that_mutable_properties_modification_does_not_affect_dataset_state(iris):
dataset = Dataset(
df=iris,
features=[
'sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)'
]
)

features = dataset.features
cat_features = dataset.cat_features

features.append("New value")
cat_features.append("New value")

assert_that("New value" not in dataset.features)
assert_that("New value" not in dataset.cat_features)


def test_dataset_empty_df(empty_df):
args = {'df': empty_df}
Expand Down

0 comments on commit de8742f

Please sign in to comment.