From 71753f4fb0d6a8452c2f8c9d6fb51e4244d1cbe3 Mon Sep 17 00:00:00 2001 From: Michael Platzer Date: Mon, 12 May 2025 18:47:23 -0400 Subject: [PATCH 1/4] chore: limit samples for AUC; drop FAISS again --- mostlyai/qa/_distances.py | 32 ++++++-------------- mostlyai/qa/_similarity.py | 4 +++ pyproject.toml | 1 - uv.lock | 60 +++++++++----------------------------- 4 files changed, 26 insertions(+), 71 deletions(-) diff --git a/mostlyai/qa/_distances.py b/mostlyai/qa/_distances.py index b75e0db..eb03236 100644 --- a/mostlyai/qa/_distances.py +++ b/mostlyai/qa/_distances.py @@ -13,11 +13,12 @@ # limitations under the License. import logging -import platform import time import numpy as np import networkx as nx import xxhash +from sklearn.neighbors import NearestNeighbors +from joblib import cpu_count from mostlyai.qa._common import ( CHARTS_COLORS, @@ -42,22 +43,9 @@ def calculate_dcrs_nndrs( t0 = time.time() data = data[data[:, 0].argsort()] # sort data by first dimension to enforce deterministic results - if platform.system() == "Linux": - # use FAISS on Linux for best performance - import faiss # type: ignore - - index = faiss.IndexFlatL2(data.shape[1]) - index.add(data) - dcrs, _ = index.search(query, 2) - dcrs = np.sqrt(dcrs) # FAISS returns squared distances - else: - # use sklearn as a fallback on non-Linux systems to avoid segfaults; these occurred when using QA as part of SDK - from sklearn.neighbors import NearestNeighbors # type: ignore - from joblib import cpu_count # type: ignore - - index = NearestNeighbors(n_neighbors=2, algorithm="auto", metric="l2", n_jobs=min(16, max(1, cpu_count() - 1))) - index.fit(data) - dcrs, _ = index.kneighbors(query) + index = NearestNeighbors(n_neighbors=2, algorithm="auto", metric="l2", n_jobs=min(16, max(1, cpu_count() - 1))) + index.fit(data) + dcrs, _ = index.kneighbors(query) dcr = dcrs[:, 0] nndr = (dcrs[:, 0] + 1e-8) / (dcrs[:, 1] + 1e-8) _LOG.info(f"calculated DCRs for {data.shape=} and {query.shape=} in {time.time() - t0:.2f}s") @@ -85,14 +73,12 @@ def calculate_distances( groups = [] # check all columns together groups += [np.arange(ori_embeds.shape[1])] - # check subsets of correlated columns together + # check 3 correlated subsets of columns if ori_embeds.shape[1] > 10: - k = max(3, ori_embeds.shape[1] // 10) - groups += split_columns_into_correlated_groups(ori_embeds, k=k) - # check random subsets of columns + groups += split_columns_into_correlated_groups(ori_embeds, k=3) + # check 3 random subsets of columns if ori_embeds.shape[1] > 10: - k = max(3, ori_embeds.shape[1] // 10) - groups += split_columns_into_random_groups(ori_embeds, k=k) + groups += split_columns_into_random_groups(ori_embeds, k=3) dcr_share = 0.0 nndr_ratio = 1.0 for columns in groups: diff --git a/mostlyai/qa/_similarity.py b/mostlyai/qa/_similarity.py index b124289..812427d 100644 --- a/mostlyai/qa/_similarity.py +++ b/mostlyai/qa/_similarity.py @@ -69,6 +69,10 @@ def calculate_mean_auc(embeds1, embeds2): for a ML model to discriminate between two embedding arrays. """ + # limit the number of samples to 10000 + embeds1 = embeds1[:10000] + embeds2 = embeds2[:10000] + # create labels for the data labels1 = np.zeros(embeds1.shape[0]) labels2 = np.ones(embeds2.shape[0]) diff --git a/pyproject.toml b/pyproject.toml index 91ecc30..c5b288e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ dependencies = [ "accelerate>=1.5.0", "torch>=2.6.0", "xxhash>=3.5.0", - "faiss-cpu>=1.7.0", ] [project.urls] diff --git a/uv.lock b/uv.lock index 0a2338d..3fe44c4 100644 --- a/uv.lock +++ b/uv.lock @@ -472,38 +472,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] -[[package]] -name = "faiss-cpu" -version = "1.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "packaging" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e7/9a/e33fc563f007924dd4ec3c5101fe5320298d6c13c158a24a9ed849058569/faiss_cpu-1.11.0.tar.gz", hash = "sha256:44877b896a2b30a61e35ea4970d008e8822545cb340eca4eff223ac7f40a1db9", size = 70218 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/e5/7490368ec421e44efd60a21aa88d244653c674d8d6ee6bc455d8ee3d02ed/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1995119152928c68096b0c1e5816e3ee5b1eebcf615b80370874523be009d0f6", size = 3307996 }, - { url = "https://files.pythonhosted.org/packages/dd/ac/a94fbbbf4f38c2ad11862af92c071ff346630ebf33f3d36fe75c3817c2f0/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:788d7bf24293fdecc1b93f1414ca5cc62ebd5f2fecfcbb1d77f0e0530621c95d", size = 7886309 }, - { url = "https://files.pythonhosted.org/packages/63/48/ad79f34f1b9eba58c32399ad4fbedec3f2a717d72fb03648e906aab48a52/faiss_cpu-1.11.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:73408d52429558f67889581c0c6d206eedcf6fabe308908f2bdcd28fd5e8be4a", size = 3778443 }, - { url = "https://files.pythonhosted.org/packages/95/67/3c6b94dd3223a8ecaff1c10c11b4ac6f3f13f1ba8ab6b6109c24b6e9b23d/faiss_cpu-1.11.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1f53513682ca94c76472544fa5f071553e428a1453e0b9755c9673f68de45f12", size = 31295174 }, - { url = "https://files.pythonhosted.org/packages/a4/2c/d843256aabdb7f20f0f87f61efe3fb7c2c8e7487915f560ba523cfcbab57/faiss_cpu-1.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:30489de0356d3afa0b492ca55da164d02453db2f7323c682b69334fde9e8d48e", size = 15003860 }, - { url = "https://files.pythonhosted.org/packages/ed/83/8aefc4d07624a868e046cc23ede8a59bebda57f09f72aee2150ef0855a82/faiss_cpu-1.11.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a90d1c81d0ecf2157e1d2576c482d734d10760652a5b2fcfa269916611e41f1c", size = 3307997 }, - { url = "https://files.pythonhosted.org/packages/2b/64/f97e91d89dc6327e08f619fe387d7d9945bc4be3b0f1ca1e494a41c92ebe/faiss_cpu-1.11.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:2c39a388b059fb82cd97fbaa7310c3580ced63bf285be531453bfffbe89ea3dd", size = 7886308 }, - { url = "https://files.pythonhosted.org/packages/44/0a/7c17b6df017b0bc127c6aa4066b028281e67ab83d134c7433c4e75cd6bb6/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a4e3433ffc7f9b8707a7963db04f8676a5756868d325644db2db9d67a618b7a0", size = 3778441 }, - { url = "https://files.pythonhosted.org/packages/53/45/7c85551025d9f0237d891b5cffdc5d4a366011d53b4b0a423b972cc52cea/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:926645f1b6829623bc88e93bc8ca872504d604718ada3262e505177939aaee0a", size = 31295136 }, - { url = "https://files.pythonhosted.org/packages/7f/9a/accade34b8668b21206c0c4cf0b96cd0b750b693ba5b255c1c10cfee460f/faiss_cpu-1.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:931db6ed2197c03a7fdf833b057c13529afa2cec8a827aa081b7f0543e4e671b", size = 15003710 }, - { url = "https://files.pythonhosted.org/packages/3b/d3/7178fa07047fd770964a83543329bb5e3fc1447004cfd85186ccf65ec3ee/faiss_cpu-1.11.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:356437b9a46f98c25831cdae70ca484bd6c05065af6256d87f6505005e9135b9", size = 3313807 }, - { url = "https://files.pythonhosted.org/packages/9e/71/25f5f7b70a9f22a3efe19e7288278da460b043a3b60ad98e4e47401ed5aa/faiss_cpu-1.11.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c4a3d35993e614847f3221c6931529c0bac637a00eff0d55293e1db5cb98c85f", size = 7913537 }, - { url = "https://files.pythonhosted.org/packages/b0/c8/a5cb8466c981ad47750e1d5fda3d4223c82f9da947538749a582b3a2d35c/faiss_cpu-1.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f9af33e0b8324e8199b93eb70ac4a951df02802a9dcff88e9afc183b11666f0", size = 3785180 }, - { url = "https://files.pythonhosted.org/packages/7f/37/eaf15a7d80e1aad74f56cf737b31b4547a1a664ad3c6e4cfaf90e82454a8/faiss_cpu-1.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:48b7e7876829e6bdf7333041800fa3c1753bb0c47e07662e3ef55aca86981430", size = 31287630 }, - { url = "https://files.pythonhosted.org/packages/ff/5c/902a78347e9c47baaf133e47863134e564c39f9afe105795b16ee986b0df/faiss_cpu-1.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:bdc199311266d2be9d299da52361cad981393327b2b8aa55af31a1b75eaaf522", size = 15005398 }, - { url = "https://files.pythonhosted.org/packages/92/90/d2329ce56423cc61f4c20ae6b4db001c6f88f28bf5a7ef7f8bbc246fd485/faiss_cpu-1.11.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:0c98e5feff83b87348e44eac4d578d6f201780dae6f27f08a11d55536a20b3a8", size = 3313807 }, - { url = "https://files.pythonhosted.org/packages/24/14/8af8f996d54e6097a86e6048b1a2c958c52dc985eb4f935027615079939e/faiss_cpu-1.11.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:796e90389427b1c1fb06abdb0427bb343b6350f80112a2e6090ac8f176ff7416", size = 7913539 }, - { url = "https://files.pythonhosted.org/packages/b2/2b/437c2f36c3aa3cffe041479fced1c76420d3e92e1f434f1da3be3e6f32b1/faiss_cpu-1.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b6e355dda72b3050991bc32031b558b8f83a2b3537a2b9e905a84f28585b47e", size = 3785181 }, - { url = "https://files.pythonhosted.org/packages/66/75/955527414371843f558234df66fa0b62c6e86e71e4022b1be9333ac6004c/faiss_cpu-1.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6c482d07194638c169b4422774366e7472877d09181ea86835e782e6304d4185", size = 31287635 }, - { url = "https://files.pythonhosted.org/packages/50/51/35b7a3f47f7859363a367c344ae5d415ea9eda65db0a7d497c7ea2c0b576/faiss_cpu-1.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:13eac45299532b10e911bff1abbb19d1bf5211aa9e72afeade653c3f1e50e042", size = 15005455 }, -] - [[package]] name = "fieldz" version = "0.1.0" @@ -614,17 +582,17 @@ wheels = [ [[package]] name = "hf-xet" -version = "1.1.0" +version = "1.1.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/aa/2c/70009910fcbd204bde75842b60c1e47fe72edb0e978954cb8001735885c7/hf_xet-1.1.0.tar.gz", hash = "sha256:a7c2a4c2b6eee9ce0a1a367a82b60d95ba634420ef1c250addad7aa4af419cf4", size = 263996 } +sdist = { url = "https://files.pythonhosted.org/packages/3a/09/e2fc5ccd6f9828efbd9135d5aab70895fa6891752ce84c57026c48f3f33d/hf_xet-1.1.1.tar.gz", hash = "sha256:3e75d6e04c38c80115b640c025d68c3dc14d62f8b244011dfe547363674a1e87", size = 277864 } wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/fd/0db331297e331f0f02005fd7ea666439bf15efd74f0dd62af02a43236a1b/hf_xet-1.1.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:0322c42551e275fcb7949c083a54a81b2898e50787c9aa74284fcb8d2c58c12c", size = 5069444 }, - { url = "https://files.pythonhosted.org/packages/b9/7d/4d7ae44219d3744ad55669cb90ef3d4ed9f5f8a4729fa635a6499491cb78/hf_xet-1.1.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:667153a0304ac2debf2af95a8ff7687186f885b493f4cd16344869af270cd110", size = 4881465 }, - { url = "https://files.pythonhosted.org/packages/83/9a/d40d2a57b132d609d8a4ccc29e59ed69749021610616749cabcda2532158/hf_xet-1.1.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:995eeffb119636ea617b96c7d7bf3c3f5ea8727fa57974574e25d700b8532d48", size = 53584225 }, - { url = "https://files.pythonhosted.org/packages/2e/01/d94553f91d85746e0862f24d239da88d10f5ce252b028565744e982432f4/hf_xet-1.1.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3aee847da362393331f515c4010d0aaa1c2669acfcca1f4b28946d6949cc0086", size = 52043680 }, - { url = "https://files.pythonhosted.org/packages/29/89/1f31853bf378f0ceb3363c07fd8a12af9b904b1f8c21e65eb5c19397bc98/hf_xet-1.1.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:68c5813a6074aa36e12ef5983230e3b03148cce61e0fcdd294096493795565b4", size = 53072672 }, - { url = "https://files.pythonhosted.org/packages/b5/9f/5ecb92b18a4b2135a72a95dc08bcbeda9176f46642c745ee052420d2aea8/hf_xet-1.1.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4ee9222bf9274b1c198b88a929de0b5a49349c4962d89c5b3b2f0f7f47d9761c", size = 53521053 }, - { url = "https://files.pythonhosted.org/packages/53/d6/cb32842cbf1cf5a154b41fa918a2fd86003af9bca227a2397cd7f312a8a6/hf_xet-1.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:73153eab9abf3d6973b21e94a67ccba5d595c3e12feb8c0bf50be02964e7f126", size = 4204376 }, + { url = "https://files.pythonhosted.org/packages/97/f5/81194ea8e4a585d7d4d0f2ad1ca16e05a4b0c5a385bb2610a8e6da1d2c3d/hf_xet-1.1.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e39a8513f0854656116c837d387d9a41e9d78430b1a181442f04c223cbc4e8f8", size = 5274857 }, + { url = "https://files.pythonhosted.org/packages/55/3c/36342b3fa247f2580821a4b183d38f36fb20e911a8307df625790e734359/hf_xet-1.1.1-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:c60cd67be384cb9e592fa6dfd29a10fddffa1feb2f3b31f53e980630d1ca0fd6", size = 5079657 }, + { url = "https://files.pythonhosted.org/packages/b0/c1/4f770cc7be79287905e13765d4a7e1949dce3483f90867f532ff56e7126b/hf_xet-1.1.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5efc6cf15930d9b0cef25c0444e00c2f55d9e09f856f26ed8c809fd5cd1aa044", size = 25506200 }, + { url = "https://files.pythonhosted.org/packages/94/69/1ec612f8e9e2ca27563adfca926cfb41bbe988e30d4cd6fc1e0c019e5570/hf_xet-1.1.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:504bbc8341edc2aa4b3c20c1fdda818554ab34e4175730f42e2a90436cbbe706", size = 24469080 }, + { url = "https://files.pythonhosted.org/packages/8d/96/9201773a0ebb982aa5936f19bdd04d404bc5d74e23f30bce6e857757998b/hf_xet-1.1.1-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:87d030157a21016c2cddf757a5fd6f1f364d86afef6e190e63a37dd4dc6f6c98", size = 25641374 }, + { url = "https://files.pythonhosted.org/packages/ba/14/10a4cf526070e774bdc7ce68202dc27a15751ddc22c6b47a5ecb6d8ea4ad/hf_xet-1.1.1-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6e9b640f0f002b3bea36739b30cf3133b3175c27a342b39315be9a9bdb0cec5b", size = 25425434 }, + { url = "https://files.pythonhosted.org/packages/bd/25/7015a82b3b165747ba85b0383e5d5278d268f3a30460f6d55849903cf272/hf_xet-1.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:215a4e95009a0b9795ca3cf33db4e8d1248139593d7e1185661cd19b062d2b82", size = 4391897 }, ] [[package]] @@ -1340,11 +1308,10 @@ wheels = [ [[package]] name = "mostlyai-qa" -version = "1.8.4" +version = "1.9.2" source = { editable = "." } dependencies = [ { name = "accelerate" }, - { name = "faiss-cpu" }, { name = "jinja2" }, { name = "joblib" }, { name = "model2vec" }, @@ -1382,7 +1349,6 @@ docs = [ [package.metadata] requires-dist = [ { name = "accelerate", specifier = ">=1.5.0" }, - { name = "faiss-cpu", specifier = ">=1.7.0" }, { name = "jinja2", specifier = ">=3.1.2" }, { name = "joblib", specifier = ">=1.4.2" }, { name = "model2vec", specifier = ">=0.4.1" }, @@ -1437,11 +1403,11 @@ wheels = [ [[package]] name = "narwhals" -version = "1.38.2" +version = "1.39.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/dd/fa/f4c2b2524b6b1e7475af933849ecaad280822eab5631151ccb1993d600ce/narwhals-1.38.2.tar.gz", hash = "sha256:7c5fbc9f2b8e1d5d95f49dcef9c2d94bf17810de68c87ff195dc7d22f7b3eeb5", size = 277368 } +sdist = { url = "https://files.pythonhosted.org/packages/80/eb/50bd207820ac84b86dee216e2caf3fdf456e59bc854543f8d0ab8637cd66/narwhals-1.39.0.tar.gz", hash = "sha256:6f114def701fc6a23b0523ad53700ae545ad3f5ece041c64dc3ad59e699da43f", size = 471928 } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/45/59251121eb801f033ffdd56d8893a0f20ec661bdc362f8d52a4e4d547f91/narwhals-1.38.2-py3-none-any.whl", hash = "sha256:a33a182e32f18d794a04e7828a5c401fb26ce9083f609993e7e5064aace641c7", size = 338437 }, + { url = "https://files.pythonhosted.org/packages/22/58/a12a534269aa5ba9abdf73a9e0deb600297b71cbf7291bca212944663143/narwhals-1.39.0-py3-none-any.whl", hash = "sha256:50b6778f4b4249eb86c88dd17c3907dd004a16ec25b02d5effaf226a2bcfb940", size = 339242 }, ] [[package]] From e4e0095771f9853509c41656342735c92216be7f Mon Sep 17 00:00:00 2001 From: Michael Platzer Date: Mon, 12 May 2025 19:16:47 -0400 Subject: [PATCH 2/4] improve log messages --- mostlyai/qa/reporting.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mostlyai/qa/reporting.py b/mostlyai/qa/reporting.py index 8b80fd1..0cb3981 100644 --- a/mostlyai/qa/reporting.py +++ b/mostlyai/qa/reporting.py @@ -205,7 +205,6 @@ def report( else: setup = "1:1" - _LOG.info("prepare training data for accuracy") trn = prepare_data_for_accuracy( df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, @@ -214,8 +213,8 @@ def report( max_sample_size=max_sample_size_accuracy, setup=setup, ) + _LOG.info(f"prepared training data for accuracy: {trn.shape}") if hol_tgt_data is not None: - _LOG.info("prepare holdout data for accuracy") hol = prepare_data_for_accuracy( df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, @@ -225,13 +224,13 @@ def report( setup=setup, ori_dtypes=trn.dtypes.to_dict(), ) + _LOG.info(f"prepared holdout data for accuracy: {hol.shape}") ori = pd.concat([trn, hol], axis=0, ignore_index=True) else: hol = None ori = trn progress.update(completed=5, total=100) - _LOG.info("prepare synthetic data for accuracy") syn = prepare_data_for_accuracy( df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, @@ -241,29 +240,29 @@ def report( setup=setup, ori_dtypes=trn.dtypes.to_dict(), ) + _LOG.info(f"prepared synthetic data for accuracy: {syn.shape}") progress.update(completed=10, total=100) # do coherence analysis only if there are non-fk columns in the target data do_coherence = setup == "1:N" and len(trn_tgt_data.columns) > 1 if do_coherence: - _LOG.info("prepare original data for coherence started") ori_coh, ori_coh_bins = prepare_data_for_coherence( df_tgt=pd.concat([trn_tgt_data, hol_tgt_data]) if hol_tgt_data is not None else trn_tgt_data, tgt_context_key=tgt_context_key, max_sample_size=max_sample_size_coherence, ) - _LOG.info("prepare synthetic data for coherence started") + _LOG.info(f"prepared original data for coherence: {ori_coh.shape}") syn_coh, _ = prepare_data_for_coherence( df_tgt=syn_tgt_data, tgt_context_key=tgt_context_key, bins=ori_coh_bins, max_sample_size=max_sample_size_coherence, ) - _LOG.info("store bins used for training data for coherence") + _LOG.info(f"prepared synthetic data for coherence: {syn_coh.shape}") statistics.store_coherence_bins(bins=ori_coh_bins) + _LOG.info("stored bins used for training data for coherence") progress.update(completed=15, total=100) - _LOG.info("calculate embeddings") syn_embeds, trn_embeds, hol_embeds = prepare_data_for_embeddings( syn_tgt_data=syn_tgt_data, trn_tgt_data=trn_tgt_data, @@ -275,6 +274,9 @@ def report( tgt_context_key=tgt_context_key, max_sample_size=max_sample_size_embeddings, ) + _LOG.info( + f"calculated embeddings: syn={syn_embeds.shape}, trn={trn_embeds.shape}, hol={hol_embeds.shape if hol_embeds is not None else None}" + ) progress.update(completed=20, total=100) ## 1. ACCURACY ## From 3846eae7da7432bd68ff6d5b5bab3910fd45edc1 Mon Sep 17 00:00:00 2001 From: Michael Platzer Date: Mon, 12 May 2025 20:07:07 -0400 Subject: [PATCH 3/4] cap seq_len at 100; handle empty tgt --- mostlyai/qa/_sampling.py | 8 +++++--- mostlyai/qa/reporting.py | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mostlyai/qa/_sampling.py b/mostlyai/qa/_sampling.py index 5aab941..0677b98 100644 --- a/mostlyai/qa/_sampling.py +++ b/mostlyai/qa/_sampling.py @@ -300,11 +300,13 @@ def prepare_data_for_embeddings( # cap to Q95 sequence length of original to avoid excessive samples per group distorting results if tgt_context_key is not None: + cap_sequence_length = 100 q95_sequence_length = trn_tgt_data.groupby(key).size().quantile(0.95) - syn_tgt_data = syn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) - trn_tgt_data = trn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) + max_sequence_length = min(q95_sequence_length, cap_sequence_length) + syn_tgt_data = syn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=max_sequence_length) + trn_tgt_data = trn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=max_sequence_length) hol_tgt_data = ( - hol_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) if hol else None + hol_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=max_sequence_length) if hol else None ) # drop key from data as its not relevant for embeddings diff --git a/mostlyai/qa/reporting.py b/mostlyai/qa/reporting.py index 0cb3981..47b1abf 100644 --- a/mostlyai/qa/reporting.py +++ b/mostlyai/qa/reporting.py @@ -181,6 +181,8 @@ def report( check_min_sample_size(trn_sample_size, 90, "training") if hol_tgt_data is not None: check_min_sample_size(hol_sample_size, 10, "holdout") + if trn_tgt_data.shape[1] == 0 or syn_tgt_data.shape[1] == 0: + raise PrerequisiteNotMetError("Provided data has no columns.") except PrerequisiteNotMetError as err: _LOG.info(err) statistics.mark_early_exit() From bde828f83109f67007f55a2e745515f6145a928f Mon Sep 17 00:00:00 2001 From: Michael Platzer Date: Mon, 12 May 2025 20:20:22 -0400 Subject: [PATCH 4/4] cap compute for contour plots --- mostlyai/qa/_similarity.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mostlyai/qa/_similarity.py b/mostlyai/qa/_similarity.py index 812427d..448ec17 100644 --- a/mostlyai/qa/_similarity.py +++ b/mostlyai/qa/_similarity.py @@ -199,6 +199,11 @@ def plot_store_similarity_contours( if trn_embeds.shape[1] < 3: return + # limit the number of samples to 10000 + syn_embeds = syn_embeds[:10000] + trn_embeds = trn_embeds[:10000] + hol_embeds = hol_embeds[:10000] if hol_embeds is not None else None + # perform PCA on trn embeddings pca_model = PCA(n_components=3) pca_model.fit(trn_embeds)