Skip to content

Commit

Permalink
PLAT-1008: Make datetime detection more strict by optionally supporti…
Browse files Browse the repository at this point in the history
…ng a must_match_all parameter

GitOrigin-RevId: a8a43e3b3b27d0375d7afc682661711ba630ef51
  • Loading branch information
misberner committed Aug 30, 2023
1 parent c8cfc19 commit bc1016e
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/gretel_synthetics/actgan/actgan_wrapper.py
Expand Up @@ -70,7 +70,7 @@ def fit(self, data: Union[pd.DataFrame, str]) -> None:
if self._auto_transform_datetimes:
if self._verbose:
logger.info("Attempting datetime auto-detection...")
detector.fit_datetime(data, with_suffix=True)
detector.fit_datetime(data, with_suffix=True, must_match_all=True)

detector.fit_empty_columns(data)
if self._verbose:
Expand Down
74 changes: 69 additions & 5 deletions src/gretel_synthetics/detectors/dates.py
Expand Up @@ -309,6 +309,23 @@ def _maybe_match(date, format) -> Tuple[Optional[datetime], Optional[str]]:
return None, None


def _check_series(series: pd.Series, format: str) -> bool:
# Remove non-standard formatting directives which are relevant for formatting
# only, not for parsing. The first one, `!`, is introduced by us (see
# ``_strptime_extra``), the second one, `%-`, is a directive not recognized
# by pandas and stripped by RDT as well (see
# https://github.com/sdv-dev/RDT/pull/458/files#r835690711 ).
pd_format = format.replace("!", "").replace("%-", "%")
try:
pd.to_datetime(series, format=pd_format)
return True
except:
# Conservatively ignore any error, and assume that the format
# didn't work.
# This is to prevent errors in the SDV code downstream.
return False


def _parse_date_multiple(
input_date: str,
date_str_fmts: Union[List[str], Set[str]] = _date_str_fmt_permutations,
Expand All @@ -334,7 +351,46 @@ def _maybe_d_str_to_fmt_multiple(input_date: str, with_suffix: bool) -> Iterator
pass


def _infer_from_series(series: Iterable[str], with_suffix: bool) -> Optional[str]:
def _infer_from_series_match_all(series: pd.Series, with_suffix: bool) -> Optional[str]:
if series.empty:
return None

# We store the candidate formats as a list instead of a set to ensure a deterministic
# result (the order of ``_maybe_d_str_to_fmt_multiple`` is deterministic as well).
# This matches the behavior of ``_infer_from_series``, which - due to the above
# property as well as ``Counter``s stable iteration based on insertion order -
# is deterministic as well.
candidate_fmts = list(_maybe_d_str_to_fmt_multiple(series[0], with_suffix))
i = 1
# Empirically, ``pd.to_datetime`` is about 8x faster than checking individual values.
# Conservatively, we fall back to calling ``pd.to_datetime`` on the entire remaining
# series when we have 4 or less candidate formats less.
# In most cases, the number of candidate formats will be lower than both 4 and 8
# after the first invocation anyway.
while len(candidate_fmts) > 4 and i < len(series):
value = series[i]
candidate_fmts = [
fmt for fmt in candidate_fmts if _maybe_match(value, fmt) != (None, None)
]
i += 1

if i < len(series):
# If we haven't exhausted the whole series yet, do a ``pd.to_datetime``
# call for the remaining values to weed out incorrect formats.
remaining_series = series[i:]
candidate_fmts = [
fmt for fmt in candidate_fmts if _check_series(remaining_series, fmt)
]

return candidate_fmts[0] if candidate_fmts else None


def _infer_from_series(
series: pd.Series, with_suffix: bool, must_match_all: bool = False
) -> Optional[str]:
if must_match_all:
return _infer_from_series_match_all(series, with_suffix)

counter = Counter()
for value in series:
for fmt in _maybe_d_str_to_fmt_multiple(value, with_suffix):
Expand All @@ -347,7 +403,10 @@ def _infer_from_series(series: Iterable[str], with_suffix: bool) -> Optional[str


def detect_datetimes(
df: pd.DataFrame, sample_size: Optional[int] = None, with_suffix: bool = False
df: pd.DataFrame,
sample_size: Optional[int] = None,
with_suffix: bool = False,
must_match_all: bool = False,
) -> DateTimeColumns:
if sample_size is None:
sample_size = SAMPLE_SIZE
Expand All @@ -356,9 +415,14 @@ def detect_datetimes(
col for col, col_type in df.dtypes.iteritems() if col_type == "object"
]
for object_col in object_cols:
curr_series: pd.Series = df[object_col].dropna(axis=0).reset_index(drop=True)
sampled_series_str = (curr_series.sample(sample_size, replace=True)).astype(str)
inferred_format = _infer_from_series(sampled_series_str, with_suffix)
test_series: pd.Series = df[object_col].dropna(axis=0).reset_index(drop=True)
# Only sample when we don't require the format to match all entries
if not must_match_all and len(test_series) > sample_size:
test_series = test_series.sample(sample_size)
test_series_str = test_series.astype(str)
inferred_format = _infer_from_series(
test_series_str, with_suffix, must_match_all
)
if inferred_format is not None:
inferred_format = inferred_format.replace("!", "")
column_data.columns[object_col] = DateTimeColumn(
Expand Down
6 changes: 5 additions & 1 deletion src/gretel_synthetics/detectors/sdv.py
Expand Up @@ -164,9 +164,13 @@ def fit_datetime(
data: pd.DataFrame,
sample_size: Optional[int] = None,
with_suffix: bool = False,
must_match_all: bool = False,
) -> None:
detections = detect_datetimes(
data, sample_size=sample_size, with_suffix=with_suffix
data,
sample_size=sample_size,
with_suffix=with_suffix,
must_match_all=must_match_all,
)
for _, column_info in detections.columns.items():
type_, transformer = datetime_column_to_sdv(column_info)
Expand Down
115 changes: 85 additions & 30 deletions tests/detectors/test_detectors_dates.py
Expand Up @@ -2,6 +2,7 @@

from datetime import datetime, timedelta, timezone

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -67,31 +68,70 @@ def test_date_str_tokenizer(input_str, expected_mask):
assert _tokenize_date_str(input_str).masked_str == expected_mask


def test_infer_from_series():
dates = ["12/20/2020", "10/17/2020", "08/10/2020", "01/22/2020", "09/01/2020"]
assert _infer_from_series(dates, False) == "%m/%d/%Y"
@pytest.mark.parametrize("must_match_all", [False, True])
def test_infer_from_series(must_match_all):
dates = pd.Series(
["12/20/2020", "10/17/2020", "08/10/2020", "01/22/2020", "09/01/2020"]
)
assert _infer_from_series(dates, False, must_match_all=must_match_all) == "%m/%d/%Y"


def test_infer_from_bad_date():
dates = ["#NAME?", "1000#", "Jim", "3", "$moola"]
assert _infer_from_series(dates, False) is None
@pytest.mark.parametrize("must_match_all", [False, True])
def test_infer_from_bad_date(must_match_all):
dates = pd.Series(["#NAME?", "1000#", "Jim", "3", "$moola"])
assert _infer_from_series(dates, False, must_match_all=must_match_all) is None


def test_infer_from_some_bad_date():
dates = ["#NAME?", "1000#", "Jim", "3", "10/17/2020"]
assert _infer_from_series(dates, False) == "%m/%d/%Y"
dates = pd.Series(["#NAME?", "1000#", "Jim", "3", "10/17/2020"])
assert _infer_from_series(dates, False, must_match_all=False) == "%m/%d/%Y"


def test_infer_from_some_bad_date_with_match_all():
dates = pd.Series(["#NAME?", "1000#", "Jim", "3", "10/17/2020"])
assert _infer_from_series(dates, False, must_match_all=True) is None


@pytest.mark.parametrize("must_match_all", [False, True])
def test_infer_from_12_hour(must_match_all):
dates = pd.Series(["8:15 AM", "9:20 PM", "1:55 PM"])
assert _infer_from_series(dates, False, must_match_all=must_match_all) == "%I:%M %p"


@pytest.mark.parametrize("with_suffix", [True, False])
@pytest.mark.parametrize("must_match_all", [False, True])
def test_detect_datetimes(with_suffix, must_match_all, test_df):
# Based on the values in the DF, we assert the `with_suffix` flag
# should not change any of the results
check = detect_datetimes(
test_df, with_suffix=with_suffix, must_match_all=must_match_all
)
assert set(check.column_names) == {"dates", "iso"}
assert check.get_column_info("random") is None

dates = check.get_column_info("dates")
assert dates.name == "dates"
assert dates.inferred_format == "%m/%d/%Y"

def test_infer_from_12_hour():
dates = ["8:15 AM", "9:20 PM", "1:55 PM"]
assert _infer_from_series(dates, False) == "%I:%M %p"
iso = check.get_column_info("iso")
assert iso.name == "iso"
assert iso.inferred_format == "%Y-%m-%dT%X.%f"


@pytest.mark.parametrize("with_suffix", [True, False])
def test_detect_datetimes(with_suffix, test_df):
@pytest.mark.parametrize("must_match_all", [False, True])
def test_detect_datetimes_with_nans(with_suffix, must_match_all, test_df):
# Create a copy to prevent modification to the session-scoped fixture
# object.
test_df = test_df.copy()
# Blank out first row
test_df.iloc[0, :] = np.nan

# Based on the values in the DF, we assert the `with_suffix` flag
# should not change any of the results
check = detect_datetimes(test_df, with_suffix=with_suffix)
check = detect_datetimes(
test_df, with_suffix=with_suffix, must_match_all=must_match_all
)
assert set(check.column_names) == {"dates", "iso"}
assert check.get_column_info("random") is None

Expand All @@ -104,27 +144,41 @@ def test_detect_datetimes(with_suffix, test_df):
assert iso.inferred_format == "%Y-%m-%dT%X.%f"


def test_infer_with_suffix():
dates = [
"2020-12-20T00:00:00Z",
"2020-10-17T00:00:00Z",
"2020-08-10T00:00:00Z",
"2020-01-22T00:00:00Z",
"2020-09-01T00:00:00Z",
]
assert _infer_from_series(dates, True) == "%Y-%m-%dT%XZ"
@pytest.mark.parametrize("must_match_all", [False, True])
def test_infer_with_suffix(must_match_all):
dates = pd.Series(
[
"2020-12-20T00:00:00Z",
"2020-10-17T00:00:00Z",
"2020-08-10T00:00:00Z",
"2020-01-22T00:00:00Z",
"2020-09-01T00:00:00Z",
]
)
assert (
_infer_from_series(dates, True, must_match_all=must_match_all) == "%Y-%m-%dT%XZ"
)

dates_2 = [d.replace("Z", "+00:00") for d in dates.copy()]
assert _infer_from_series(dates_2, True) == "%Y-%m-%dT%X+00:00"
dates_2 = pd.Series([d.replace("Z", "+00:00") for d in dates])
assert (
_infer_from_series(dates_2, True, must_match_all=must_match_all)
== "%Y-%m-%dT%X+00:00"
)

dates_3 = [d.replace("Z", "-00:00") for d in dates.copy()]
assert _infer_from_series(dates_3, True) == "%Y-%m-%dT%X-00:00"
dates_3 = pd.Series([d.replace("Z", "-00:00") for d in dates])
assert (
_infer_from_series(dates_3, True, must_match_all=must_match_all)
== "%Y-%m-%dT%X-00:00"
)


def test_detect_datetimes_with_suffix(test_df):
@pytest.mark.parametrize("must_match_all", [False, True])
def test_detect_datetimes_with_suffix(must_match_all, test_df):
# Prevent modification of the session-scoped fixture object
test_df = test_df.copy()
# Add a TZ suffix of "Z" to the iso strings
test_df["iso"] = test_df["iso"].astype("string").apply(lambda val: val + "Z")
check = detect_datetimes(test_df, with_suffix=True)
check = detect_datetimes(test_df, with_suffix=True, must_match_all=must_match_all)
assert set(check.column_names) == {"dates", "iso"}

iso = check.get_column_info("iso")
Expand All @@ -134,7 +188,8 @@ def test_detect_datetimes_with_suffix(test_df):
assert iso.inferred_format == "%Y-%m-%dT%X.%fZ"


def test_detect_datetimes_custom_formats():
@pytest.mark.parametrize("must_match_all", [False, True])
def test_detect_datetimes_custom_formats(must_match_all):
df = pd.DataFrame(
{
"str": ["a", "b", "c"],
Expand All @@ -151,7 +206,7 @@ def test_detect_datetimes_custom_formats():
}
)

check = detect_datetimes(df)
check = detect_datetimes(df, must_match_all=must_match_all)

assert set(check.column_names) == {
"dateandtime",
Expand Down

0 comments on commit bc1016e

Please sign in to comment.