In [None]:
import util

FEW_SHOT = 1
initial_prompts = util.load_last_iteration_prompts(
    f"final_results/Experiment-70-strategy-inv-bce-gemma3-{FEW_SHOT}shot_opt_pairs.txt")
pq = util.PriorityQueue(max_capacity=1000, filter_threshold=0.6, initial=initial_prompts)


pq.get_best_n(100)


[(('No infiltration of atypical cells.',
   'Infiltration of atypical cells disrupting normal lymph node architecture.'),
  0.9998),
 (('Cells arranged in a normal lymphatic pattern.',
   'Cells arranged in sheets or solid nests.'),
  0.9998),
 (('No plasma cells are identified.',
   'Numerous plasma cells with eccentric nuclei and abundant cytoplasm.'),
  0.9994),
 (('Few scattered reactive lymphocytes.',
   'Sheets of atypical lymphocytes with brisk proliferation.'),
  0.9992),
 (('No infiltration of atypical cells into surrounding tissues.',
   'Infiltration of atypical cells into surrounding tissues, disrupting normal structures.'),
  0.9991),
 (('Few reactive lymphocytes are present.',
   'Dense infiltrate of atypical lymphocytes and plasma cells.'),
  0.9989),
 (('Immunocytes are rare.',
   'Abundant immunocytes are present, forming rosettes.'),
  0.9988),
 (('Few reactive lymphocytes are present.',
   'Numerous reactive lymphocytes and plasma cells are observed.'),
  0.9987),
 (

In [35]:
import ast
import re

# --- Configuration ---
CROWDING_INTERVAL = 10         # perform crowding every X iterations
CROWDING_ITERATIONS = 3      # number of crowding passes
NUMBER_OF_PROMPTS_TO_GROUP = 30
MAX_RETRIES = 5

class CrowdingManager:
    """
    Encapsulates all crowding-related logic: grouping duplicate prompts via LLM
    and pruning the priority queue accordingly.
    """

    def __init__(self,
                 client,
                 interval: int = CROWDING_INTERVAL,
                 iterations: int = CROWDING_ITERATIONS,
                 group_size: int = NUMBER_OF_PROMPTS_TO_GROUP,
                 max_retries: int = MAX_RETRIES):
        self.client = client
        self.interval = interval
        self.iterations = iterations
        self.group_size = group_size
        self.max_retries = max_retries

        self.group_prompt = """The task is to group textual description pairs of visual discriminative features for tumor detection in histopathology. 
Current Prompt Pairs: Format: <Index. Prompt Pair>
{prompt_pairs_str}
Each pair corresponds to a feature of the same medical concept. Group the prompt pairs that has exactly same observation but differ only in language variations. Give the indexes of the grouped pairs in the output.
Provide the output as follows: list[list[index:int]]. Make sure to include all pairs in the output, even if they are not grouped with others.
Let's think step by step.
"""
        self.retry_prompt = """The task is to group textual description pairs of visual discriminative features for tumor detection in histopathology. 
Current Prompt Pairs:
{prompt_pairs_str}

You've already grouped some pairs, but there are still ungrouped pairs remaining.
Current Grouped indexes:
{current_grouped_indexes}

Remaining Prompt Pairs:
{prompt_pairs_str_remaining}

Provide the output as follows: list[list[index:int]]. Make sure to include all pairs in the output, even if they are not grouped with others.
Let's think step by step."""

    def _parse_grouped_indexes(self, text: str):
        return ast.literal_eval(text)

    def _get_unique_indexes(self, grouped_indexes: list[list[int]]) -> list[int]:
        unique_indexes: list[int] = []
        for group in grouped_indexes:
            # if the group is not a list and just an integer, let's just append it.
            if isinstance(group, int):
                unique_indexes.append(group)
            else:
                # Append the first index of each group
                unique_indexes.append(group[0])

        return unique_indexes

    def _flatten_grouped_list(self, grouped_indexes: list[list[int]]) -> list[int]:
        # sometimes the list from LLM is in the format of [[1, 2], [2, 3], 4, 5, 6]
        # have to accommodate for that - flatten both nested lists and individual items
        flat_list = []
        for item in grouped_indexes:
            if isinstance(item, list):
                flat_list.extend(item)
            else:
                flat_list.append(item)

        return flat_list

    def _get_remaining_indexes(self, grouped_indexes: list[list[int]], total_count: int) -> list[int]:
        flat_list = self._flatten_grouped_list(grouped_indexes)
        missing_indexes = set(range(1, total_count + 1)) - set(flat_list)
        return list(missing_indexes)

    def _get_grouped_indexes_from_llm(self, llm_prompt: str) -> list[list[int]]:
        print("Sending Prompt: ", llm_prompt)
        for attempt in range(self.max_retries):
            try:
                response = self.client.get_llm_response(prompt=llm_prompt)
                print(response)
                # Try to extract code block with or without 'python'
                m = re.search(r'```(?:python)?\s*([\s\S]*?)\s*```', response)
                if not m:
                    raise ValueError(
                        "No code block found between triple backticks")
                list_str = m.group(1)
                grouped_indexes = ast.literal_eval(list_str)
                return grouped_indexes
            except Exception as e:
                print(f"Error in LLM response: {e}")
                if attempt == self.max_retries - 1:
                    raise ValueError(
                        "Failed to get a valid response from the LLM after multiple retries.")

    def perform_crowding(self, pq: util.PriorityQueue) -> util.PriorityQueue:

        deleted_num = 0
        for i in range(self.iterations):
            print(f"=== Iteration {i+1} of {self.iterations} ===")
            # retrieve the best prompt pairs from the priority queue
            prompt_pairs = pq.get_best_n(n=self.group_size)
            prompt_pairs_str = "\n".join(
                [f"{i+1}. ('{pair[0]}' , '{pair[1]}')" for i,
                 (pair, score) in enumerate(prompt_pairs)]
            )

            grouped_indexes = self._get_grouped_indexes_from_llm(
                llm_prompt=self.group_prompt.format(
                    prompt_pairs_str=prompt_pairs_str, num_of_prompts=self.group_size),
            )

            # Dealing with remaining indexes that were not grouped
            if self.group_size > len(prompt_pairs):
                # we have pruned the queue to less than group size
                remaining_indexes = self._get_remaining_indexes(
                    grouped_indexes,  len(prompt_pairs))
            else:
                # usual scenario when the pq size is greater than the grouping size
                remaining_indexes = self._get_remaining_indexes(
                    grouped_indexes, self.group_size)

            if len(remaining_indexes) > 0:
                retry_prompt_str = self.retry_prompt.format(
                    prompt_pairs_str=prompt_pairs_str,
                    current_grouped_indexes=str(grouped_indexes),
                    prompt_pairs_str_remaining="\n".join(
                        [f"{original_idx}. {prompt_pairs[original_idx-1][0]}" for original_idx in remaining_indexes]
                    )
                )
                grouped_indexes = self._get_grouped_indexes_from_llm(
                    llm_prompt=retry_prompt_str,
                )

            unique_indexes = self._get_unique_indexes(grouped_indexes)
            print(f"Unique indexes: {unique_indexes}")

            # select the best prompts based on the unique indexes
            print("Debug: Length of prompt pairs before selecting best:",
                  len(prompt_pairs))

            # if any index is out of range, skip this iteration
            if any(i < 1 or i > len(prompt_pairs) for i in unique_indexes):
                print("Warning: Some indexes are out of range. Skipping this iteration.")
                continue

            best_prompt_pairs_with_scores = [
                prompt_pairs[i-1] for i in unique_indexes]

            # delete the top n prompts from the priority queue
            pq.delete_top_n(self.group_size)
            # add the best prompts back to the priority queue
            for prompt_pair, score in best_prompt_pairs_with_scores:
                pq.insert(prompt_pair, score)

            # print the number of deleted prompts
            deleted_num += (self.group_size - len(unique_indexes))
            print(
                f"Iteration {i+1} completed. Deleted {deleted_num} duplicate prompts so far.")

        pq.delete_top_n(pq.max_capacity)  # Clear the queue at the end
        for prompt_pair, score in best_prompt_pairs_with_scores:
            # Reinsert the best prompts into the queue
            pq.insert(prompt_pair, score)

        return pq

In [36]:
import util
llm_client = util.LLMClient(provider="gemini") 
crowding_manager = CrowdingManager(client=llm_client)

pq = crowding_manager.perform_crowding(pq=pq)

=== Iteration 1 of 3 ===
Sending Prompt:  The task is to group textual description pairs of visual discriminative features for tumor detection in histopathology. 
Current Prompt Pairs: Format: <Index. Prompt Pair>
1. ('No infiltration of atypical cells.' , 'Infiltration of atypical cells disrupting normal lymph node architecture.')
2. ('Cells arranged in a normal lymphatic pattern.' , 'Cells arranged in sheets or solid nests.')
3. ('No plasma cells are identified.' , 'Numerous plasma cells with eccentric nuclei and abundant cytoplasm.')
4. ('Few scattered reactive lymphocytes.' , 'Sheets of atypical lymphocytes with brisk proliferation.')
5. ('No infiltration of atypical cells into surrounding tissues.' , 'Infiltration of atypical cells into surrounding tissues, disrupting normal structures.')
6. ('Few reactive lymphocytes are present.' , 'Dense infiltrate of atypical lymphocytes and plasma cells.')
7. ('Immunocytes are rare.' , 'Abundant immunocytes are present, forming rosettes.')
8.

In [None]:
# write this to final_results/crowded/{FEW_SHOT}-shot.txt
for prompt_pair_with_score in pq.get_best_n(n=100):
    print(f"{prompt_pair_with_score[0]}, Score: {prompt_pair_with_score[1]}")

('No infiltration of atypical cells.', 'Infiltration of atypical cells disrupting normal lymph node architecture.'), Score: 0.9998
('No plasma cells are identified.', 'Numerous plasma cells with eccentric nuclei and abundant cytoplasm.'), Score: 0.9994
('Immunocytes are rare.', 'Abundant immunocytes are present, forming rosettes.'), Score: 0.9988
('Mature lymphocytes.', 'Immunoblasts and plasmablasts.'), Score: 0.9987
('No emperipolesis is observed.', 'Frequent emperipolesis is seen within tumor cells.'), Score: 0.9986
('Stroma is edematous with minimal cellularity.', 'Stroma is densely cellular with a myxoid appearance.'), Score: 0.9984
('Inflammatory infiltrate is minimal and lymphocytic.', 'Dense inflammatory infiltrate with neutrophils and eosinophils.'), Score: 0.998
('Background shows normal lymphoid aggregates.', 'Background shows diffuse sheets of atypical cells.'), Score: 0.9973
('Few reactive germinal centers.', 'Numerous large and atypical germinal centers with prominent nuc

In [None]:
import util
scores = [score for (_, score) in pq.get_best_n(n=1000)]
knee_point_analyzer = util.KneePointAnalysis(scores)
knee_point = knee_point_analyzer.find_knee_point()

print(f"Knee point found at: {knee_point}")

# Write this to final_results/knee/{FEW_SHOT}-shot.txt
for prompt_pair_with_score in pq.get_best_n(n=knee_point):
    print(f"{prompt_pair_with_score[0]}, Score: {prompt_pair_with_score[1]}")

Knee point found at: 11
('No infiltration of atypical cells.', 'Infiltration of atypical cells disrupting normal lymph node architecture.'), Score: 0.9998
('No plasma cells are identified.', 'Numerous plasma cells with eccentric nuclei and abundant cytoplasm.'), Score: 0.9994
('Immunocytes are rare.', 'Abundant immunocytes are present, forming rosettes.'), Score: 0.9988
('Mature lymphocytes.', 'Immunoblasts and plasmablasts.'), Score: 0.9987
('No emperipolesis is observed.', 'Frequent emperipolesis is seen within tumor cells.'), Score: 0.9986
('Stroma is edematous with minimal cellularity.', 'Stroma is densely cellular with a myxoid appearance.'), Score: 0.9984
('Inflammatory infiltrate is minimal and lymphocytic.', 'Dense inflammatory infiltrate with neutrophils and eosinophils.'), Score: 0.998
('Background shows normal lymphoid aggregates.', 'Background shows diffuse sheets of atypical cells.'), Score: 0.9973
('Few reactive germinal centers.', 'Numerous large and atypical germinal ce

# Final Performance

In [None]:
# 1. load model, process, and tokenizer
model, preprocess, tokenizer = util.load_clip_model()
print("Model, preprocess, and tokenizer loaded successfully.")

# 2. load dataset
# 1) Unpack—annotate what extract_center_embeddings returns
centers_features: List[np.ndarray]
centers_labels:   List[np.ndarray]
centers_features, centers_labels = util.extract_center_embeddings(
    model=model,
    preprocess=preprocess,
    num_centers=5,  # evaluating on all centers
    isTrain=False,  # Evaluating on test centers only
)
print("Center embeddings extracted successfully.")

# Convert to torch tensors for each center
centers_features = [torch.from_numpy(feat) for feat in centers_features]
centers_labels = [torch.from_numpy(label) for label in centers_labels]
for i, _ in enumerate(centers_features):
    print(f"Evaluating center {i}...")
    results = util.evaluate_prompt_list(
        pq.get_best_n(n=knee_point),
        centers_features[i],
        centers_labels[i],
        model,
        tokenizer,
        unweighted=False
    )

    # Writw this to final_results/evaluation/{FEW_SHOT}-shot-results.txt
    print("\n--- Ensemble Evaluation Results ---")
    print(f"Accuracy: {results['accuracy']:.4f}")
    print(f"AUC: {results['auc']:.4f}")
    print("Confusion Matrix:\n", results['cm'])
    print("Classification Report:\n", results['report'])