<a href="https://colab.research.google.com/github/mishra-yogendra/DSPy-Practical-Assignment/blob/main/DSPy_Practical_Assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#DSPy Entity Extraction & Knowledge Graph Generator
##Assignment: Structuring Unstructured Data with LLMs

###This notebook demonstrates:
###1. Entity extraction from web content
###2. Intelligent deduplication with confidence loops
###3. Mermaid knowledge graph generation
###4. CSV export of structured data


In [3]:
!pip install dspy
import dspy
from dspy import Predict, ChainOfThought
from pydantic import BaseModel, Field
from typing import List
import requests
from bs4 import BeautifulSoup
import pandas as pd
import re
import os

Collecting dspy
  Downloading dspy-3.0.3-py3-none-any.whl.metadata (7.2 kB)
Collecting backoff>=2.2 (from dspy)
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Collecting optuna>=3.4.0 (from dspy)
  Downloading optuna-4.5.0-py3-none-any.whl.metadata (17 kB)
Collecting magicattr>=0.1.6 (from dspy)
  Downloading magicattr-0.1.6-py2.py3-none-any.whl.metadata (3.2 kB)
Collecting litellm>=1.64.0 (from dspy)
  Downloading litellm-1.79.1-py3-none-any.whl.metadata (30 kB)
Collecting diskcache>=5.6.0 (from dspy)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Collecting json-repair>=0.30.0 (from dspy)
  Downloading json_repair-0.52.5-py3-none-any.whl.metadata (11 kB)
Collecting asyncer==0.0.8 (from dspy)
  Downloading asyncer-0.0.8-py3-none-any.whl.metadata (6.7 kB)
Collecting gepa==0.0.7 (from gepa[dspy]==0.0.7->dspy)
  Downloading gepa-0.0.7-py3-none-any.whl.metadata (22 kB)
Collecting fastuuid>=0.13.0 (from litellm>=1.64.0->dspy)
  Downloading fastuuid-0.14.0-cp

In [38]:
# CONFIGURATION

API_KEY = "YOUR_GROQ_API_KEY"

# Configure DSPy with Groq AI
import os
os.environ["GROQ_API_KEY"] = API_KEY

try:
    # Groq AI configuration - Fast inference with open source models
    lm = dspy.LM(
        model="groq/llama-3.3-70b-versatile",  # Best balance of speed and quality
        api_key=API_KEY,
        api_base="https://api.groq.com/openai/v1",
        temperature=0.3,  # Lower temperature for consistent extraction
        max_tokens=8000
    )
    dspy.configure(lm=lm)
    print("‚úì DSPy configured with Groq AI")
    print(f"  Model: llama-3.3-70b-versatile")
except Exception as e:
    print(f"Configuration failed: {e}")
    print("Please check your Groq API key")


# URLs to process
URLS = [
    'https://en.wikipedia.org/wiki/Sustainable_agriculture',
    'https://www.nature.com/articles/d41586-025-03353-5',
    'https://www.sciencedirect.com/science/article/pii/S1043661820315152',
    'https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10457221/',
    'https://www.fao.org/3/y4671e/y4671e06.htm',
    'https://www.medscape.com/viewarticle/time-reconsider-tramadol-chronic-pain-2025a1000ria',
    'https://www.sciencedirect.com/science/article/pii/S0378378220307088',
    'https://www.frontiersin.org/news/2025/09/01/rectangle-telescope-finding-habitable-planets',
    'https://www.medscape.com/viewarticle/second-dose-boosts-shingles-protection-adults-aged-65-years-2025a1000ro7',
    'https://www.theguardian.com/global-development/2025/oct/13/astro-ambassadors-stargazers-himalayas-hanle-ladakh-india'
]


# PYDANTIC MODELS FOR STRUCTURED OUTPUT

class EntityWithAttr(BaseModel):
    """Structured entity with semantic type"""
    entity: str = Field(description="the named entity")
    attr_type: str = Field(description="semantic type (e.g. Drug, Disease, Concept, Method, Process)")


class EntityList(BaseModel):
    """List of entities extracted from text"""
    entities: List[EntityWithAttr]


class DeduplicatedEntities(BaseModel):
    """Deduplicated list of entities"""
    deduplicated: List[EntityWithAttr]
    confidence: float = Field(description="confidence score 0-1")


class RelationTriple(BaseModel):
    """Relationship between two entities"""
    source: str
    relation: str
    target: str


class RelationList(BaseModel):
    """List of relationships"""
    triples: List[RelationTriple]


# DSPY SIGNATURES

class ExtractEntities(dspy.Signature):
    """Extract named entities and their semantic types from text"""
    paragraph: str = dspy.InputField()
    entities: List[EntityWithAttr] = dspy.OutputField()


class DeduplicateEntities(dspy.Signature):
    """Deduplicate similar entities (e.g., 'PB IC', 'pea-barley intercrop' -> 1 entity)"""
    items: str = dspy.InputField(desc="list of entities to deduplicate")
    deduplicated: List[str] = dspy.OutputField()
    confidence: float = dspy.OutputField(desc="confidence score 0-1")


class ExtractRelations(dspy.Signature):
    """Extract relationships between entities"""
    text: str = dspy.InputField()
    entities: str = dspy.InputField(desc="list of valid entities")
    triples: List[RelationTriple] = dspy.OutputField()


# UTILITY FUNCTIONS

def scrape_text(url: str, max_chars: int = 5000) -> str:
    """Scrape and clean text from URL"""
    try:
        headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
        response = requests.get(url, headers=headers, timeout=10)
        response.raise_for_status()

        soup = BeautifulSoup(response.content, 'html.parser')

        # Remove script and style elements
        for script in soup(["script", "style", "nav", "footer", "header"]):
            script.decompose()

        # Get text
        text = soup.get_text(separator=' ', strip=True)

        # Clean text
        text = re.sub(r'\s+', ' ', text)
        text = text[:max_chars]  # Limit length

        return text
    except Exception as e:
        print(f"Error scraping {url}: {e}")
        return ""


def clean_entity_name(name: str) -> str:
    """Clean entity name for Mermaid graph"""
    # Remove special characters, keep alphanumeric and spaces
    cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', name)
    # Replace spaces with underscores
    cleaned = re.sub(r'\s+', '_', cleaned)
    return cleaned[:50]  # Limit length


# MAIN PROCESSING FUNCTIONS

def extract_entities_from_text(text: str) -> List[EntityWithAttr]:
    """Extract entities using DSPy predictor"""
    predictor = Predict(ExtractEntities)

    # Split text into chunks if too long
    chunks = [text[i:i+2000] for i in range(0, min(len(text), 6000), 2000)]

    all_entities = []
    for idx, chunk in enumerate(chunks):
        try:
            print(f"Processing chunk {idx + 1}/{len(chunks)}...")
            result = predictor(paragraph=chunk)
            if hasattr(result, 'entities') and result.entities:
                all_entities.extend(result.entities)
                print(f"Found {len(result.entities)} entities in chunk {idx + 1}")
        except Exception as e:
            print(f"Error in chunk {idx + 1}: {str(e)[:100]}")
            # Try simpler extraction as fallback
            try:
                # Use CoT for better results
                cot_predictor = ChainOfThought(ExtractEntities)
                result = cot_predictor(paragraph=chunk[:1500])
                if hasattr(result, 'entities') and result.entities:
                    all_entities.extend(result.entities)
                    print(f"CoT found {len(result.entities)} entities in chunk {idx + 1}")
            except Exception as e2:
                print(f"CoT also failed: {str(e2)[:100]}")
                continue

    return all_entities


def deduplicate_with_lm(items: List[EntityWithAttr], target_confidence: float = 0.8) -> List[EntityWithAttr]:
    """Deduplicate entities with confidence loop"""
    predictor = Predict(DeduplicateEntities)

    # Convert entities to string format
    items_str = "\n".join([f"{e.entity} ({e.attr_type})" for e in items])

    max_attempts = 3
    for attempt in range(max_attempts):
        try:
            result = predictor(items=items_str)

            # Check confidence
            confidence = getattr(result, 'confidence', 0.5)
            if confidence >= target_confidence or attempt == max_attempts - 1:
                # Parse deduplicated results
                deduped = []
                for orig_entity in items:
                    # Check if entity is in deduplicated list
                    if any(orig_entity.entity.lower() in d.lower() for d in result.deduplicated):
                        deduped.append(orig_entity)

                # Remove exact duplicates
                seen = set()
                unique_deduped = []
                for e in deduped:
                    key = (e.entity.lower(), e.attr_type.lower())
                    if key not in seen:
                        seen.add(key)
                        unique_deduped.append(e)

                return unique_deduped
        except Exception as e:
            print(f"Deduplication attempt {attempt + 1} failed: {e}")

    # Fallback: manual deduplication
    seen = set()
    unique = []
    for e in items:
        key = (e.entity.lower(), e.attr_type.lower())
        if key not in seen:
            seen.add(key)
            unique.append(e)

    return unique


def extract_relations(text: str, entities: List[EntityWithAttr]) -> List[RelationTriple]:
    """Extract relationships between entities"""
    predictor = Predict(ExtractRelations)

    entity_list = ", ".join([e.entity for e in entities])

    try:
        result = predictor(text=text[:2000], entities=entity_list)
        if hasattr(result, 'triples'):
            return result.triples
    except Exception as e:
        print(f"Error extracting relations: {e}")

    return []


def generate_mermaid_diagram(entities: List[EntityWithAttr], triples: List[RelationTriple]) -> str:
    """Generate Mermaid diagram from entities and relations"""
    mermaid = "graph TD\n"

    # Create entity set for validation
    entity_set = {e.entity.lower() for e in entities}

    # Add nodes with types
    for entity in entities:
        clean_id = clean_entity_name(entity.entity)
        label = entity.entity[:40]  # Limit label length
        mermaid += f'  {clean_id}["{label}<br/><i>{entity.attr_type}</i>"]\n'

    # Add edges
    added_edges = set()
    for triple in triples:
        # Validate entities exist
        if triple.source.lower() in entity_set and triple.target.lower() in entity_set:
            src_clean = clean_entity_name(triple.source)
            dst_clean = clean_entity_name(triple.target)
            label = triple.relation[:40]  # Limit label length

            edge_key = (src_clean, dst_clean, label)
            if edge_key not in added_edges:
                mermaid += f'  {src_clean} -- "{label}" --> {dst_clean}\n'
                added_edges.add(edge_key)

    return mermaid


# MAIN PIPELINE

def process_url(url: str, index: int):
    """Process a single URL through the complete pipeline"""
    print(f"\n{'='*80}")
    print(f"Processing URL {index + 1}/10: {url}")
    print(f"{'='*80}")

    # Step 1: Scrape text
    print("Step 1: Scraping content...")
    text = scrape_text(url)
    if not text:
        print("Failed to scrape content")
        return None
    print(f"Scraped {len(text)} characters")

    # Step 2: Extract entities
    print("Step 2: Extracting entities...")
    entities = extract_entities_from_text(text)
    print(f"Extracted {len(entities)} entities")

    # Step 3: Deduplicate
    print("Step 3: Deduplicating entities...")
    deduped_entities = deduplicate_with_lm(entities)
    print(f"Deduplicated to {len(deduped_entities)} unique entities")

    # Step 4: Extract relations
    print("Step 4: Extracting relationships...")
    relations = extract_relations(text, deduped_entities)
    print(f"Extracted {len(relations)} relationships")

    # Step 5: Generate Mermaid diagram
    print("Step 5: Generating Mermaid diagram...")
    mermaid = generate_mermaid_diagram(deduped_entities, relations)
    print(f"Generated diagram")

    return {
        'url': url,
        'entities': deduped_entities,
        'relations': relations,
        'mermaid': mermaid,
        'index': index + 1
    }


def process_url_with_manual_text(url: str, text: str, index: int):
    """Process a URL with pre-extracted text instead of scraping"""
    print(f"\n{'='*80}")
    print(f"Processing URL {index + 1} (manual text): {url}")
    print(f"{'='*80}")

    # Skip scraping, use provided text
    print(f"Step 1: Using provided text ({len(text)} characters)")

    # Step 2: Extract entities
    print("Step 2: Extracting entities...")
    entities = extract_entities_from_text(text)
    print(f"Extracted {len(entities)} entities")

    # Step 3: Deduplicate
    print("Step 3: Deduplicating entities...")
    deduped_entities = deduplicate_with_lm(entities)
    print(f"Deduplicated to {len(deduped_entities)} unique entities")

    # Step 4: Extract relations
    print("Step 4: Extracting relationships...")
    relations = extract_relations(text, deduped_entities)
    print(f"Extracted {len(relations)} relationships")

    # Step 5: Generate Mermaid diagram
    print("Step 5: Generating Mermaid diagram...")
    mermaid = generate_mermaid_diagram(deduped_entities, relations)
    print(f"Generated diagram")

    return {
        'url': url,
        'entities': deduped_entities,
        'relations': relations,
        'mermaid': mermaid,
        'index': index + 1
    }


# MANUAL TEXT DEFINITIONS
text_url1 = """
Ivermectin is a macrolide antiparasitic drug ... Recently, ivermectin has been reported to inhibit the proliferation of several tumor cells by regulating multiple signaling pathways. This suggests that ivermectin may be an anticancer drug with great potential. Here, we reviewed the related mechanisms by which ivermectin inhibited the development of different cancers and promoted programmed cell death and discussed the prospects for the clinical application of ivermectin as an anticancer drug for neoplasm therapy.

Keywords: Ivermectin, avermectin, selamectin, doramectin, moxidectin, cancer, tumor, neoplasm, triple-negative breast cancer (TNBC), drug repositioning, apoptosis, autophagy, pyroptosis, proliferation, metastasis, angiogenic activity, multidrug resistance (MDR), cell death, PAK1 kinase, EGFR, HER2, ASC, GSDMD, LDH, PARP, P-glycoprotein (P-gp), SOX-2, OCT-4, STAT3, YAP1, HMGB1, HSP27, signaling pathways, crosstalk, chemotherapy drugs, targeted drugs.
"""

text_url2 = """
There is a significant relationship between ambient temperature and mortality ... in vulnerable groups, especially in elderly over the age of 65 years, infants and individuals with co-morbid cardiovascular and/or respiratory conditions, there is a deficiency in thermoregulation. When temperatures exceed a certain limit, being cold winter spells or heat waves, there is an increase in the number of deaths. ... Besides the direct effect of temperature rises on human health, global warming will have a negative impact on primary producers and livestock, leading to malnutrition ... Public health measures ... improved urban planning and reduction in energy consumption ... reduce the carbon footprint and help avert global warming, thus reducing mortality.

Keywords: global warming, ambient temperature, thermoregulation, heat regulation, heat waves, cold spells, mortality, malnutrition, carbon footprint, urban planning, energy consumption, climate change, air pollution, greenhouse effect, cardiovascular, respiratory disease, vector-borne diseases, waterborne diseases, foodborne diseases, mental health problems, allergies, elderly, infants, primary producers, livestock, vulnerable groups, air pollution, carbon dioxide, methane, nitrous oxide, ocean acidification, ozone depletion, adaptation, acclimatization, epidemiological studies, public health measures.
"""


def main():
    """Main execution function"""
    print("Starting DSPy Entity Extraction Pipeline")
    print(f"Processing {len(URLS)} URLs...\n")

    # Create output directory
    os.makedirs('output', exist_ok=True)

    # Process all URLs
    results = []
    csv_rows = [['link', 'tag', 'tag_type']]

    # Dictionary for manual text overrides
    manual_texts = {
        'https://www.sciencedirect.com/science/article/pii/S1043661820315152': text_url1,
        'https://www.sciencedirect.com/science/article/pii/S0378378220307088': text_url2
    }

    for i, url in enumerate(URLS):
        # Check if we have manual text for this URL
        if url in manual_texts:
            result = process_url_with_manual_text(url, manual_texts[url], i)
        else:
            result = process_url(url, i)

        if result:
            results.append(result)

            # Save Mermaid diagram
            mermaid_file = f'output/mermaid_{result["index"]}.md'
            with open(mermaid_file, 'w', encoding='utf-8') as f:
                f.write(result['mermaid'])
            print(f"Saved {mermaid_file}")

            # Add to CSV
            for entity in result['entities']:
                csv_rows.append([url, entity.entity, entity.attr_type])

    # Save CSV
    df = pd.DataFrame(csv_rows[1:], columns=csv_rows[0])
    csv_file = 'output/tags.csv'
    df.to_csv(csv_file, index=False)
    print(f"\n‚úì Saved {csv_file} with {len(df)} rows")

    # Print summary
    print("\n" + "="*80)
    print("‚úì PIPELINE COMPLETE!")
    print("="*80)
    print(f"Processed: {len(results)}/{len(URLS)} URLs")
    print(f"Total entities: {sum(len(r['entities']) for r in results)}")
    print(f"Total relationships: {sum(len(r['relations']) for r in results)}")
    print(f"\nOutputs saved to 'output/' directory:")
    print(f"  - mermaid_1.md to mermaid_10.md")
    print(f"  - tags.csv")


if __name__ == "__main__":
    # Test API connection first
    print("üîç Testing API connection...")
    try:
        test_predictor = Predict("question -> answer")
        test_result = test_predictor(question="What is 2+2?")
        print(f"‚úì API connection successful! Test response: {test_result.answer}")
        print("\nStarting main pipeline...\n")
        main()
    except Exception as e:
        print(f"‚ùå API connection failed: {e}")
        print("\nüîß Troubleshooting steps:")
        print("1. Check your API key is correct")
        print("2. Verify Groq endpoint: https://api.groq.com/openai/v1")
        print("3. Check your quota at Groq dashboard")
        print("4. Try alternative configuration:")
        print("\n   # Option A: Use environment variables")
        print("   export GROQ_API_KEY='your-key'")
        print("\n   # Option B: Try different model")
        print("   lm = dspy.LM(model='groq/llama-3.1-70b-versatile', api_key=API_KEY, api_base='...')")


‚úì DSPy configured with Groq AI
  Model: llama-3.3-70b-versatile
üîç Testing API connection...
‚úì API connection successful! Test response: The answer to 2+2 is 4.

Starting main pipeline...

Starting DSPy Entity Extraction Pipeline
Processing 10 URLs...


Processing URL 1/10: https://en.wikipedia.org/wiki/Sustainable_agriculture
Step 1: Scraping content...
Scraped 5000 characters
Step 2: Extracting entities...
Processing chunk 1/3...
Found 12 entities in chunk 1
Processing chunk 2/3...
Found 68 entities in chunk 2
Processing chunk 3/3...
Found 38 entities in chunk 3
Extracted 118 entities
Step 3: Deduplicating entities...
Deduplicated to 106 unique entities
Step 4: Extracting relationships...
Extracted 10 relationships
Step 5: Generating Mermaid diagram...
Generated diagram
Saved output/mermaid_1.md

Processing URL 2/10: https://www.nature.com/articles/d41586-025-03353-5
Step 1: Scraping content...
Scraped 5000 characters
Step 2: Extracting entities...
Processing chunk 1/3...
Found