From da76dfafe4bec7ca99bf31b4d2e2dabe08cce204 Mon Sep 17 00:00:00 2001 From: makrianast Date: Fri, 22 Nov 2024 17:06:13 +0200 Subject: [PATCH] added batching to matcher.py and unit tests --- src/harmony/matching/matcher.py | 34 +++++++++++-- tests/test_batching_in_matcher.py | 84 +++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 tests/test_batching_in_matcher.py diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index c409e54..2dfd2da 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -42,6 +42,30 @@ ) from harmony.schemas.text_vector import TextVector +import os + + +def get_batch_size(default=50): + try: + batch_size = int(os.getenv("BATCH_SIZE", default)) + return max(batch_size, 0) + except (ValueError, TypeError): + return default +def process_items_in_batches(items, llm_function): + batch_size = get_batch_size() + + if batch_size == 0: + return llm_function(items) + + + batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)] + + results = [] + for batch in batches: + batch_results = llm_function(batch) + results.extend(batch_results) + return results + def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray: dp = dot(vec1, vec2.T) @@ -127,8 +151,11 @@ def create_full_text_vectors( # Texts with no cached vector texts_not_cached = [x.text for x in text_vectors if not x.vector] + + # Get vectors for all texts not cached - new_vectors_list: List = vectorisation_function(texts_not_cached).tolist() + new_vectors_list: List = process_items_in_batches(texts_not_cached, vectorisation_function) + # Create a dictionary with new vectors new_vectors_dict = {} @@ -382,7 +409,7 @@ def match_questions_with_catalogue_instruments( instrument_idx_to_score = {} for instrument_idx, average_sim in instrument_idx_to_cosine_similarities_average.items(): - score = average_sim * (0.1+instrument_idx_to_top_matches_ct.get(instrument_idx, 0)) + score = average_sim * (0.1 + instrument_idx_to_top_matches_ct.get(instrument_idx, 0)) instrument_idx_to_score[instrument_idx] = score # Find the top 10 best instrument idx matches, index 0 containing the best match etc. @@ -432,7 +459,8 @@ def match_questions_with_catalogue_instruments( "info": info, "num_matched_questions": num_top_match_questions, "num_ref_instrument_questions": num_questions_in_ref_instrument, - "mean_cosine_similarity": instrument_idx_to_cosine_similarities_average.get(top_catalogue_instrument_idx) + "mean_cosine_similarity": instrument_idx_to_cosine_similarities_average.get( + top_catalogue_instrument_idx) }, )) diff --git a/tests/test_batching_in_matcher.py b/tests/test_batching_in_matcher.py new file mode 100644 index 0000000..b69a5bf --- /dev/null +++ b/tests/test_batching_in_matcher.py @@ -0,0 +1,84 @@ +import sys +import os +import unittest +import numpy + +sys.path.append("../src") +from unittest import TestCase, mock +from harmony.matching.matcher import get_batch_size +from harmony.matching.matcher import process_items_in_batches + + +# Mock LLM function +def mock_llm_function(batch): + """Simulates processing a batch.""" + return [f"Processed: {item}" for item in batch] + + +class TestMatcherBatching(TestCase): + + @mock.patch.dict(os.environ, {"BATCH_SIZE": "5"}) + def test_batched_processing(self): + """Test that 10 items are divided into 2 batches of 5 each.""" + items = [f"item{i}" for i in range(10)] # 10 items to process + results = process_items_in_batches(items, mock_llm_function) + + self.assertEqual(len(results), 10) + + expected = [ + "Processed: item0", "Processed: item1", "Processed: item2", "Processed: item3", "Processed: item4", + "Processed: item5", "Processed: item6", "Processed: item7", "Processed: item8", "Processed: item9", + ] + self.assertEqual(results, expected) + + @mock.patch.dict(os.environ, {"BATCH_SIZE": "5"}) + def test_large_batch_size(self): + """Test batch size greater than input size.""" + items = [f"item{i}" for i in range(3)] # Only 3 items + results = process_items_in_batches(items, mock_llm_function) + + self.assertEqual(len(results), 3) + + expected = [ + "Processed: item0", "Processed: item1", "Processed: item2", + ] + self.assertEqual(results, expected) + + @mock.patch.dict(os.environ, {"BATCH_SIZE": "0"}) + def test_no_batching(self): + """Test no batching (all items processed in one batch).""" + items = [f"item{i}" for i in range(10)] # 10 items to process + results = process_items_in_batches(items, mock_llm_function) + + self.assertEqual(len(results), 10) + + expected = [ + "Processed: item0", "Processed: item1", "Processed: item2", "Processed: item3", "Processed: item4", + "Processed: item5", "Processed: item6", "Processed: item7", "Processed: item8", "Processed: item9", + ] + self.assertEqual(results, expected) + + @mock.patch.dict(os.environ, {"BATCH_SIZE": "-5"}) + def test_negative_batch_size(self): + """Test when BATCH_SIZE is negative, it defaults to 0.""" + items = [f"item{i}" for i in range(10)] + results = process_items_in_batches(items, mock_llm_function) + self.assertEqual(len(results), 10) + + @mock.patch.dict(os.environ, {}, clear=True) + def test_default_batch_size(self): + """Test when BATCH_SIZE is not set, it defaults to 50.""" + items = [f"item{i}" for i in range(10)] + results = process_items_in_batches(items, mock_llm_function) + self.assertEqual(len(results), 10) + + @mock.patch.dict(os.environ, {"BATCH_SIZE": "invalid"}) + def test_invalid_batch_size(self): + """Test when BATCH_SIZE is invalid, it defaults to 50.""" + items = [f"item{i}" for i in range(10)] + results = process_items_in_batches(items, mock_llm_function) + self.assertEqual(len(results), 10) + + +if __name__ == "__main__": + unittest.main()