Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 24 additions & 30 deletions mostlyai/qa/_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools
import hashlib
import logging
Expand Down Expand Up @@ -1034,7 +1035,7 @@ def binning_data(
def bin_data(
df: pd.DataFrame,
bins: int | dict[str, list],
non_categorical_label_style: Literal["bounded", "unbounded"] = "unbounded",
non_categorical_label_style: Literal["bounded", "unbounded", "lower"] = "unbounded",
) -> tuple[pd.DataFrame, dict[str, list]]:
"""
Splits data into bins.
Expand All @@ -1048,41 +1049,32 @@ def bin_data(

# Note, that we create a new pd.DataFrame to avoid fragmentation warning messages that can occur if we try to
# replace hundreds of columns of a large dataset
cols = {}

bins_dct = {}
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
dat_cols = [c for c in df.columns if pd.api.types.is_datetime64_any_dtype(df[c])]
cat_cols = [c for c in df.columns if c not in num_cols + dat_cols]
cols, bins_dct = {}, {}
if isinstance(bins, int):
for col in num_cols:
cols[col], bins_dct[col] = bin_numeric(df[col], bins, label_style=non_categorical_label_style)
for col in dat_cols:
cols[col], bins_dct[col] = bin_datetime(df[col], bins, label_style=non_categorical_label_style)
for col in cat_cols:
cols[col], bins_dct[col] = bin_categorical(df[col], bins)
else: # bins is a dict
for col in num_cols:
if col in bins:
cols[col], _ = bin_numeric(df[col], bins[col], label_style=non_categorical_label_style)
for col in df.columns:
if pd.api.types.is_numeric_dtype(df[col]):
cols[col], bins_dct[col] = bin_numeric(df[col], bins, label_style=non_categorical_label_style)
elif pd.api.types.is_datetime64_any_dtype(df[col]):
cols[col], bins_dct[col] = bin_datetime(df[col], bins, label_style=non_categorical_label_style)
else:
_LOG.warning(f"'{col}' is missing in bins")
for col in dat_cols:
if col in bins:
cols[col], _ = bin_datetime(df[col], bins[col], label_style=non_categorical_label_style)
else:
_LOG.warning(f"'{col}' is missing in bins")
for col in cat_cols:
cols[col], bins_dct[col] = bin_categorical(df[col], bins)
else: # bins is a dict
for col in df.columns:
if col in bins:
cols[col], _ = bin_categorical(df[col], bins[col])
if isinstance(bins[col][0], (int, float, np.integer, np.floating)):
cols[col], _ = bin_numeric(df[col], bins[col], label_style=non_categorical_label_style)
elif isinstance(bins[col][0], (datetime.date, datetime.datetime, np.datetime64)):
cols[col], _ = bin_datetime(df[col], bins[col], label_style=non_categorical_label_style)
else:
cols[col], _ = bin_categorical(df[col], bins[col])
else:
_LOG.warning(f"'{col}' is missing in bins")
cols[col] = df[col]
bins_dct = bins
return pd.DataFrame(cols), bins_dct


def bin_numeric(
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded"] = "unbounded"
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded", "lower"] = "unbounded"
) -> tuple[pd.Categorical, list]:
def _clip(col, bins):
if isinstance(bins, list):
Expand Down Expand Up @@ -1131,7 +1123,7 @@ def _adjust_breaks(breaks):


def bin_datetime(
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded"] = "unbounded"
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded", "lower"] = "unbounded"
) -> tuple[pd.Categorical, list]:
def _clip(col, bins):
if isinstance(bins, list):
Expand Down Expand Up @@ -1184,7 +1176,7 @@ def bin_non_categorical(
clip_and_breaks: Callable,
create_labels: Callable,
adjust_breaks: Callable,
label_style: Literal["bounded", "unbounded"] = "unbounded",
label_style: Literal["bounded", "unbounded", "lower"] = "unbounded",
) -> tuple[pd.Categorical, list]:
col = col.fillna(np.nan).infer_objects(copy=False)

Expand All @@ -1203,7 +1195,9 @@ def bin_non_categorical(
)
labels = [str(b) for b in breaks[:-1]]

if label_style == "unbounded":
if label_style == "lower":
new_labels_map = {label: f"{label}" for label in labels}
elif label_style == "unbounded":
new_labels_map = {label: f"⪰ {label}" for label in labels}
else: # label_style == "bounded"
new_labels_map = {label: f"⪰ {label} ≺ {next_label}" for label, next_label in zip(labels, labels[1:] + ["∞"])}
Expand Down
31 changes: 15 additions & 16 deletions mostlyai/qa/_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import random
import time
Expand Down Expand Up @@ -229,7 +230,7 @@ def pull_data_for_embeddings(
ctx_primary_key: str | None = None,
tgt_context_key: str | None = None,
max_sample_size: int | None = None,
tgt_num_dat_bins: dict[str, list] | None = None,
bins: dict[str, list] | None = None,
) -> list[str]:
_LOG.info("pulling data for embeddings")
t0 = time.time()
Expand Down Expand Up @@ -264,12 +265,19 @@ def pull_data_for_embeddings(
df_tgt = df_tgt.rename(columns={tgt_context_key: key})
tgt_context_key = key

# bin numeric and datetime columns; partly also to prevent
# embedding distortion by adding extra precision to values
prefixes = string.ascii_lowercase + string.ascii_uppercase
tgt_num_dat_bins = tgt_num_dat_bins or {}
for i, col in enumerate(tgt_num_dat_bins.keys()):
df_tgt[col] = bin_num_dat(values=df_tgt[col], bins=tgt_num_dat_bins[col], prefix=prefixes[i % len(prefixes)])
# bin columns; also to prevent distortion of embeddings by adding extra precision or unknown values
bins = bins or {}
df_tgt.columns = [TGT_COLUMN_PREFIX + c if c != key else c for c in df_tgt.columns]
df_tgt, _ = bin_data(df_tgt, bins=bins, non_categorical_label_style="lower")
# add some prefix to make numeric and date values unique in the embedding space
for col in df_tgt.columns:
if col in bins:
if isinstance(
bins[col][0], (int, float, np.integer, np.floating, datetime.date, datetime.datetime, np.datetime64)
):
prefixes = string.ascii_lowercase + string.ascii_uppercase
prefix = prefixes[xxhash.xxh32_intdigest(col) % len(prefixes)]
df_tgt[col] = prefix + df_tgt[col].astype(str)

# split into chunks while keeping groups together and process in parallel
n_jobs = min(16, max(1, cpu_count() - 1))
Expand Down Expand Up @@ -303,15 +311,6 @@ def sequence_to_string(sequence: pd.DataFrame) -> str:
return strings


def bin_num_dat(values: pd.Series, bins: list, prefix: str) -> pd.Series:
bins = sorted(set(bins))
binned = pd.cut(values, bins=bins, labels=bins[:-1], include_lowest=True).astype(str)
binned[values <= min(bins)] = str(bins[0])
binned[values >= max(bins)] = str(bins[-1])
binned[values.isna()] = "NA"
return prefix + binned


def calculate_embeddings(
strings: list[str],
progress: ProgressCallbackWrapper | None = None,
Expand Down
13 changes: 3 additions & 10 deletions mostlyai/qa/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import warnings
from pathlib import Path
Expand Down Expand Up @@ -289,12 +288,6 @@ def report(
embedder = load_embedder()
_LOG.info("load tgt bins")
bins = statistics.load_bins()
tgt_num_dat_bins = {
c.replace(TGT_COLUMN_PREFIX, ""): bins[c]
for c in bins.keys()
if c.replace(TGT_COLUMN_PREFIX, "") in trn_tgt_data.columns
and isinstance(bins[c][0], (int, float, datetime.date, datetime.datetime))
}

_LOG.info("calculate embeddings for synthetic")
syn_embeds = calculate_embeddings(
Expand All @@ -304,7 +297,7 @@ def report(
ctx_primary_key=ctx_primary_key,
tgt_context_key=tgt_context_key,
max_sample_size=max_sample_size_embeddings_final,
tgt_num_dat_bins=tgt_num_dat_bins,
bins=bins,
),
progress=progress,
progress_from=25,
Expand All @@ -319,7 +312,7 @@ def report(
ctx_primary_key=ctx_primary_key,
tgt_context_key=tgt_context_key,
max_sample_size=max_sample_size_embeddings_final,
tgt_num_dat_bins=tgt_num_dat_bins,
bins=bins,
),
progress=progress,
progress_from=45,
Expand All @@ -335,7 +328,7 @@ def report(
ctx_primary_key=ctx_primary_key,
tgt_context_key=tgt_context_key,
max_sample_size=max_sample_size_embeddings_final,
tgt_num_dat_bins=tgt_num_dat_bins,
bins=bins,
),
progress=progress,
progress_from=65,
Expand Down
10 changes: 1 addition & 9 deletions mostlyai/qa/reporting_from_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
from pathlib import Path

Expand All @@ -28,7 +27,6 @@
)
from mostlyai.qa._sampling import pull_data_for_embeddings, calculate_embeddings, pull_data_for_coherence
from mostlyai.qa._common import (
TGT_COLUMN_PREFIX,
ProgressCallback,
PrerequisiteNotMetError,
check_min_sample_size,
Expand Down Expand Up @@ -169,12 +167,6 @@ def report_from_statistics(
embedder = load_embedder()
_LOG.info("load bins")
bins = statistics.load_bins()
tgt_num_dat_bins = {
c.replace(TGT_COLUMN_PREFIX, ""): bins[c]
for c in bins.keys()
if c.replace(TGT_COLUMN_PREFIX, "") in syn_tgt_data.columns
and isinstance(bins[c][0], (int, float, datetime.date, datetime.datetime))
}

_LOG.info("calculate embeddings for synthetic")
syn_embeds = calculate_embeddings(
Expand All @@ -184,7 +176,7 @@ def report_from_statistics(
ctx_primary_key=ctx_primary_key,
tgt_context_key=tgt_context_key,
max_sample_size=max_sample_size_embeddings,
tgt_num_dat_bins=tgt_num_dat_bins,
bins=bins,
),
progress=progress,
progress_from=40,
Expand Down
11 changes: 0 additions & 11 deletions tests/end_to_end/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,6 @@ def test_report_flat_rare(tmp_path):
assert metrics.accuracy.univariate == 0.0
assert metrics.distances.ims_training == metrics.distances.ims_holdout == 0.0

# test case where rare values are not protected, and we leak trn into synthetic
syn_tgt_data = pd.DataFrame({"x": trn_tgt_data["x"].sample(100, replace=True)})
_, metrics = qa.report(
syn_tgt_data=syn_tgt_data,
trn_tgt_data=trn_tgt_data,
hol_tgt_data=hol_tgt_data,
statistics_path=statistics_path,
)
assert metrics.distances.ims_training > metrics.distances.ims_holdout
assert metrics.distances.dcr_training < metrics.distances.dcr_holdout


def test_report_flat_early_exit(tmp_path):
# test early exit for dfs with <100 rows
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,11 @@ def test_trim_labels():


def test_calculate_correlations(cols):
trn, hol, syn = cols
trn, bins = bin_data(trn, 3)
syn, _ = bin_data(syn, bins)
# prefix some columns with "tgt::"
columns = [f"tgt::{c}" if c != "cat" else c for idx, c in enumerate(trn.columns)]
trn.columns, syn.columns = columns, columns
trn, _, syn = cols
trn, bins = bin_data(trn[["num", "dt"]], 3)
syn, _ = bin_data(syn[["num", "dt"]], bins)
trn = trn.add_prefix("tgt::")
syn = syn.add_prefix("tgt::")
corr_trn = calculate_correlations(trn)
exp_corr_trn = pd.DataFrame(
[
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_pull_data_for_embeddings_large_int(tmp_path):
{"cc": list(np.random.randint(100, 200, size=1000)) + [1800218404984585216] + [pd.NA]}, dtype="Int64"
)
bins = {"cc": [100, 200]}
pull_data_for_embeddings(df_tgt=df, tgt_num_dat_bins=bins)
pull_data_for_embeddings(df_tgt=df, bins=bins)


def test_pull_data_for_embeddings_dates(tmp_path):
Expand All @@ -48,4 +48,4 @@ def test_pull_data_for_embeddings_dates(tmp_path):
"y": [datetime(2020, 2, 1), datetime(2024, 1, 1)],
"z": [datetime(2020, 2, 1), datetime(2024, 1, 1)],
}
pull_data_for_embeddings(df_tgt=df, tgt_num_dat_bins=bins)
pull_data_for_embeddings(df_tgt=df, bins=bins)