In [1]:
import util

initial_prompts = util.load_initial_prompts(
    "experiment_results/medical_concepts.txt")
pq = util.PriorityQueue(max_capacity=1000, filter_threshold=0.6, initial=initial_prompts)


# Create the prompt for the LLM
group_prompt = """The task is to group textual description pairs of visual discriminative features for tumor detection in histopathology. 
Current Prompt Pairs:
{prompt_pairs_str}
Group the prompt pairs that has same observation but differ only in language variations. Give the indexes of the grouped pairs in the output.
Provide the output as follows: list[list[index:int]]. Make sure to include all pairs in the output, even if they are not grouped with others.
Let's think step by step. Count from 1-{num_of_prompts} to verify each item is in the list.
"""

retry_prompt = """The task is to group textual description pairs of visual discriminative features for tumor detection in histopathology. 
Current Prompt Pairs:
{prompt_pairs_str}

You've already grouped some pairs, but there are still ungrouped pairs remaining.
Current Grouped indexes:
{current_grouped_indexes}

Remaining Prompt Pairs:
{prompt_pairs_str_remaining}

Provide the output as follows: list[list[index:int]]. Make sure to include all pairs in the output, even if they are not grouped with others.
Let’s think step by step.""" 

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
import re
import ast

def parse_grouped_indexes(text: str):
    import ast
    return ast.literal_eval(text)

def get_unique_indexes(grouped_indexes: list[list[int]]) -> list[int]:
    unique_indexes: list[int] = []
    for group in grouped_indexes:
        unique_indexes.append(group[0])  # Append the first index of each group
    return unique_indexes

def get_remaining_indexes(grouped_indexes: list[list[int]], total_count: int) -> list[int]:
    flat_list = [item for sublist in grouped_indexes for item in sublist]
    missing_indexes = set(range(1, total_count + 1)) - set(flat_list)
    return list(missing_indexes)

def get_grouped_indexes_from_llm(llm_prompt: str, client, max_retries) -> list[list[int]]:
    print("Sending Prompt: ", llm_prompt)
    for attempt in range(max_retries):
        try:
            response = client.get_llm_response(prompt=llm_prompt)
            print(response)
            m = re.search(r'```python\s*([\s\S]*?)\s*```', response)
            if not m:
                raise ValueError("No ```python ... ``` block found")
            list_str = m.group(1)
            grouped_indexes = ast.literal_eval(list_str)
            return grouped_indexes
        except Exception as e:
            print(f"Error in LLM response: {e}")
            if attempt == max_retries - 1:
                raise ValueError("Failed to get a valid response from the LLM after multiple retries.")
    

In [4]:
import util

NUMBER_OF_PROMPTS_TO_GROUP = 30
CROWDING_ITERATIONS = 20
MAX_RETRIES = 5
client = util.LLMClient(provider='gemini')

deleted_num = 0
for i in range(CROWDING_ITERATIONS):
    print(f"=== Iteration {i+1} of {CROWDING_ITERATIONS} ===")
    # retrieve the best prompt pairs from the priority queue
    prompt_pairs = pq.get_best_n(n=NUMBER_OF_PROMPTS_TO_GROUP)
    prompt_pairs_str = "\n".join(
        [f"{i+1}. ('{pair[0]}' , '{pair[1]}')" for i, (pair, score) in enumerate(prompt_pairs) ]
    )

    grouped_indexes = get_grouped_indexes_from_llm(
        llm_prompt=group_prompt.format(prompt_pairs_str=prompt_pairs_str, num_of_prompts=NUMBER_OF_PROMPTS_TO_GROUP),
        client=client,
        max_retries=MAX_RETRIES
    )

    remaining_indexes = get_remaining_indexes(grouped_indexes, NUMBER_OF_PROMPTS_TO_GROUP)
    if len(remaining_indexes) > 0:
        print(f"Remaining indexes: {remaining_indexes}")
        retry_prompt_str = retry_prompt.format(
            prompt_pairs_str=prompt_pairs_str,
            current_grouped_indexes= str(grouped_indexes),
            prompt_pairs_str_remaining="\n".join(
                [f"{remaining_indexes[i]}. ('{pair[0]}' , '{pair[1]}')" for i, (pair, score) in enumerate([prompt_pairs[i-1] for i in remaining_indexes]) ]
            )
        )
        grouped_indexes = get_grouped_indexes_from_llm(
            llm_prompt=retry_prompt_str,
            client=client,
            max_retries=MAX_RETRIES
        )

    unique_indexes = get_unique_indexes(grouped_indexes)
    print(f"Unique indexes: {unique_indexes}")

    # select the best prompts based on the unique indexes
    best_prompt_pairs_with_scores = [prompt_pairs[i-1] for i in unique_indexes]

    # delete the top n prompts from the priority queue
    print(f"Length of PQ before deletion: {len(pq)}")
    pq.delete_top_n(NUMBER_OF_PROMPTS_TO_GROUP)
    print(f"Length of PQ after deletion: {len(pq)}")
    
    # add the best prompts back to the priority queue
    for prompt_pair, score in best_prompt_pairs_with_scores:
        pq.insert(prompt_pair, score)


    # print the number of deleted prompts
    deleted_num += (NUMBER_OF_PROMPTS_TO_GROUP - len(unique_indexes))
    print(f"Iteration {i+1} completed. Deleted {deleted_num} duplicate prompts so far.")    

    

=== Iteration 1 of 20 ===
Sending Prompt:  The task is to group textual description pairs of visual discriminative features for tumor detection in histopathology. 
Current Prompt Pairs:
1. ('No atypical cells infiltrating surrounding tissues' , 'Atypical cells infiltrating surrounding tissues and disrupting normal structures')
2. ('No significant atypia in the surrounding lymphocytes' , 'Significant atypia observed in lymphocytes adjacent to tumor nests')
3. ('No evidence of fibrosis' , 'Prominent stromal fibrosis surrounding tumor nests')
4. ('Normal follicular architecture is preserved' , 'Disrupted follicular architecture with loss of polarity')
5. ('No prominent nucleoli are observed in lymphocytes' , 'Cells exhibit large, prominent, and irregular nucleoli')
6. ('No giant cells or multinucleated cells are seen' , 'Presence of multinucleated giant cells, suggestive of specific tumor types')
7. ('No plasmacytoid differentiation is observed' , 'Plasmacytoid differentiation is prominen

KeyboardInterrupt: 

In [5]:
pq.delete_top_n(pq.max_capacity)  # Clear the queue at the end
for prompt_pair, score in best_prompt_pairs_with_scores:
    pq.insert(prompt_pair, score)  # Reinsert the best prompts into the queue

prompts_after_crowding_with_scores = pq.get_best_n(n=NUMBER_OF_PROMPTS_TO_GROUP)  # Get the final best prompts
for i, (prompt_pair, score) in enumerate(prompts_after_crowding_with_scores):
    print(f"{prompt_pair}, Score: {score}")

('No atypical cells infiltrating surrounding tissues', 'Atypical cells infiltrating surrounding tissues and disrupting normal structures'), Score: 0.9013
('No significant atypia in the surrounding lymphocytes', 'Significant atypia observed in lymphocytes adjacent to tumor nests'), Score: 0.8997
('No evidence of fibrosis', 'Prominent stromal fibrosis surrounding tumor nests'), Score: 0.8994
('Normal follicular architecture is preserved', 'Disrupted follicular architecture with loss of polarity'), Score: 0.894
('No prominent nucleoli are observed in lymphocytes', 'Cells exhibit large, prominent, and irregular nucleoli'), Score: 0.8935
('No giant cells or multinucleated cells are seen', 'Presence of multinucleated giant cells, suggestive of specific tumor types'), Score: 0.8884
('No plasmacytoid differentiation is observed', 'Plasmacytoid differentiation is prominent within the tumor cells'), Score: 0.8883
('Interfollicular areas show a normal complement of T cells', 'Interfollicular area