diff --git a/mostlyai/qa/_sampling.py b/mostlyai/qa/_sampling.py index 4665a5a..ccd7389 100644 --- a/mostlyai/qa/_sampling.py +++ b/mostlyai/qa/_sampling.py @@ -110,15 +110,16 @@ def pull_data_for_accuracy( df = pd.merge(df, df_tgt, on=key, how="left") df = pd.merge(df, df_nxt, on=key, how="left") df = df.drop(columns=[key]) - - # remove records with sequence length equal to 0 count_column = f"{TGT_COLUMN_PREFIX}{COUNT_COLUMN}" df[count_column] = df[count_column].fillna(0).astype("Int64") - df = df.loc[df[count_column] > 0].reset_index(drop=True) + # determine setup if not provided if setup is None: setup = "1:1" if (df[count_column] == 1).all() else "1:N" + # remove records with sequence length equal to 0 + df = df.loc[df[count_column] > 0].reset_index(drop=True) + # for 1:1 ctx/tgt setups, drop nxt and count columns; ensure at least one column remains if setup == "1:1": df = df.drop(columns=[c for c in df.columns if c.startswith(NXT_COLUMN_PREFIX)]) diff --git a/mostlyai/qa/reporting_from_statistics.py b/mostlyai/qa/reporting_from_statistics.py index 6f61911..b920499 100644 --- a/mostlyai/qa/reporting_from_statistics.py +++ b/mostlyai/qa/reporting_from_statistics.py @@ -113,6 +113,9 @@ def report_from_statistics( ctx_primary_key=ctx_primary_key, tgt_context_key=tgt_context_key, max_sample_size=max_sample_size_accuracy, + # always pull Sequence Length and nxt columns for synthetic data + # and let downstream functions decide if they are needed + setup="1:N", ) _LOG.info(f"sample synthetic data finished ({syn.shape=})") progress.update(completed=20, total=100)