Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hashprep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .config import HashPrepConfig as HashPrepConfig
from .core.analyzer import DatasetAnalyzer as DatasetAnalyzer
from .utils.config_loader import load_config as load_config

__version__ = "0.1.0b2"
8 changes: 7 additions & 1 deletion hashprep/checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
def _check_dataset_drift(analyzer):
"""Wrapper for drift detection that uses analyzer's comparison_df."""
if hasattr(analyzer, "comparison_df") and analyzer.comparison_df is not None:
return check_drift(analyzer.df, analyzer.comparison_df)
drift_cfg = analyzer.config.drift
return check_drift(
analyzer.df,
analyzer.comparison_df,
threshold=drift_cfg.p_value,
config=drift_cfg,
)
return []


Expand Down
17 changes: 6 additions & 11 deletions hashprep/checks/columns.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from ..config import DEFAULT_CONFIG
from .core import Issue

_COL_THRESHOLDS = DEFAULT_CONFIG.columns


def _check_single_value_columns(analyzer):
issues = []
Expand All @@ -28,18 +25,15 @@ def _check_single_value_columns(analyzer):
return issues


def _check_high_cardinality(
analyzer,
threshold: int = _COL_THRESHOLDS.high_cardinality_count,
critical_threshold: float = _COL_THRESHOLDS.high_cardinality_ratio_critical,
):
def _check_high_cardinality(analyzer):
_cfg = analyzer.config.columns
issues = []
categorical_cols = analyzer.df.select_dtypes(include="object").columns.tolist()
for col in categorical_cols:
unique_count = int(analyzer.df[col].nunique())
unique_ratio = float(unique_count / len(analyzer.df))
if unique_count > threshold:
severity = "critical" if unique_ratio > critical_threshold else "warning"
if unique_count > _cfg.high_cardinality_count:
severity = "critical" if unique_ratio > _cfg.high_cardinality_ratio_critical else "warning"
impact = "high" if severity == "critical" else "medium"
quick_fix = (
"Options: \n- Drop column: Avoids overfitting from unique identifiers (Pros: Simplifies model; Cons: Loses potential info).\n- Engineer feature: Extract patterns (e.g., titles from names) (Pros: Retains useful info; Cons: Requires domain knowledge).\n- Use hashing: Reduce dimensionality (Pros: Scalable; Cons: May lose interpretability)."
Expand All @@ -61,10 +55,11 @@ def _check_high_cardinality(

def _check_duplicates(analyzer):
issues = []
_cfg = analyzer.config.columns
duplicate_rows = int(analyzer.df.duplicated().sum())
if duplicate_rows > 0:
duplicate_ratio = float(duplicate_rows / len(analyzer.df))
severity = "critical" if duplicate_ratio > _COL_THRESHOLDS.duplicate_ratio_critical else "warning"
severity = "critical" if duplicate_ratio > _cfg.duplicate_ratio_critical else "warning"
impact = "high" if severity == "critical" else "medium"
quick_fix = (
"Options: \n- Drop duplicates: Ensures data integrity (Pros: Cleaner data; Cons: May lose valid repeats).\n- Verify duplicates: Check if intentional (e.g., time-series) (Pros: Validates data; Cons: Time-consuming)."
Expand Down
16 changes: 7 additions & 9 deletions hashprep/checks/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@
import pandas as pd
from scipy.stats import chi2_contingency, kendalltau, pearsonr, spearmanr

from ..config import DEFAULT_CONFIG
from ..utils.type_inference import is_usable_for_corr
from .core import Issue
from .discretizer import DiscretizationType, Discretizer

_CORR = DEFAULT_CONFIG.correlations
CORR_THRESHOLDS = _CORR.as_nested_dict()
CAT_MAX_DISTINCT = _CORR.max_distinct_categories
LOW_CARD_NUM_THRESHOLD = _CORR.low_cardinality_numeric


def _cramers_v_corrected(table: pd.DataFrame) -> float:
if table.empty or (table.shape[0] == 1 or table.shape[1] == 1):
Expand All @@ -37,8 +31,9 @@ def calculate_correlations(analyzer, thresholds=None):
Compute correlations using internal defaults: Spearman + Pearson for numerics,
with Kendall added automatically for low-cardinality pairs.
"""
_cfg = analyzer.config.correlations
if thresholds is None:
thresholds = CORR_THRESHOLDS
thresholds = _cfg.as_nested_dict()

inferred_types = analyzer.column_types # Use analyzer.column_types for inferred types dict
issues = []
Expand All @@ -50,7 +45,7 @@ def calculate_correlations(analyzer, thresholds=None):
col
for col, typ in inferred_types.items()
if typ == "Categorical"
and 1 < analyzer.df[col].nunique() <= CAT_MAX_DISTINCT
and 1 < analyzer.df[col].nunique() <= _cfg.max_distinct_categories
and is_usable_for_corr(analyzer.df[col])
]

Expand All @@ -62,6 +57,7 @@ def calculate_correlations(analyzer, thresholds=None):


def _check_numeric_correlation(analyzer, numeric_cols: list, thresholds: dict):
_cfg = analyzer.config.correlations
issues = []
if len(numeric_cols) < 2:
return issues
Expand All @@ -85,7 +81,9 @@ def _check_numeric_correlation(analyzer, numeric_cols: list, thresholds: dict):

# Kendall (only for low-cardinality numerics)
kendall_corr, kendall_p = None, None
is_low_card = series1.nunique() <= LOW_CARD_NUM_THRESHOLD or series2.nunique() <= LOW_CARD_NUM_THRESHOLD
is_low_card = (
series1.nunique() <= _cfg.low_cardinality_numeric or series2.nunique() <= _cfg.low_cardinality_numeric
)
if is_low_card:
kendall_corr, kendall_p = kendalltau(series1, series2)
kendall_corr = abs(kendall_corr)
Expand Down
16 changes: 8 additions & 8 deletions hashprep/checks/datetime_checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import numpy as np
import pandas as pd

from ..config import DEFAULT_CONFIG
from .core import Issue

_DT_CFG = DEFAULT_CONFIG.datetime


def _coerce_datetime(series: pd.Series) -> pd.Series:
"""Return a datetime Series regardless of whether the source is datetime64 or object."""
Expand All @@ -21,6 +18,7 @@ def _datetime_cols(analyzer) -> list[str]:

def _check_datetime_future_dates(analyzer) -> list[Issue]:
"""Flag datetime columns that contain values in the future (likely data errors)."""
_cfg = analyzer.config.datetime
issues = []
now = pd.Timestamp.now()

Expand All @@ -34,7 +32,7 @@ def _check_datetime_future_dates(analyzer) -> list[Issue]:
continue

future_ratio = future_count / len(dt)
severity = "critical" if future_ratio > _DT_CFG.future_date_critical_ratio else "warning"
severity = "critical" if future_ratio > _cfg.future_date_critical_ratio else "warning"
impact = "high" if severity == "critical" else "medium"
issues.append(
Issue(
Expand All @@ -59,11 +57,12 @@ def _check_datetime_future_dates(analyzer) -> list[Issue]:

def _check_datetime_gaps(analyzer) -> list[Issue]:
"""Detect anomalously large gaps in datetime columns (broken time series)."""
_cfg = analyzer.config.datetime
issues = []

for col in _datetime_cols(analyzer):
dt = _coerce_datetime(analyzer.df[col]).sort_values()
if len(dt) < _DT_CFG.min_rows_for_gap_check:
if len(dt) < _cfg.min_rows_for_gap_check:
continue

diffs = dt.diff().dropna()
Expand All @@ -79,8 +78,8 @@ def _check_datetime_gaps(analyzer) -> list[Issue]:
max_gap = float(diff_seconds.max())
ratio = max_gap / median_gap

if ratio >= _DT_CFG.gap_multiplier_warning:
severity = "critical" if ratio >= _DT_CFG.gap_multiplier_critical else "warning"
if ratio >= _cfg.gap_multiplier_warning:
severity = "critical" if ratio >= _cfg.gap_multiplier_critical else "warning"
impact = "high" if severity == "critical" else "medium"

# Locate the gap for a human-readable description
Expand Down Expand Up @@ -113,11 +112,12 @@ def _check_datetime_gaps(analyzer) -> list[Issue]:

def _check_datetime_monotonicity(analyzer) -> list[Issue]:
"""Warn when a datetime column that looks like a time-series index is non-monotonic."""
_cfg = analyzer.config.datetime
issues = []

for col in _datetime_cols(analyzer):
dt = _coerce_datetime(analyzer.df[col])
if len(dt) < _DT_CFG.min_rows_for_gap_check:
if len(dt) < _cfg.min_rows_for_gap_check:
continue

# Only flag if the column has mostly unique values (i.e., likely an index/timestamp)
Expand Down
17 changes: 8 additions & 9 deletions hashprep/checks/distribution.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from scipy.stats import kstest

from ..config import DEFAULT_CONFIG
from .core import Issue

_DIST = DEFAULT_CONFIG.distribution


def _check_uniform_distribution(analyzer, p_threshold: float = _DIST.uniform_p_value) -> list[Issue]:
def _check_uniform_distribution(analyzer) -> list[Issue]:
"""
Detect uniformly distributed numeric columns using Kolmogorov-Smirnov test.
Uniform distributions often indicate synthetic IDs or sequential data.
"""
_cfg = analyzer.config.distribution
issues = []

for col in analyzer.df.select_dtypes(include="number").columns:
series = analyzer.df[col].dropna()
if len(series) < _DIST.uniform_min_samples:
if len(series) < _cfg.uniform_min_samples:
continue

min_val, max_val = series.min(), series.max()
Expand All @@ -26,7 +24,7 @@ def _check_uniform_distribution(analyzer, p_threshold: float = _DIST.uniform_p_v
_, p_val = kstest(normalized, "uniform")
is_monotonic = series.is_monotonic_increasing or series.is_monotonic_decreasing

if p_val > p_threshold or is_monotonic:
if p_val > _cfg.uniform_p_value or is_monotonic:
monotonic_note = " and monotonic" if is_monotonic else ""
issues.append(
Issue(
Expand All @@ -47,22 +45,23 @@ def _check_uniform_distribution(analyzer, p_threshold: float = _DIST.uniform_p_v
return issues


def _check_unique_values(analyzer, threshold: float = _DIST.unique_value_ratio) -> list[Issue]:
def _check_unique_values(analyzer) -> list[Issue]:
"""
Detect columns where nearly all values are unique.
High uniqueness often indicates identifiers, names, or free-text fields.
"""
_cfg = analyzer.config.distribution
issues = []

for col in analyzer.df.columns:
series = analyzer.df[col].dropna()
if len(series) < _DIST.unique_min_samples:
if len(series) < _cfg.unique_min_samples:
continue

unique_count = series.nunique()
unique_ratio = unique_count / len(series)

if unique_ratio >= threshold:
if unique_ratio >= _cfg.unique_value_ratio:
issues.append(
Issue(
category="unique_values",
Expand Down
20 changes: 11 additions & 9 deletions hashprep/checks/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
_log = get_logger("checks.drift")

_DRIFT = DEFAULT_CONFIG.drift
CRITICAL_P_VALUE = _DRIFT.critical_p_value
MAX_CATEGORIES_FOR_CHI2 = _DRIFT.max_categories_for_chi2


def check_drift(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
threshold: float = _DRIFT.p_value,
config=None,
) -> list[Issue]:
"""
Check for distribution shift between two datasets.
Expand All @@ -25,10 +24,11 @@ def check_drift(
if not isinstance(df_train, pd.DataFrame) or not isinstance(df_test, pd.DataFrame):
raise TypeError("Both df_train and df_test must be pandas DataFrames")

drift_cfg = config if config is not None else _DRIFT
issues = []

issues.extend(_check_numeric_drift(df_train, df_test, threshold))
issues.extend(_check_categorical_drift(df_train, df_test, threshold))
issues.extend(_check_numeric_drift(df_train, df_test, threshold, drift_cfg))
issues.extend(_check_categorical_drift(df_train, df_test, threshold, drift_cfg))

return issues

Expand All @@ -37,6 +37,7 @@ def _check_numeric_drift(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
threshold: float,
drift_cfg,
) -> list[Issue]:
"""Check numeric columns for distribution drift using KS-test."""
issues = []
Expand All @@ -55,7 +56,7 @@ def _check_numeric_drift(
stat, p_val = ks_2samp(train_vals, test_vals)

if p_val < threshold:
severity = "critical" if p_val < CRITICAL_P_VALUE else "warning"
severity = "critical" if p_val < drift_cfg.critical_p_value else "warning"
issues.append(
Issue(
category="dataset_drift",
Expand All @@ -74,6 +75,7 @@ def _check_categorical_drift(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
threshold: float,
drift_cfg,
) -> list[Issue]:
"""Check categorical columns for distribution drift using Chi-square test."""
issues = []
Expand All @@ -88,20 +90,20 @@ def _check_categorical_drift(

new_categories = set(test_counts.index) - set(train_counts.index)
if new_categories:
sample_new = list(new_categories)[: _DRIFT.max_new_category_samples]
sample_new = list(new_categories)[: drift_cfg.max_new_category_samples]
issues.append(
Issue(
category="dataset_drift",
severity="warning",
column=col,
description=f"New categories in test set for '{col}': {sample_new}{'...' if len(new_categories) > _DRIFT.max_new_category_samples else ''}",
description=f"New categories in test set for '{col}': {sample_new}{'...' if len(new_categories) > drift_cfg.max_new_category_samples else ''}",
impact_score="medium",
quick_fix="Handle unseen categories in preprocessing pipeline (e.g., OrdinalEncoder with unknown_value).",
)
)

all_cats = list(set(train_counts.index) | set(test_counts.index))
if len(all_cats) > MAX_CATEGORIES_FOR_CHI2:
if len(all_cats) > drift_cfg.max_categories_for_chi2:
continue

train_total = train_counts.sum()
Expand All @@ -127,7 +129,7 @@ def _check_categorical_drift(
chi2_stat, p_val = chisquare(observed_arr, f_exp=expected_arr)

if p_val < threshold:
severity = "critical" if p_val < CRITICAL_P_VALUE else "warning"
severity = "critical" if p_val < drift_cfg.critical_p_value else "warning"
issues.append(
Issue(
category="dataset_drift",
Expand Down
4 changes: 2 additions & 2 deletions hashprep/checks/imbalance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ..config import DEFAULT_CONFIG
from .core import Issue


def _check_class_imbalance(analyzer, threshold: float = DEFAULT_CONFIG.imbalance.majority_class_ratio):
def _check_class_imbalance(analyzer):
threshold = analyzer.config.imbalance.majority_class_ratio
issues = []
if analyzer.target_col and analyzer.target_col in analyzer.df.columns:
counts = analyzer.df[analyzer.target_col].value_counts(normalize=True)
Expand Down
Loading