Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@
)
from harmony.schemas.text_vector import TextVector

import os


def get_batch_size(default=50):
try:
batch_size = int(os.getenv("BATCH_SIZE", default))
return max(batch_size, 0)
except (ValueError, TypeError):
return default
def process_items_in_batches(items, llm_function):
batch_size = get_batch_size()

if batch_size == 0:
return llm_function(items)


batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]

results = []
for batch in batches:
batch_results = llm_function(batch)
results.extend(batch_results)
return results


def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray:
dp = dot(vec1, vec2.T)
Expand Down Expand Up @@ -127,8 +151,11 @@ def create_full_text_vectors(
# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]



# Get vectors for all texts not cached
new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()
new_vectors_list: List = process_items_in_batches(texts_not_cached, vectorisation_function)


# Create a dictionary with new vectors
new_vectors_dict = {}
Expand Down Expand Up @@ -382,7 +409,7 @@ def match_questions_with_catalogue_instruments(

instrument_idx_to_score = {}
for instrument_idx, average_sim in instrument_idx_to_cosine_similarities_average.items():
score = average_sim * (0.1+instrument_idx_to_top_matches_ct.get(instrument_idx, 0))
score = average_sim * (0.1 + instrument_idx_to_top_matches_ct.get(instrument_idx, 0))
instrument_idx_to_score[instrument_idx] = score

# Find the top 10 best instrument idx matches, index 0 containing the best match etc.
Expand Down Expand Up @@ -432,7 +459,8 @@ def match_questions_with_catalogue_instruments(
"info": info,
"num_matched_questions": num_top_match_questions,
"num_ref_instrument_questions": num_questions_in_ref_instrument,
"mean_cosine_similarity": instrument_idx_to_cosine_similarities_average.get(top_catalogue_instrument_idx)
"mean_cosine_similarity": instrument_idx_to_cosine_similarities_average.get(
top_catalogue_instrument_idx)
},
))

Expand Down
84 changes: 84 additions & 0 deletions tests/test_batching_in_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import sys
import os
import unittest
import numpy

sys.path.append("../src")
from unittest import TestCase, mock
from harmony.matching.matcher import get_batch_size
from harmony.matching.matcher import process_items_in_batches


# Mock LLM function
def mock_llm_function(batch):
"""Simulates processing a batch."""
return [f"Processed: {item}" for item in batch]


class TestMatcherBatching(TestCase):

@mock.patch.dict(os.environ, {"BATCH_SIZE": "5"})
def test_batched_processing(self):
"""Test that 10 items are divided into 2 batches of 5 each."""
items = [f"item{i}" for i in range(10)] # 10 items to process
results = process_items_in_batches(items, mock_llm_function)

self.assertEqual(len(results), 10)

expected = [
"Processed: item0", "Processed: item1", "Processed: item2", "Processed: item3", "Processed: item4",
"Processed: item5", "Processed: item6", "Processed: item7", "Processed: item8", "Processed: item9",
]
self.assertEqual(results, expected)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "5"})
def test_large_batch_size(self):
"""Test batch size greater than input size."""
items = [f"item{i}" for i in range(3)] # Only 3 items
results = process_items_in_batches(items, mock_llm_function)

self.assertEqual(len(results), 3)

expected = [
"Processed: item0", "Processed: item1", "Processed: item2",
]
self.assertEqual(results, expected)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "0"})
def test_no_batching(self):
"""Test no batching (all items processed in one batch)."""
items = [f"item{i}" for i in range(10)] # 10 items to process
results = process_items_in_batches(items, mock_llm_function)

self.assertEqual(len(results), 10)

expected = [
"Processed: item0", "Processed: item1", "Processed: item2", "Processed: item3", "Processed: item4",
"Processed: item5", "Processed: item6", "Processed: item7", "Processed: item8", "Processed: item9",
]
self.assertEqual(results, expected)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "-5"})
def test_negative_batch_size(self):
"""Test when BATCH_SIZE is negative, it defaults to 0."""
items = [f"item{i}" for i in range(10)]
results = process_items_in_batches(items, mock_llm_function)
self.assertEqual(len(results), 10)

@mock.patch.dict(os.environ, {}, clear=True)
def test_default_batch_size(self):
"""Test when BATCH_SIZE is not set, it defaults to 50."""
items = [f"item{i}" for i in range(10)]
results = process_items_in_batches(items, mock_llm_function)
self.assertEqual(len(results), 10)

@mock.patch.dict(os.environ, {"BATCH_SIZE": "invalid"})
def test_invalid_batch_size(self):
"""Test when BATCH_SIZE is invalid, it defaults to 50."""
items = [f"item{i}" for i in range(10)]
results = process_items_in_batches(items, mock_llm_function)
self.assertEqual(len(results), 10)


if __name__ == "__main__":
unittest.main()