Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ joblib = ">=1.2.0"
Jinja2 = ">=3.1.2"
scikit-learn = ">=1.4.0"
sentence-transformers = ">=3.1.0"
rich = "^13.9.4"

[tool.poetry.group.dev.dependencies]
ruff = "0.7.0"
Expand Down
54 changes: 40 additions & 14 deletions src/mostlyai/qa/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import logging
from typing import Protocol
from functools import partial
from typing import Protocol, Callable

import pandas as pd
from tqdm.auto import tqdm
from rich.progress import Progress

from mostlyai.qa.filesystem import Statistics

Expand Down Expand Up @@ -73,18 +74,43 @@ class PrerequisiteNotMetError(Exception):


class ProgressCallback(Protocol):
def __call__(self, current: int, total: int) -> None: ...


def add_tqdm(on_progress: ProgressCallback | None = None, description: str = "Processing") -> ProgressCallback:
pbar = tqdm(desc=description, total=100)

def _on_progress(current: int, total: int):
if on_progress is not None:
on_progress(current, total)
pbar.update(current - pbar.n)

return _on_progress
def __call__(self, total: float | None = None, completed: float | None = None, **kwargs) -> None: ...


class ProgressCallbackWrapper:
@staticmethod
def _wrap_progress_callback(
update_progress: ProgressCallback | None = None, **kwargs
) -> tuple[ProgressCallback, Callable]:
if not update_progress:
rich_progress = Progress()
rich_progress.start()
task_id = rich_progress.add_task(**kwargs)
update_progress = partial(rich_progress.update, task_id=task_id)
else:
rich_progress = None

def teardown_progress():
if rich_progress:
rich_progress.refresh()
rich_progress.stop()

return update_progress, teardown_progress

def update(self, total: float | None = None, completed: float | None = None, **kwargs) -> None:
self._update_progress(total=total, completed=completed, **kwargs)

def __init__(self, update_progress: ProgressCallback | None = None, **kwargs):
self._update_progress, self._teardown_progress = self._wrap_progress_callback(update_progress, **kwargs)

def __enter__(self):
self._update_progress(completed=0, total=1)
return self

def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self._update_progress(completed=1, total=1)
self._teardown_progress()


def check_min_sample_size(size: int, min: int, type: str) -> None:
Expand Down
45 changes: 23 additions & 22 deletions src/mostlyai/qa/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@
ProgressCallback,
PrerequisiteNotMetError,
check_min_sample_size,
add_tqdm,
NXT_COLUMN,
CTX_COLUMN_PREFIX,
TGT_COLUMN_PREFIX,
REPORT_CREDITS,
ProgressCallbackWrapper,
)
from mostlyai.qa.filesystem import Statistics, TemporaryWorkspace

Expand All @@ -71,7 +71,7 @@ def report(
max_sample_size_accuracy: int | None = None,
max_sample_size_embeddings: int | None = None,
statistics_path: str | Path | None = None,
on_progress: ProgressCallback | None = None,
update_progress: ProgressCallback | None = None,
) -> tuple[Path, Metrics | None]:
"""
Generate HTML report and metrics for comparing synthetic and original data samples.
Expand All @@ -93,7 +93,7 @@ def report(
max_sample_size_accuracy: Max sample size for accuracy
max_sample_size_embeddings: Max sample size for embeddings (similarity & distances)
statistics_path: Path of where to store the statistics to be used by `report_from_statistics`
on_progress: A custom progress callback
update_progress: A custom progress callback
Returns:
1. Path to the HTML report
2. Pydantic Metrics:
Expand All @@ -119,10 +119,10 @@ def report(
- `dcr_share`: Share of synthetic samples that are closer to a training sample than to a holdout sample. This shall not be significantly larger than 50\%.
"""

with TemporaryWorkspace() as workspace:
on_progress = add_tqdm(on_progress, description="Creating report")
on_progress(current=0, total=100)

with (
TemporaryWorkspace() as workspace,
ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress,
):
# ensure all columns are present and in the same order as training data
syn_tgt_data = syn_tgt_data[trn_tgt_data.columns]
if hol_tgt_data is not None:
Expand Down Expand Up @@ -165,7 +165,6 @@ def report(
_LOG.info(err)
statistics.mark_early_exit()
html_report.store_early_exit_report(report_path)
on_progress(current=100, total=100)
return report_path, None

# prepare datasets for accuracy
Expand Down Expand Up @@ -194,7 +193,7 @@ def report(
max_sample_size=max_sample_size_accuracy,
setup=setup,
)
on_progress(current=5, total=100)
progress.update(completed=5, total=100)

_LOG.info("prepare training data for accuracy started")
trn = pull_data_for_accuracy(
Expand All @@ -205,7 +204,7 @@ def report(
max_sample_size=max_sample_size_accuracy,
setup=setup,
)
on_progress(current=10, total=100)
progress.update(completed=10, total=100)

# coerce dtypes to match the original training data dtypes
for col in trn:
Expand All @@ -222,7 +221,7 @@ def report(
statistics=statistics,
workspace=workspace,
)
on_progress(current=20, total=100)
progress.update(completed=20, total=100)

# ensure that embeddings are all equal size for a fair 3-way comparison
max_sample_size_embeddings = min(
Expand All @@ -232,7 +231,9 @@ def report(
hol_sample_size or float("inf"),
)

def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, stop: int) -> np.ndarray:
def _calc_pull_embeds(
df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, progress_from: int, progress_to: int
) -> np.ndarray:
strings = pull_data_for_embeddings(
df_tgt=df_tgt,
df_ctx=df_ctx,
Expand All @@ -241,24 +242,24 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
max_sample_size=max_sample_size_embeddings,
)
# split into buckets for calculating embeddings to avoid memory issues and report continuous progress
buckets = np.array_split(strings, stop - start)
buckets = np.array_split(strings, progress_to - progress_from)
buckets = [b for b in buckets if len(b) > 0]
embeds = []
for i, bucket in enumerate(buckets, 1):
embeds += [calculate_embeddings(bucket.tolist())]
on_progress(current=start + i, total=100)
on_progress(current=stop, total=100)
progress.update(completed=progress_from + i, total=100)
progress.update(completed=progress_to, total=100)
embeds = np.concatenate(embeds, axis=0)
_LOG.info(f"calculated embeddings {embeds.shape}")
return embeds

syn_embeds = _calc_pull_embeds(df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, start=20, stop=40)
trn_embeds = _calc_pull_embeds(df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, start=40, stop=60)
syn_embeds = _calc_pull_embeds(df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, progress_from=20, progress_to=40)
trn_embeds = _calc_pull_embeds(df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, progress_from=40, progress_to=60)
if hol_tgt_data is not None:
hol_embeds = _calc_pull_embeds(df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, start=60, stop=80)
hol_embeds = _calc_pull_embeds(df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, progress_from=60, progress_to=80)
else:
hol_embeds = None
on_progress(current=80, total=100)
progress.update(completed=80, total=100)

_LOG.info("report similarity")
sim_cosine_trn_hol, sim_cosine_trn_syn, sim_auc_trn_hol, sim_auc_trn_syn = report_similarity(
Expand All @@ -268,7 +269,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
workspace=workspace,
statistics=statistics,
)
on_progress(current=90, total=100)
progress.update(completed=90, total=100)

_LOG.info("report distances")
dcr_trn, dcr_hol = report_distances(
Expand All @@ -277,7 +278,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
hol_embeds=hol_embeds,
workspace=workspace,
)
on_progress(current=99, total=100)
progress.update(completed=99, total=100)

metrics = calculate_metrics(
acc_uni=acc_uni,
Expand Down Expand Up @@ -314,7 +315,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
acc_biv=acc_biv,
corr_trn=corr_trn,
)
on_progress(current=100, total=100)
progress.update(completed=100, total=100)
return report_path, metrics


Expand Down
21 changes: 10 additions & 11 deletions src/mostlyai/qa/report_from_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
ProgressCallback,
PrerequisiteNotMetError,
check_min_sample_size,
add_tqdm,
check_statistics_prerequisite,
determine_data_size,
REPORT_CREDITS,
ProgressCallbackWrapper,
)
from mostlyai.qa.filesystem import Statistics, TemporaryWorkspace

Expand All @@ -50,12 +50,12 @@ def report_from_statistics(
report_extra_info: str = "",
max_sample_size_accuracy: int | None = None,
max_sample_size_embeddings: int | None = None,
on_progress: ProgressCallback | None = None,
update_progress: ProgressCallback | None = None,
) -> Path:
with TemporaryWorkspace() as workspace:
on_progress = add_tqdm(on_progress, description="Creating report from statistics")
on_progress(current=0, total=100)

with (
TemporaryWorkspace() as workspace,
ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress,
):
# prepare report_path
if report_path is None:
report_path = Path.cwd() / "data-report.html"
Expand All @@ -73,7 +73,6 @@ def report_from_statistics(
check_min_sample_size(syn_sample_size, 100, "synthetic")
except PrerequisiteNotMetError:
html_report.store_early_exit_report(report_path)
on_progress(current=100, total=100)
return report_path

meta = statistics.load_meta()
Expand All @@ -96,15 +95,15 @@ def report_from_statistics(
max_sample_size=max_sample_size_accuracy,
)
_LOG.info(f"sample synthetic data finished ({syn.shape=})")
on_progress(current=20, total=100)
progress.update(completed=20, total=100)

# calculate and plot accuracy and correlations
acc_uni, acc_biv, corr_trn = report_accuracy_and_correlations_from_statistics(
syn=syn,
statistics=statistics,
workspace=workspace,
)
on_progress(current=30, total=100)
progress.update(completed=30, total=100)

_LOG.info("calculate embeddings for synthetic")
syn_embeds = calculate_embeddings(
Expand All @@ -123,7 +122,7 @@ def report_from_statistics(
workspace=workspace,
statistics=statistics,
)
on_progress(current=50, total=100)
progress.update(completed=50, total=100)

meta |= {
"rows_synthetic": syn.shape[0],
Expand All @@ -144,7 +143,7 @@ def report_from_statistics(
acc_biv=acc_biv,
corr_trn=corr_trn,
)
on_progress(current=100, total=100)
progress.update(completed=100, total=100)
return report_path


Expand Down