# <a id='toc1_'></a>[Demo of the LLM Label Retrieval](#toc0_)

**Table of contents**<a id='toc0_'></a>    
- [Demo of the LLM Label Retrieval](#toc1_)    
  - [Imports](#toc1_1_)    
  - [BooleanQuery](#toc1_2_)    
  - [DSPy Signatures](#toc1_3_)    
  - [QueryToLabelsTranslator Module](#toc1_4_)    
  - [Paper Filter Module](#toc1_5_)    
  - [AdvancedQueryTranslator](#toc1_6_)    
  - [Testing the AdvancedQueryTranslator](#toc1_7_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

## <a id='toc1_1_'></a>[Imports](#toc0_)

In [1]:
import dspy
from typing import List, Dict, Set, Optional, Union
from dataclasses import dataclass
import re

## <a id='toc1_2_'></a>[BooleanQuery](#toc0_)

In [2]:
@dataclass
class BooleanQuery:
    """Represents a boolean query over labels"""
    must_have: Set[str]      # AND conditions
    should_have: Set[str]    # OR conditions
    must_not_have: Set[str]  # NOT conditions
    confidence: float = 1.0

    def __str__(self):
        parts = []
        if self.must_have:
            parts.append(f"MUST: {list(self.must_have)}")
        if self.should_have:
            parts.append(f"SHOULD: {list(self.should_have)}")
        if self.must_not_have:
            parts.append(f"NOT: {list(self.must_not_have)}")
        return " | ".join(parts)

## <a id='toc1_3_'></a>[DSPy Signatures](#toc0_)

In [3]:
class QueryParser(dspy.Signature):
    """Parse natural language query into structured components using the available labels. Think broadly about the different clusters of labels."""

    query = dspy.InputField(desc="Natural language query from user")
    available_labels = dspy.InputField(
        desc="List of available labels in the database which will be used to match concepts in the query to actual labels (comma-separated)")
    
    main_concepts = dspy.OutputField(
        desc="Primary concepts the user is looking for (comma-separated)")
    required_concepts = dspy.OutputField(
        desc="Concepts that MUST be present (comma-separated, empty if none)")
    optional_concepts = dspy.OutputField(
        desc="Concepts that SHOULD be present but not required (comma-separated, empty if none)")
    excluded_concepts = dspy.OutputField(
        desc="Concepts that must NOT be present (comma-separated, empty if none)")

In [4]:
class LabelMatcher(dspy.Signature):
    """Match parsed concepts to actual database labels"""

    concept = dspy.InputField(desc="A concept extracted from user query")
    available_labels = dspy.InputField(desc="List of available labels with their usage counts")

    matched_labels = dspy.OutputField(
        desc="Best matching labels for this concept (comma-separated). MUST be from available_labels list only. Prefer labels with higher usage counts when multiple good matches exist.")
    confidence = dspy.OutputField(desc="Confidence score 0-1 for the matches")

## <a id='toc1_4_'></a>[QueryToLabelsTranslator Module](#toc0_)

In [5]:
class QueryToLabelsTranslator(dspy.Module):
    """Main translator from natural language to boolean label queries with DSPy Refine validation"""

    def __init__(self, available_labels: List[str], label_counts: Dict[str, int], max_retries: int = 3):
        super().__init__()
        self.available_labels = available_labels
        self.label_set = set(available_labels)
        self.label_counts = label_counts
        self.max_retries = max_retries

        # Create formatted label list with counts for LLM context
        self.labels_with_counts = self._format_labels_with_counts()

        # DSPy modules
        self.query_parser = dspy.ChainOfThought(QueryParser)
        
        # Create reward function for label validation (correct signature: args, pred)
        def label_validation_reward(args, pred: dspy.Prediction) -> float:
            """Reward function that returns 1.0 if all labels are valid, 0.0 otherwise"""
            try:
                matched_labels = []
                if hasattr(pred, 'matched_labels') and pred.matched_labels.strip():
                    matched_labels = [label.strip() for label in pred.matched_labels.split(',') if label.strip()]
                
                # Check if all labels are in available_labels
                valid_labels = all(label in self.label_set for label in matched_labels)
                return 1.0 if valid_labels else 0.0
            except:
                return 0.0
        
        # Wrap LabelMatcher with Refine for automatic feedback-based retries
        self.label_matcher = dspy.Refine(
            dspy.ChainOfThought(LabelMatcher),
            N=max_retries,
            reward_fn=label_validation_reward,
            threshold=1.0  # Stop when we get perfect validation (reward = 1.0)
        )

    def _format_labels_with_counts(self) -> str:
        """Format labels with their counts for LLM context"""
        # Sort labels by count (descending) to show most popular first
        sorted_labels = sorted(self.label_counts.items(), key=lambda x: x[1], reverse=True)
        
        # Format as "label (count=X)" 
        formatted_labels = [f"{label} (count={count})" for label, count in sorted_labels]
        return ", ".join(formatted_labels)

    def forward(self, query: str) -> BooleanQuery:
        """Translate natural language query to boolean label query"""

        # Step 1: Parse the natural language query
        parsed = self.query_parser(
            query=query,
            available_labels=self.labels_with_counts
        )

        # Step 2: Match concepts to actual labels
        boolean_query = BooleanQuery(
            must_have=set(),
            should_have=set(),
            must_not_have=set()
        )

        # Process main concepts (these become SHOULD conditions)
        if parsed.main_concepts.strip():
            main_labels = self._match_concepts_to_labels(
                parsed.main_concepts.split(',')
            )
            boolean_query.should_have.update(main_labels)

        # Process required concepts (these become MUST conditions)
        if parsed.required_concepts.strip():
            required_labels = self._match_concepts_to_labels(
                parsed.required_concepts.split(',')
            )
            boolean_query.must_have.update(required_labels)

        # Process optional concepts (these become additional SHOULD conditions)
        if parsed.optional_concepts.strip():
            optional_labels = self._match_concepts_to_labels(
                parsed.optional_concepts.split(',')
            )
            boolean_query.should_have.update(optional_labels)

        # Process excluded concepts (these become NOT conditions)
        if parsed.excluded_concepts.strip():
            excluded_labels = self._match_concepts_to_labels(
                parsed.excluded_concepts.split(',')
            )
            boolean_query.must_not_have.update(excluded_labels)

        return boolean_query

    def _match_concepts_to_labels(self, concepts: List[str]) -> Set[str]:
        """Match a list of concepts to actual database labels using Refine validation"""
        matched_labels = set()

        for concept in concepts:
            concept = concept.strip()
            if not concept:
                continue

            try:
                # Use Refine wrapped LabelMatcher - it will automatically provide feedback and retry
                matches = self.label_matcher(
                    concept=concept,
                    available_labels=self.labels_with_counts
                )

                # Parse matched labels - Refine ensures these are valid through feedback loop
                if matches.matched_labels.strip():
                    for label in matches.matched_labels.split(','):
                        label = label.strip()
                        if label and label in self.label_set:  # Double-check validity
                            matched_labels.add(label)

            except Exception as e:
                # Handle any errors (network, JSON parsing, etc.)
                print(f"Warning: Could not match concept '{concept}' to valid labels: {e}")
                continue

        return matched_labels

## <a id='toc1_5_'></a>[Paper Filter Module](#toc0_)

In [6]:
class PaperFilter:
    """Filter papers based on boolean label queries"""

    def __init__(self, papers_db: List[Dict]):
        """
        papers_db: List of dicts with keys: 'id', 'title', 'abstract', 'labels'
        """
        self.papers_db = papers_db

    def filter_papers(self, boolean_query: BooleanQuery, min_should_match: int = 1) -> List[Dict]:
        """Filter papers based on boolean query"""

        filtered_papers = []

        for paper in self.papers_db:
            paper_labels = set(paper.get('labels', []))

            # Check MUST conditions
            if boolean_query.must_have:
                if not boolean_query.must_have.issubset(paper_labels):
                    continue

            # Check NOT conditions
            if boolean_query.must_not_have:
                if boolean_query.must_not_have.intersection(paper_labels):
                    continue

            # Check SHOULD conditions
            if boolean_query.should_have:
                should_matches = len(
                    boolean_query.should_have.intersection(paper_labels))
                if should_matches < min_should_match:
                    continue

                # Add relevance score based on how many SHOULD conditions match
                paper['relevance_score'] = should_matches / \
                    len(boolean_query.should_have)
            else:
                paper['relevance_score'] = 1.0

            filtered_papers.append(paper)

        # Sort by relevance score
        filtered_papers.sort(key=lambda x: x['relevance_score'], reverse=True)

        return filtered_papers

## <a id='toc1_6_'></a>[AdvancedQueryTranslator](#toc0_)

In [7]:
class AdvancedQueryTranslator(dspy.Module):
    """Enhanced translator that handles complex boolean logic with Refine validation"""

    def __init__(self, available_labels: List[str], label_counts: Dict[str, int], max_retries: int = 3):
        super().__init__()
        self.basic_translator = QueryToLabelsTranslator(available_labels, label_counts, max_retries)
        self.available_labels = available_labels
        self.label_counts = label_counts
        self.max_retries = max_retries

    def forward(self, query: str, force_basic: bool = False) -> BooleanQuery:
        """Handle complex queries with explicit boolean logic"""

        # Check for explicit boolean operators
        if not force_basic and self._has_explicit_boolean_logic(query):
            return self._parse_explicit_boolean_query(query)
        else:
            return self.basic_translator(query)

    def _has_explicit_boolean_logic(self, query: str) -> bool:
        """Check if query contains explicit AND, OR, NOT operators"""
        boolean_keywords = ['AND', 'OR', 'NOT', 'but not', 'except', 'without']
        return any(keyword.lower() in query.lower() for keyword in boolean_keywords)

    def _parse_explicit_boolean_query(self, query: str) -> BooleanQuery:
        """Parse queries with explicit boolean logic"""

        # This is a simplified parser - you could make it much more sophisticated
        query_lower = query.lower()

        # Split on boolean operators
        parts = re.split(
            r'\b(and|or|not|but not|except|without)\b', query_lower)

        boolean_query = BooleanQuery(set(), set(), set())
        current_mode = 'should'  # Default mode

        for part in parts:
            part = part.strip()
            if not part:
                continue

            if part in ['and']:
                current_mode = 'must'
            elif part in ['or']:
                current_mode = 'should'
            elif part in ['not', 'but not', 'except', 'without']:
                current_mode = 'not'
            else:
                # Extract labels from this part
                part_labels = self._extract_labels_from_text(part)

                if current_mode == 'must':
                    boolean_query.must_have.update(part_labels)
                elif current_mode == 'should':
                    boolean_query.should_have.update(part_labels)
                elif current_mode == 'not':
                    boolean_query.must_not_have.update(part_labels)

        return boolean_query

    def _extract_labels_from_text(self, text: str) -> Set[str]:
        """Extract labels from a piece of text"""
        # Use the basic translator on just this piece
        sub_query = self.basic_translator(text)
        return sub_query.should_have.union(sub_query.must_have)

## <a id='toc1_7_'></a>[Testing the AdvancedQueryTranslator](#toc0_)

In [14]:
import pandas as pd
import ast
from collections import Counter
import os

# 🔑 LLM backend: point DSPy to your provider (OpenAI shown here)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Please set the OPENAI_API_KEY environment variable.")
turbo = dspy.LM(model="gpt-4o-mini", temperature=0, max_tokens=16384, api_key=OPENAI_API_KEY)
dspy.settings.configure(lm=turbo)


# Load the Acelot Library CSV
print("Loading Acelot Library data...")
df = pd.read_csv("Acelot Library.csv")
print(f"Loaded {len(df)} papers from Acelot Library")

# Extract all unique tags from the Tags column to build available_labels
print("\nExtracting available labels from Tags column...")
all_tags = []
for idx, row in df.iterrows():
    tags = row.get('Tags', '')
    if pd.notna(tags) and tags.strip():
        # Split tags by common delimiters (semicolon, comma, pipe)
        tag_list = []
        for delimiter in [';', ',', '|']:
            if delimiter in tags:
                tag_list = [tag.strip() for tag in tags.split(delimiter)]
                break
        else:
            # No delimiter found, treat as single tag
            tag_list = [tags.strip()]
        
        # Add all tags to the list (for counting)
        all_tags.extend([tag for tag in tag_list if tag])

# Count label occurrences
label_counts = Counter(all_tags)
available_labels = list(label_counts.keys())

print(f"Found {len(available_labels)} unique labels")
print("Sample labels with counts:")
for label, count in list(label_counts.most_common(10)):
    print(f"  {label}: {count} papers")

# Create papers_db from DataFrame
print("\nCreating papers database...")
papers_db = []
for idx, row in df.iterrows():
    # Extract tags for this paper
    tags = row.get('Tags', '')
    paper_labels = []
    if pd.notna(tags) and tags.strip():
        # Split tags by common delimiters
        for delimiter in [';', ',', '|']:
            if delimiter in tags:
                paper_labels = [tag.strip() for tag in tags.split(delimiter)]
                break
        else:
            paper_labels = [tags.strip()]
        
        # Clean up labels
        paper_labels = [tag for tag in paper_labels if tag]
    
    paper = {
        'id': idx,
        'title': row.get('Title', 'No Title'),
        'abstract': row.get('Abstract', 'No Abstract'),
        'labels': paper_labels,
        'library_url': row.get('Library URL', ''),
        'year': row.get('year', ''),
        'journal': row.get('Journal', ''),
        'authors': row.get('Author', '')
    }
    papers_db.append(paper)

print(f"Created papers database with {len(papers_db)} papers")

# Initialize translator and filter with Refine validation and label counts
print("\nInitializing translator with Refine validation and label count priors...")

# Create translator using Refine pattern with label counts
translator = AdvancedQueryTranslator(available_labels, label_counts, max_retries=3)
paper_filter = PaperFilter(papers_db)

# Show label distribution insights
print(f"\nLabel distribution insights:")
print(f"Most common labels: {dict(label_counts.most_common(5))}")
print(f"Least common labels: {dict(label_counts.most_common()[-5:])}")
print(f"Total label instances: {sum(label_counts.values())}")
print(f"Average labels per paper: {sum(label_counts.values()) / len(papers_db):.2f}")

# Test queries
test_queries = [
    "Find papers on the discovery of new alzheimer's drugs",
    # "Find papers on protein folding using machine learning",
    # "Drug discovery with graph neural networks",
    # "Machine learning but not computational chemistry",
    # "Deep learning AND molecular dynamics",
    # "Transformers OR neural networks for chemistry"
]

print("\n" + "=" * 60)
print("Query2Label with DSPy Refine + Label Count Priors - Acelot Library Test")
print("=" * 60)

for query in test_queries:
    print(f"\nQuery: {query}")
    try:
        boolean_query = translator(query, force_basic=True)
        print(f"Boolean Query: {boolean_query}")

        filtered_papers = paper_filter.filter_papers(boolean_query)
        print(f"Found {len(filtered_papers)} papers")
        
        # Show top 5 results
        for i, paper in enumerate(filtered_papers[:5]):
            print(f"  {i+1}. {paper['title'][:80]}..." if len(paper['title']) > 80 else f"  {i+1}. {paper['title']}")
            print(f"     Score: {paper.get('relevance_score', 0):.2f} | Labels: {paper['labels'][:3]}")
            if paper.get('library_url'):
                print(f"     URL: {paper['library_url']}")
            print()
        
        if len(filtered_papers) > 5:
            print(f"  ... and {len(filtered_papers) - 5} more papers")
            
    except Exception as e:
        print(f"Error processing query: {e}")
    print("-" * 40)

print(f"\nTotal papers in database: {len(papers_db)}")
print(f"Total available labels: {len(available_labels)}")
print("Top 10 most used labels:", [label for label, _ in label_counts.most_common(10)])

Loading Acelot Library data...
Loaded 1613 papers from Acelot Library

Extracting available labels from Tags column...
Found 420 unique labels
Sample labels with counts:
  Machine Learning In Chemistry: 441 papers
  In Silico Drug Discovery: 325 papers
  Drug Discovery: 258 papers
  Gain of Toxicity and Loss of Function: 251 papers
  Animal Model: 236 papers
  Reaction and Retrosynthesis Prediction: 228 papers
  Small Molecule Therapies: 225 papers
  ALS: 221 papers
  Protein Ligand Binding: 208 papers
  Misfolded Proteins: 205 papers

Creating papers database...
Created papers database with 1613 papers

Initializing translator with Refine validation and label count priors...

Label distribution insights:
Most common labels: {'Machine Learning In Chemistry': 441, 'In Silico Drug Discovery': 325, 'Drug Discovery': 258, 'Gain of Toxicity and Loss of Function': 251, 'Animal Model': 236}
Least common labels: {'AME': 1, '@shruti': 1, 'In Vitro Assay': 1, 'Irritable Bowel Disease': 1, 'autoi

In [17]:
turbo.inspect_history()





[34m[2025-07-05T18:49:47.119022][0m

[31mSystem message:[0m

Your input fields are:
1. `query` (str): Natural language query from user
2. `available_labels` (str): List of available labels in the database which will be used to match concepts in the query to actual labels (comma-separated)
Your output fields are:
1. `reasoning` (str): 
2. `main_concepts` (str): Primary concepts the user is looking for (comma-separated)
3. `required_concepts` (str): Concepts that MUST be present (comma-separated, empty if none)
4. `optional_concepts` (str): Concepts that SHOULD be present but not required (comma-separated, empty if none)
5. `excluded_concepts` (str): Concepts that must NOT be present (comma-separated, empty if none)
All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## query ## ]]
{query}

[[ ## available_labels ## ]]
{available_labels}

[[ ## reasoning ## ]]
{reasoning}

[[ ## main_concepts ## ]]
{main_concepts}

[[ ## required_c