In [None]:
import util

initial_prompts = util.load_initial_prompts(
    "experiment_results/medical_concepts.txt")
pq = util.PriorityQueue(max_capacity=1000, filter_threshold=0.6, initial=initial_prompts)



In [6]:
import ast
import re

# --- Configuration ---
CROWDING_INTERVAL = 10         # perform crowding every X iterations
CROWDING_ITERATIONS = 3      # number of crowding passes
NUMBER_OF_PROMPTS_TO_GROUP = 30
MAX_RETRIES = 5

class CrowdingManager:
    """
    Encapsulates all crowding-related logic: grouping duplicate prompts via LLM
    and pruning the priority queue accordingly.
    """

    def __init__(self,
                 client,
                 interval: int = CROWDING_INTERVAL,
                 iterations: int = CROWDING_ITERATIONS,
                 group_size: int = NUMBER_OF_PROMPTS_TO_GROUP,
                 max_retries: int = MAX_RETRIES):
        self.client = client
        self.interval = interval
        self.iterations = iterations
        self.group_size = group_size
        self.max_retries = max_retries

        self.group_prompt = """The task is to group textual descriptions corresponding to exactly the same medical concept. 
Descriptions to be grouped:
{prompt_pairs_str}
Group the descriptions that has exactly same observation but differ only in language variations. Give the indexes of the grouped descriptions in the output.
Provide the output as follows: list[list[index:int]]. Make sure to include all the descriptions in the output, even if they are not grouped with others.
Let's think step by step.
"""
        self.retry_prompt = """The task is to group textual descriptions corresponding to exactly the same medical concept.
Descriptions to be grouped:
:
{prompt_pairs_str}

You've already grouped some descriptions, but there are still ungrouped descriptions remaining.
Current Grouped indexes:
{current_grouped_indexes}

Remaining descriptions to be grouped:
{prompt_pairs_str_remaining}

Provide the output as follows: list[list[index:int]]. Make sure to include all descriptions in the output, even if they are not grouped with others.
Let's think step by step."""

    def _parse_grouped_indexes(self, text: str):
        return ast.literal_eval(text)

    def _get_unique_indexes(self, grouped_indexes: list[list[int]]) -> list[int]:
        unique_indexes: list[int] = []
        for group in grouped_indexes:
            # if the group is not a list and just an integer, let's just append it.
            if isinstance(group, int):
                unique_indexes.append(group)
            else:
                # Append the first index of each group
                unique_indexes.append(group[0])

        return unique_indexes

    def _flatten_grouped_list(self, grouped_indexes: list[list[int]]) -> list[int]:
        # sometimes the list from LLM is in the format of [[1, 2], [2, 3], 4, 5, 6]
        # have to accommodate for that - flatten both nested lists and individual items
        flat_list = []
        for item in grouped_indexes:
            if isinstance(item, list):
                flat_list.extend(item)
            else:
                flat_list.append(item)

        return flat_list

    def _get_remaining_indexes(self, grouped_indexes: list[list[int]], total_count: int) -> list[int]:
        flat_list = self._flatten_grouped_list(grouped_indexes)
        missing_indexes = set(range(1, total_count + 1)) - set(flat_list)
        return list(missing_indexes)

    def _get_grouped_indexes_from_llm(self, llm_prompt: str) -> list[list[int]]:
        print("Sending Prompt: ", llm_prompt)
        for attempt in range(self.max_retries):
            try:
                response = self.client.get_llm_response(prompt=llm_prompt)
                print(response)
                # Try to extract code block with or without 'python'
                m = re.search(r'```(?:python)?\s*([\s\S]*?)\s*```', response)
                if not m:
                    raise ValueError(
                        "No code block found between triple backticks")
                list_str = m.group(1)
                grouped_indexes = ast.literal_eval(list_str)
                return grouped_indexes
            except Exception as e:
                print(f"Error in LLM response: {e}")
                if attempt == self.max_retries - 1:
                    raise ValueError(
                        "Failed to get a valid response from the LLM after multiple retries.")

    def perform_crowding(self, pq: util.PriorityQueue) -> util.PriorityQueue:

        deleted_num = 0
        for i in range(self.iterations):
            print(f"=== Iteration {i+1} of {self.iterations} ===")
            # retrieve the best prompt pairs from the priority queue
            prompt_pairs = pq.get_best_n(n=self.group_size)
            prompt_pairs_str = "\n".join(
                [f"{i+1}. {pair[0]}" for i,
                 (pair, score) in enumerate(prompt_pairs)]
            )

            grouped_indexes = self._get_grouped_indexes_from_llm(
                llm_prompt=self.group_prompt.format(
                    prompt_pairs_str=prompt_pairs_str, num_of_prompts=self.group_size),
            )

            # Dealing with remaining indexes that were not grouped
            if self.group_size > len(prompt_pairs):
                # we have pruned the queue to less than group size
                remaining_indexes = self._get_remaining_indexes(
                    grouped_indexes,  len(prompt_pairs))
            else:
                # usual scenario when the pq size is greater than the grouping size
                remaining_indexes = self._get_remaining_indexes(
                    grouped_indexes, self.group_size)

            if len(remaining_indexes) > 0:
                retry_prompt_str = self.retry_prompt.format(
                    prompt_pairs_str=prompt_pairs_str,
                    current_grouped_indexes=str(grouped_indexes),
                    prompt_pairs_str_remaining="\n".join(
                        [f"{original_idx}. {prompt_pairs[original_idx-1][0][0]}" for original_idx in remaining_indexes]
                    )
                )
                grouped_indexes = self._get_grouped_indexes_from_llm(
                    llm_prompt=retry_prompt_str,
                )

            unique_indexes = self._get_unique_indexes(grouped_indexes)
            print(f"Unique indexes: {unique_indexes}")

            # select the best prompts based on the unique indexes
            print("Debug: Length of prompt pairs before selecting best:",
                  len(prompt_pairs))

            # if any index is out of range, skip this iteration
            if any(i < 1 or i > len(prompt_pairs) for i in unique_indexes):
                print("Warning: Some indexes are out of range. Skipping this iteration.")
                continue

            best_prompt_pairs_with_scores = [
                prompt_pairs[i-1] for i in unique_indexes]

            # delete the top n prompts from the priority queue
            pq.delete_top_n(self.group_size)
            # add the best prompts back to the priority queue
            for prompt_pair, score in best_prompt_pairs_with_scores:
                pq.insert(prompt_pair, score)

            # print the number of deleted prompts
            deleted_num += (self.group_size - len(unique_indexes))
            print(
                f"Iteration {i+1} completed. Deleted {deleted_num} duplicate prompts so far.")

        pq.delete_top_n(pq.max_capacity)  # Clear the queue at the end
        for prompt_pair, score in best_prompt_pairs_with_scores:
            # Reinsert the best prompts into the queue
            pq.insert(prompt_pair, score)

        return pq

In [7]:
import util
llm_client = util.LLMClient(provider="gemini") 
crowding_manager = CrowdingManager(client=llm_client)

crowding_manager.perform_crowding(pq=pq)

=== Iteration 1 of 3 ===
Sending Prompt:  The task is to group textual descriptions corresponding to exactly the same medical concept. 
Descriptions to be grouped:
1. No atypical cells infiltrating surrounding tissues
2. No evidence of fibrosis
3. Normal follicular architecture is preserved
4. No prominent nucleoli are observed in lymphocytes
5. No giant cells or multinucleated cells are seen
6. No plasmacytoid differentiation is observed
7. Interfollicular areas contain small lymphocytes with regular nuclei
8. Lymphocytes exhibit a uniform population
9. No evidence of mitotic activity in lymphocytes
10. Stroma is delicate and sparsely collagenized
11. Blood vessels are small and exhibit normal morphology
12. Sinusoids are open and show normal lymphocyte flow
13. Nuclei are round to oval with smooth chromatin
14. Smooth, well-defined nuclear borders and absence of nuclear grooves in lymphocytes
15. Normal reactive germinal centers with tingible body macrophages
16. Lymphocytes are even

<util.PriorityQueue at 0x222df1489d0>