diff --git a/poetry.lock b/poetry.lock index feafa92..b5524ab 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "annotated-types" @@ -3035,4 +3035,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "84b6afdb4c4101d2a0fbe8aeddabb53e03e9aa5886dd1e8bf95181c24c17c349" +content-hash = "12fab0b7c571095916f88504bb577687deb6c8d23b82ac4530ca2e82c1170980" diff --git a/pyproject.toml b/pyproject.toml index 6e8c58b..cbbaca4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ joblib = ">=1.2.0" Jinja2 = ">=3.1.2" scikit-learn = ">=1.4.0" sentence-transformers = ">=3.1.0" +rich = "^13.9.4" [tool.poetry.group.dev.dependencies] ruff = "0.7.0" diff --git a/src/mostlyai/qa/common.py b/src/mostlyai/qa/common.py index f6ad52e..ae2ac8c 100644 --- a/src/mostlyai/qa/common.py +++ b/src/mostlyai/qa/common.py @@ -13,10 +13,11 @@ # limitations under the License. import logging -from typing import Protocol +from functools import partial +from typing import Protocol, Callable import pandas as pd -from tqdm.auto import tqdm +from rich.progress import Progress from mostlyai.qa.filesystem import Statistics @@ -73,18 +74,43 @@ class PrerequisiteNotMetError(Exception): class ProgressCallback(Protocol): - def __call__(self, current: int, total: int) -> None: ... - - -def add_tqdm(on_progress: ProgressCallback | None = None, description: str = "Processing") -> ProgressCallback: - pbar = tqdm(desc=description, total=100) - - def _on_progress(current: int, total: int): - if on_progress is not None: - on_progress(current, total) - pbar.update(current - pbar.n) - - return _on_progress + def __call__(self, total: float | None = None, completed: float | None = None, **kwargs) -> None: ... + + +class ProgressCallbackWrapper: + @staticmethod + def _wrap_progress_callback( + update_progress: ProgressCallback | None = None, **kwargs + ) -> tuple[ProgressCallback, Callable]: + if not update_progress: + rich_progress = Progress() + rich_progress.start() + task_id = rich_progress.add_task(**kwargs) + update_progress = partial(rich_progress.update, task_id=task_id) + else: + rich_progress = None + + def teardown_progress(): + if rich_progress: + rich_progress.refresh() + rich_progress.stop() + + return update_progress, teardown_progress + + def update(self, total: float | None = None, completed: float | None = None, **kwargs) -> None: + self._update_progress(total=total, completed=completed, **kwargs) + + def __init__(self, update_progress: ProgressCallback | None = None, **kwargs): + self._update_progress, self._teardown_progress = self._wrap_progress_callback(update_progress, **kwargs) + + def __enter__(self): + self._update_progress(completed=0, total=1) + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self._update_progress(completed=1, total=1) + self._teardown_progress() def check_min_sample_size(size: int, min: int, type: str) -> None: diff --git a/src/mostlyai/qa/report.py b/src/mostlyai/qa/report.py index ca77cc7..b84e328 100644 --- a/src/mostlyai/qa/report.py +++ b/src/mostlyai/qa/report.py @@ -42,11 +42,11 @@ ProgressCallback, PrerequisiteNotMetError, check_min_sample_size, - add_tqdm, NXT_COLUMN, CTX_COLUMN_PREFIX, TGT_COLUMN_PREFIX, REPORT_CREDITS, + ProgressCallbackWrapper, ) from mostlyai.qa.filesystem import Statistics, TemporaryWorkspace @@ -71,7 +71,7 @@ def report( max_sample_size_accuracy: int | None = None, max_sample_size_embeddings: int | None = None, statistics_path: str | Path | None = None, - on_progress: ProgressCallback | None = None, + update_progress: ProgressCallback | None = None, ) -> tuple[Path, Metrics | None]: """ Generate HTML report and metrics for comparing synthetic and original data samples. @@ -93,7 +93,7 @@ def report( max_sample_size_accuracy: Max sample size for accuracy max_sample_size_embeddings: Max sample size for embeddings (similarity & distances) statistics_path: Path of where to store the statistics to be used by `report_from_statistics` - on_progress: A custom progress callback + update_progress: A custom progress callback Returns: 1. Path to the HTML report 2. Pydantic Metrics: @@ -119,10 +119,10 @@ def report( - `dcr_share`: Share of synthetic samples that are closer to a training sample than to a holdout sample. This shall not be significantly larger than 50\%. """ - with TemporaryWorkspace() as workspace: - on_progress = add_tqdm(on_progress, description="Creating report") - on_progress(current=0, total=100) - + with ( + TemporaryWorkspace() as workspace, + ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress, + ): # ensure all columns are present and in the same order as training data syn_tgt_data = syn_tgt_data[trn_tgt_data.columns] if hol_tgt_data is not None: @@ -165,7 +165,6 @@ def report( _LOG.info(err) statistics.mark_early_exit() html_report.store_early_exit_report(report_path) - on_progress(current=100, total=100) return report_path, None # prepare datasets for accuracy @@ -194,7 +193,7 @@ def report( max_sample_size=max_sample_size_accuracy, setup=setup, ) - on_progress(current=5, total=100) + progress.update(completed=5, total=100) _LOG.info("prepare training data for accuracy started") trn = pull_data_for_accuracy( @@ -205,7 +204,7 @@ def report( max_sample_size=max_sample_size_accuracy, setup=setup, ) - on_progress(current=10, total=100) + progress.update(completed=10, total=100) # coerce dtypes to match the original training data dtypes for col in trn: @@ -222,7 +221,7 @@ def report( statistics=statistics, workspace=workspace, ) - on_progress(current=20, total=100) + progress.update(completed=20, total=100) # ensure that embeddings are all equal size for a fair 3-way comparison max_sample_size_embeddings = min( @@ -232,7 +231,9 @@ def report( hol_sample_size or float("inf"), ) - def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, stop: int) -> np.ndarray: + def _calc_pull_embeds( + df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, progress_from: int, progress_to: int + ) -> np.ndarray: strings = pull_data_for_embeddings( df_tgt=df_tgt, df_ctx=df_ctx, @@ -241,24 +242,24 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st max_sample_size=max_sample_size_embeddings, ) # split into buckets for calculating embeddings to avoid memory issues and report continuous progress - buckets = np.array_split(strings, stop - start) + buckets = np.array_split(strings, progress_to - progress_from) buckets = [b for b in buckets if len(b) > 0] embeds = [] for i, bucket in enumerate(buckets, 1): embeds += [calculate_embeddings(bucket.tolist())] - on_progress(current=start + i, total=100) - on_progress(current=stop, total=100) + progress.update(completed=progress_from + i, total=100) + progress.update(completed=progress_to, total=100) embeds = np.concatenate(embeds, axis=0) _LOG.info(f"calculated embeddings {embeds.shape}") return embeds - syn_embeds = _calc_pull_embeds(df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, start=20, stop=40) - trn_embeds = _calc_pull_embeds(df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, start=40, stop=60) + syn_embeds = _calc_pull_embeds(df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, progress_from=20, progress_to=40) + trn_embeds = _calc_pull_embeds(df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, progress_from=40, progress_to=60) if hol_tgt_data is not None: - hol_embeds = _calc_pull_embeds(df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, start=60, stop=80) + hol_embeds = _calc_pull_embeds(df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, progress_from=60, progress_to=80) else: hol_embeds = None - on_progress(current=80, total=100) + progress.update(completed=80, total=100) _LOG.info("report similarity") sim_cosine_trn_hol, sim_cosine_trn_syn, sim_auc_trn_hol, sim_auc_trn_syn = report_similarity( @@ -268,7 +269,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st workspace=workspace, statistics=statistics, ) - on_progress(current=90, total=100) + progress.update(completed=90, total=100) _LOG.info("report distances") dcr_trn, dcr_hol = report_distances( @@ -277,7 +278,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st hol_embeds=hol_embeds, workspace=workspace, ) - on_progress(current=99, total=100) + progress.update(completed=99, total=100) metrics = calculate_metrics( acc_uni=acc_uni, @@ -314,7 +315,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st acc_biv=acc_biv, corr_trn=corr_trn, ) - on_progress(current=100, total=100) + progress.update(completed=100, total=100) return report_path, metrics diff --git a/src/mostlyai/qa/report_from_statistics.py b/src/mostlyai/qa/report_from_statistics.py index e8815ae..82e609e 100644 --- a/src/mostlyai/qa/report_from_statistics.py +++ b/src/mostlyai/qa/report_from_statistics.py @@ -26,10 +26,10 @@ ProgressCallback, PrerequisiteNotMetError, check_min_sample_size, - add_tqdm, check_statistics_prerequisite, determine_data_size, REPORT_CREDITS, + ProgressCallbackWrapper, ) from mostlyai.qa.filesystem import Statistics, TemporaryWorkspace @@ -50,12 +50,12 @@ def report_from_statistics( report_extra_info: str = "", max_sample_size_accuracy: int | None = None, max_sample_size_embeddings: int | None = None, - on_progress: ProgressCallback | None = None, + update_progress: ProgressCallback | None = None, ) -> Path: - with TemporaryWorkspace() as workspace: - on_progress = add_tqdm(on_progress, description="Creating report from statistics") - on_progress(current=0, total=100) - + with ( + TemporaryWorkspace() as workspace, + ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress, + ): # prepare report_path if report_path is None: report_path = Path.cwd() / "data-report.html" @@ -73,7 +73,6 @@ def report_from_statistics( check_min_sample_size(syn_sample_size, 100, "synthetic") except PrerequisiteNotMetError: html_report.store_early_exit_report(report_path) - on_progress(current=100, total=100) return report_path meta = statistics.load_meta() @@ -96,7 +95,7 @@ def report_from_statistics( max_sample_size=max_sample_size_accuracy, ) _LOG.info(f"sample synthetic data finished ({syn.shape=})") - on_progress(current=20, total=100) + progress.update(completed=20, total=100) # calculate and plot accuracy and correlations acc_uni, acc_biv, corr_trn = report_accuracy_and_correlations_from_statistics( @@ -104,7 +103,7 @@ def report_from_statistics( statistics=statistics, workspace=workspace, ) - on_progress(current=30, total=100) + progress.update(completed=30, total=100) _LOG.info("calculate embeddings for synthetic") syn_embeds = calculate_embeddings( @@ -123,7 +122,7 @@ def report_from_statistics( workspace=workspace, statistics=statistics, ) - on_progress(current=50, total=100) + progress.update(completed=50, total=100) meta |= { "rows_synthetic": syn.shape[0], @@ -144,7 +143,7 @@ def report_from_statistics( acc_biv=acc_biv, corr_trn=corr_trn, ) - on_progress(current=100, total=100) + progress.update(completed=100, total=100) return report_path