In [1]:
import os
import glob
import base64

from dotenv import load_dotenv
load_dotenv()
os.environ["PATH"] = "/opt/homebrew/bin:" + os.environ.get("PATH", "")

from unstructured.partition.pdf import partition_pdf
from unstructured.partition.pptx import partition_pptx
from unstructured.partition.xlsx import partition_xlsx
from unstructured.chunking.title import chunk_by_title

from langchain_core.documents import Document
from langchain_core.messages import HumanMessage
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma

embedding_model = OpenAIEmbeddings(model="text-embedding-3-large")

llm = ChatOpenAI(temperature=0, model="gpt-4o")


  from .autonotebook import tqdm as notebook_tqdm


In [15]:
def process_image_with_vision(base64_image):
    """Process image using vision model to extract content"""
    try:
        enhanced_prompt = """
        Analyze this image and extract all information in a way that answers potential questions users might ask.

        For charts/graphs:
        - Identify the highest, lowest, and notable values with specific numbers
        - Explain trends, patterns, and relationships in the data
        - Answer questions like "what has the highest/lowest value", "which items are above/below a threshold"
        - Provide rankings and comparisons

        For tables in images:
        - Extract all data points with specific numbers and percentages  
        - Identify the most/least common items
        - Explain relationships between columns and rows

        For text content:
        - Extract all readable text verbatim
        - Preserve important numbers, dates, and key facts
        - Maintain structure and hierarchy

        For diagrams/flowcharts:
        - Describe the process, flow, or relationships shown
        - Explain connections between elements
        - Identify key decision points or outcomes

        Write your response in natural language that would match how someone would ask questions about this content. Include specific numbers, rankings, and be explicit about comparisons and superlatives (highest, lowest, most common, etc.).
        """
        
        message = HumanMessage(content=[
            {
                "type": "text", 
                "text": enhanced_prompt
            },
            {
                "type": "image_url", 
                "image_url": {"url": f"data:image/png;base64,{base64_image}"}
            }
        ])
        
        response = llm.invoke([message])
        return response.content
        
    except Exception as e:
        print(f"    ❌ Error processing image with vision model: {str(e)}")
        return None

In [None]:
# =============================================================================
# STEP 1: FIND ALL FILES IN DATA FOLDER
# =============================================================================
data_folder = "./AMTAGVI"
print(f"\n📁 Looking for files in: {data_folder}")

# Find all PDF and PowerPoint files
pdf_files = glob.glob(os.path.join(data_folder, "**/*.pdf"), recursive=True)
pptx_files = glob.glob(os.path.join(data_folder, "**/*.pptx"), recursive=True)
xlsx_files = glob.glob(os.path.join(data_folder, "**/*.xlsx"), recursive=True)
png_files = glob.glob(os.path.join(data_folder, "**/*.png"), recursive=True)
jpg_files = glob.glob(os.path.join(data_folder, "**/*.jpg"), recursive=True)
jpeg_files = glob.glob(os.path.join(data_folder, "**/*.jpeg"), recursive=True)

# all_files = pdf_files + pptx_files + xlsx_files
all_files = pdf_files + pptx_files + xlsx_files + png_files + jpg_files + jpeg_files

print(f"Found {len(pdf_files)} PDF files and {len(pptx_files)} PowerPoint files and {len(xlsx_files)} xlsx files and {len(png_files)} PNG files and {len(jpg_files)} JPG files and {len(jpeg_files)} JPEG files ")

print(f"Total files to process: {len(all_files)}")

if len(all_files) == 0:
    print("❌ No files found! Make sure your data folder contains PDF or PPTX files.")
    exit()


In [None]:
# =============================================================================
# STEP 2: PROCESS EACH FILE (PARTITION INTO ELEMENTS)
# =============================================================================
print(f"\n🔄 Processing files...")

all_elements = []

for file_path in all_files:
    print(f"Processing: {os.path.basename(file_path)}")
    
    try:
        if file_path.endswith('.pdf'):
            # Process PDF files
            elements = partition_pdf(
                filename=file_path,
                strategy="hi_res",
                hi_res_model_name="yolox",
                infer_table_structure=True,
                languages=["eng"], 
                extract_image_block_types=["Image"], 
                extract_image_block_to_payload=True 
            )
        
        elif file_path.endswith(('.pptx', '.ppt')):
            # Process PowerPoint files
            elements = partition_pptx(
                filename=file_path,
                infer_table_structure=True,
                include_slide_notes=True,
                strategy="hi_res"
            )

        elif file_path.endswith('.xlsx'): 
           elements = partition_xlsx(
               filename=file_path, 
               infer_table_structure=True
            )
           
        elif file_path.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            # Process standalone images - convert to base64 first
            with open(file_path, "rb") as image_file:
                image_data = image_file.read()
            base64_image = base64.b64encode(image_data).decode()
            
            vision_content = process_image_with_vision(base64_image)
            
            if vision_content:
                # Create a mock element for standalone images
                class MockElement:
                    def __init__(self, text, file_path):
                        self.text = text
                        self.category = "Image"
                        self.metadata = MockMetadata(file_path)
                
                class MockMetadata:
                    def __init__(self, file_path):
                        self.filename = os.path.basename(file_path)
                        self.file_path = file_path
                        
                    def to_dict(self):
                        return {
                            'filename': self.filename,
                            'filetype': 'standalone_image',
                            'content_type': 'image_vision_extracted',
                            'image_source': 'standalone'
                        }
                
                vision_element = MockElement(vision_content, file_path)
                elements = [vision_element]
        else:
            print(f"  ⚠️ Unsupported file type: {file_path}")
            continue

        all_elements.extend(elements)
        print(f"  ✅ Extracted {len(elements)} elements")
                
    except Exception as e:
        print(f"  ❌ Error processing {file_path}: {str(e)}")
        continue

print(f"\n📊 Total elements extracted from all files: {len(all_elements)}")


In [None]:
# =============================================================================
# SAVE PROCESSED ELEMENTS FOR FUTURE RUNS
# =============================================================================
import pickle

# Only save if all_elements exists and has data
if all_elements:
    print(f"💾 Saving {len(all_elements)} elements...")
    with open("./partitioned_elements.pkl", "wb") as f:
        pickle.dump(all_elements, f)
    print("✅ Saved!")
else:
    print("⚠️ No all_elements to save")

In [2]:
import pickle

with open("./partitioned_elements.pkl", "rb") as f:
    all_elements = pickle.load(f)

print(f"{len(all_elements)} elements have been loaded")

36759 elements have been loaded


In [3]:
# =============================================================================
# STEP 3: FILTER OUT UNWANTED SECTIONS
# =============================================================================
print(f"\n🔍 Filtering out unwanted sections...")

# Keywords to filter out (easily extensible)
FILTER_KEYWORDS = [
    'references',
    # 'bibliography',
    # 'acknowledgments', 
    # 'appendix',
    # 'glossary'
]

def filter_sections(elements, keywords):
    """Remove sections containing any of the specified keywords"""
    filtered = []
    removed = []
    skip_section = False
    current_heading = None
    
    for element in elements:
        # Check if this is a heading/title
        is_heading = hasattr(element, 'category') and element.category in ['Title', 'Header']
        element_text = getattr(element, 'text', '') or ''
        
        if is_heading:
            # Check if heading contains any filter keywords
            skip_section = any(keyword.lower() in element_text.lower() for keyword in keywords)
            if skip_section:
                current_heading = element_text
        
        # Remove element if we're in a filtered section
        if skip_section:
            removed.append({
                'text': element_text[:100] + ('...' if len(element_text) > 100 else ''),
                'category': getattr(element, 'category', 'Unknown'),
                'filename': getattr(element.metadata, 'filename', 'Unknown') if hasattr(element, 'metadata') else 'Unknown',
                'section': current_heading
            })
        else:
            filtered.append(element)
    
    return filtered, removed

# Apply filter
filtered_elements, removed_elements = filter_sections(all_elements, FILTER_KEYWORDS)

# Print results
print(f"📊 Filtered {len(removed_elements)} elements from {len(all_elements)} total")

if removed_elements:
    print(f"\n🗑️ REMOVED ELEMENTS:")
    for filename in set(elem['filename'] for elem in removed_elements):
        file_elements = [e for e in removed_elements if e['filename'] == filename]
        print(f"\n📄 {filename} ({len(file_elements)} removed)")
        for elem in file_elements:
            print(f"  • [{elem['category']}] {elem['text']}")



🔍 Filtering out unwanted sections...
📊 Filtered 1194 elements from 36759 total

🗑️ REMOVED ELEMENTS:

📄 prc-us-00483-ref2-Tawbi 2022.pdf (28 removed)
  • [Title] References
  • [NarrativeText] 1. Larkin J, Chiarion-Sileni V, Gonzalez R, et al. Five-year survival with combined nivolumab and ip...
  • [NarrativeText] 3. Durham NM, Nirschl CJ, Jackson CM, et al. Lymphocyte Activation Gene 3 (LAG-3) modulates the abil...
  • [NarrativeText] 4. Workman CJ, Cauley LS, Kim I-J, Blackman MA, Woodland DL, Vignali DAA. Lymphocyte activation gene...
  • [NarrativeText] 5. Hemon P, Jean-Louis F, Ramgolam K, et al. MHC class II engagement by its li- gand LAG-3 (CD223) c...
  • [NarrativeText] 6. Woo S-R, Turnis ME, Goldberg MV, et al. Immune inhibitory molecules LAG-3 and PD-1 synergisticall...
  • [NarrativeText] 9. Ascierto PA, Melero I, Bhatia S, et al. Initial efficacy of anti-lymphocyte activa- tion gene-3 (...
  • [NarrativeText] 10. Eisenhauer EA, Therasse P, Bogaerts J, et al. New response

In [4]:
# Update elements
all_elements = filtered_elements
print(f"\n✅ Now using {len(all_elements)} filtered elements")


✅ Now using 35565 filtered elements


In [16]:
embedded_images = [el for el in all_elements if getattr(el, 'category', '') == 'Image' and hasattr(el.metadata, 'image_base64')] 
len(embedded_images) 

1661

In [None]:
# =============================================================================
# STEP 4: PROCESS IMAGES WITH VISION MODEL
# =============================================================================
print(f"\n🖼️ Processing images with vision model...") 

# Extract embedded images from PDFs/PPTs
embedded_images = [el for el in all_elements if getattr(el, 'category', '') == 'Image' and hasattr(el.metadata, 'image_base64')] 

# Process embedded images 
image_documents = [] 

for i, img_element in enumerate(embedded_images): 
    try:
        # Get base64 image data 
        base64_image = img_element.metadata.image_base64 
        
        # Process with vision model 
        vision_content = process_image_with_vision(base64_image) 
        
        if vision_content:
            # Create minimal metadata
            metadata = {
                'filename': img_element.metadata.filename, 
                'filetype': img_element.metadata.filetype, 
                'content_type': 'image_vision_extracted',  
                'image_source': 'embedded'
            }
            
            # Create document
            doc = Document(
                page_content=vision_content,
                metadata=metadata
            )
            
            image_documents.append(doc)
            print(f"  ✅ Processed embedded image {i+1}")
        
    except Exception as e:
        print(f"  ❌ Error processing embedded image {i+1}: {str(e)}")
        continue

In [7]:
# =============================================================================
# SAVE PROCESSED ELEMENTS FOR FUTURE RUNS
# =============================================================================
import pickle

# Only save if all_elements exists and has data
if embedded_images:
    print(f"💾 Saving {len(embedded_images)} embedded_images chunks...")
    with open("./embedded_images.pkl", "wb") as f:
        pickle.dump(embedded_images, f)
    print("✅ Saved!")
else:
    print("⚠️ No embedded_images chunks to save")

NameError: name 'standalone_images' is not defined

In [None]:
# Process standalone images (MockElements)
standalone_images = [el for el in all_elements if hasattr(el, '__class__') and el.__class__.__name__ == 'MockElement']

for i, img_element in enumerate(standalone_images):
    try:
        # Get vision content (already processed)
        vision_content = img_element.text
        
        # Get minimal metadata
        metadata = img_element.metadata.to_dict()
        
        # Create document
        doc = Document(
            page_content=vision_content,
            metadata=metadata
        )
        
        image_documents.append(doc)
        print(f"  ✅ Processed standalone image {i+1}")
        
    except Exception as e:
        print(f"  ❌ Error processing standalone image {i+1}: {str(e)}")
        continue

print(f"🖼️ Created {len(image_documents)} image documents")

In [5]:
def create_table_summary(table_html, table_text):
    """Use LLM to create query-friendly summaries"""
    
    enhanced_prompt = f"""
    Analyze this table and create a comprehensive summary that answers potential questions users might ask.

    Table Data:
    {table_text}
    
    Table HTML:
    {table_html}
    
    IMPORTANT INSTRUCTIONS:
    1. **Data Quality Check**: If you notice corrupted text, garbled data, or major discrepancies between the Table Data and Table HTML, acknowledge this rather than guessing. Say "extraction quality issues detected" if the data appears unreliable.
    
    2. **Source Priority**: When there are differences between Table Data and Table HTML, prioritize the Table Data section as it's usually more reliable.
    
    3. **Summary Requirements**:
       - What type of data this table contains
       - The highest/lowest values and what they represent  
       - Any patterns, trends, or notable findings
       - Specific numbers, percentages, and key data points
       - Answer common questions like "what are the most/least common", "which has the highest/lowest", "what percentage"
    
    4. **Response Format**: 
       - Keep summary concise (200-300 words)
       - Write in natural language that matches how users ask questions
       - Include specific numbers and rankings when clearly visible
       - If data is unclear or corrupted, state this explicitly rather than reconstructing
    
    5. **Accuracy Over Completeness**: Only mention specific details you can clearly see in the data. When in doubt, say the information is unclear rather than guessing.
    """
    
    enhanced_summary = llm.invoke(enhanced_prompt)
    
    return enhanced_summary

In [None]:
# =============================================================================
# STEP 5: EXTRACT AND SUMMARIZE TABLES
# =============================================================================
print(f"\n📋 Processing tables...") 

tables = [el for el in all_elements if getattr(el, 'category', '') == "Table"] 
print(f"Found {len(tables)} tables") 

llm = ChatOpenAI(temperature=0, model="gpt-4o") 

table_documents = [] 

for i, table in enumerate(tables): 
    
    try:
        # Get table HTML if available
        table_html = getattr(table.metadata, 'text_as_html', None) 
        
        if table_html:
            # Summarize the table using LLM
            table_doc = Document(page_content=table_html) 
            
            summary = create_table_summary(table_html, table.text) 
            
            # Create metadata for the table
            metadata = table.metadata.to_dict() if hasattr(table.metadata, 'to_dict') else {} 
            
            # Clean up metadata (fix lists that ChromaDB doesn't like)
            for key, value in metadata.items():
                if isinstance(value, list):
                    if len(value) > 0:
                        metadata[key] = str(value[0])
                    else:
                        metadata[key] = None
                elif value is not None:
                    metadata[key] = str(value)
            
            # Add table-specific metadata
            metadata.update({
                'content_type': 'table_summary',
                'original_html': table_html
            })
            
            # Create document with summary
            table_content = f"""
                {summary}

                Source Data: {table.text}
                """
            
            doc = Document(
                page_content=table_content,
                metadata=metadata
            )
            
            table_documents.append(doc)
            print(f"  ✅ Summarized table {i+1}")
        
        else:
            print(f"  ⚠️ Table {i+1} has no HTML content, skipping")
            
    except Exception as e:
        print(f"  ❌ Error processing table {i+1}: {str(e)}")
        continue

print(f"📋 Created {len(table_documents)} table summary documents")

In [None]:
# This is for SAVING 

import pickle

if table_documents:
    print(f"💾 Saving {len(table_documents)} elements...")
    with open("./summarized_documents.pkl", "wb") as f:
        pickle.dump(table_documents, f)
    print("✅ Saved!")
else:
    print("⚠️ No table_documents to save")

In [6]:
import pickle


with open("./summarized_documents.pkl", "rb") as f:
    table_documents = pickle.load(f)
    print(f"{len(table_documents)} table documents have been loaded")

1127 table documents have been loaded


In [7]:
# =============================================================================
# STEP 6: CHUNK THE REGULAR CONTENT  
# =============================================================================
print(f"\n✂️ Chunking content...")

# Get only non-image, non-table elements for text chunking
regular_elements = [el for el in all_elements if getattr(el, 'category', '') not in ['Image', 'Table']]

# Chunk all elements
chunks = chunk_by_title(
    regular_elements,
    combine_text_under_n_chars=100,
    max_characters=3000,
)

print(f"Created {len(chunks)} chunks")


✂️ Chunking content...
Created 6221 chunks


In [15]:
# =============================================================================
# STEP 7: CONVERT CHUNKS TO LANGCHAIN DOCUMENTS
# =============================================================================

# Convert chunks to LangChain Documents
chunk_documents = []

for i, chunk in enumerate(chunks):
    # Extract text content
    text_content = chunk.text
    
    # Get metadata
    metadata = chunk.metadata.to_dict() if hasattr(chunk.metadata, 'to_dict') else {}
    
    # Clean up metadata for ChromaDB
    for key, value in metadata.items():
        if isinstance(value, list):
            if len(value) > 0:
                metadata[key] = str(value[0])
            else:
                metadata[key] = None
        elif value is not None:
            metadata[key] = str(value)
    
    # Add chunk info
    metadata.update({
        'chunk_index': i,
        'chunk_id': getattr(chunk, 'id', f'chunk_{i}'),
        'content_type': 'text_chunk'
    })
    
    # Create document
    doc = Document(
        page_content=text_content,
        metadata=metadata
    )
    chunk_documents.append(doc)

print(f"✂️ Created {len(chunk_documents)} chunk documents")

✂️ Created 6221 chunk documents


In [None]:
# =============================================================================
# STEP 8: COMBINE ALL DOCUMENTS
# =============================================================================
print(f"\n🔗 Combining all documents...")

# Combine regular chunks, table summaries, and image documents
all_documents = chunk_documents + table_documents + image_documents

print(f"Total documents for vector database: {len(all_documents)}")
print(f"  - Text chunks: {len(chunk_documents)}")
print(f"  - Table chunks: {len(table_documents)}")
print(f"  - Image chunks: {len(image_documents)}")

In [16]:
# =============================================================================
# STEP 8: COMBINE ALL DOCUMENTS
# =============================================================================
print(f"\n🔗 Combining all documents...")

# Combine regular chunks, table summaries, and image documents
all_documents = chunk_documents + table_documents 

print(f"Total documents for vector database: {len(all_documents)}")


🔗 Combining all documents...
Total documents for vector database: 7348


In [17]:
import hashlib

def deduplicate_documents_by_content(documents):
    seen_hashes = set()
    unique_documents = []
    duplicate_sources = []
    
    for doc in documents:
        content_hash = hashlib.md5(doc.page_content.encode()).hexdigest()
        
        if content_hash not in seen_hashes:
            seen_hashes.add(content_hash)
            unique_documents.append(doc)
        else:
            # Log where duplicates come from
            source = doc.metadata.get('filename', 'unknown')
            duplicate_sources.append(source)
    
    print(f"Duplicate sources: {set(duplicate_sources)}")
    return unique_documents

print(f"\n🔍 Deduplicating documents...")
print(f"Before deduplication: {len(all_documents)} documents")

unique_documents = deduplicate_documents_by_content(all_documents)

print(f"After deduplication: {len(unique_documents)} documents")



🔍 Deduplicating documents...
Before deduplication: 7348 documents
Duplicate sources: {'prc-us-00484-ref1-Amtagvi PI 02 2024.pdf', 'CLEAN_5228_BRANDED AMTAGVI SEM Ads_HCP_050225.pdf', '5237_AMTAGVI 5-Year Data Email_R3_042425.pdf', 'Medical Lifileucel Scientific Communication Platform_Revision_11-15-24-dm (1).pptx', 'Iovance Data on File 2025.pdf', 'AMTAGVI banner ads for ASCO enewsletter_040225.pdf', 'prc-us-00484-ref3-Iovance DOF 2025.pdf', 'EXTS_5228_AMTAGVI SEM Ads_R3_040825.pdf', '5237_AMTAGVI Email 1_R3_040925 PC (1).pdf', 'Comments for PRC_5228_BRANDED AMTAGVI SEM Ads_HCP_050225.pdf', 'PRC Comments_AMTAGVI.com Unbranded HCP-DTC SEM (May 2025) 4.25.25.pdf', '5228_AMTAGVI SEM Ads_r1v3_032125.xlsx', 'DRAFT Copy of 5228_AMTAGVI SEM Ads_r1v3_032125.xlsx', 'PRJ5237_IOV25_5237_Amtagvi_Email_2_v9_FCv1.pdf', '5237_AMTAGVI_Referral Email_R4v2_042925.pdf', 'AMTAGVI Order Sheet.pdf', 'IOVANCE-Unbranded Mini Campaign-RFP-Submission_v1.pptx', 'AMTAGVI 2025 NPP Media Expansion Recommendations.

In [22]:
import os
from pinecone import Pinecone, ServerlessSpec
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_openai import OpenAIEmbeddings
from pinecone_text.sparse import BM25Encoder

index_name = "pixacore"

# Initialize the pinecone client
api_key = "pcsk_6p7LX3_MP2gq5GRAkEj8WLhoU8pCHYLwLn6xtxNsPtVtVzytR5dTGnGdwb57dwL3EhodRh"
pc = Pinecone(api_key=api_key)

# Create the index - IMPORTANT: Must use dotproduct for hybrid search
if index_name not in pc.list_indexes().names():
    pc.create_index(
        name=index_name, 
        dimension=3072,  # OpenAI embedding dimension
        metric="dotproduct",  # sparse values supported only for dotproduct
        spec=ServerlessSpec(
            cloud="aws", 
            region="us-east-1"  
        )
    )

In [23]:
# Get the index
index = pc.Index(index_name)

# Create BM25 encoder
bm25_encoder = BM25Encoder().default()

# Fit BM25 encoder on your text corpus
text_corpus = [doc.page_content for doc in unique_documents]

In [24]:
bm25_encoder.fit(text_corpus)

100%|██████████| 4218/4218 [00:04<00:00, 1040.64it/s]


<pinecone_text.sparse.bm25_encoder.BM25Encoder at 0x2b79e3750>

In [57]:
# Store the BM25 values (optional but recommended)
bm25_encoder.dump("bm25_values.json")

# Create the retriever with alpha parameter for weighting
pinecone_retriever = PineconeHybridSearchRetriever(
    embeddings=embedding_model,
    sparse_encoder=bm25_encoder,
    index=index,
    alpha=0.5,  # Adjust this to control dense vs sparse weighting
    top_k=30
)

In [27]:
# Add texts with metadata to the retriever
text_list = [doc.page_content for doc in unique_documents]
metadata_list = [doc.metadata for doc in unique_documents]

In [29]:
len(text_list)

4218

In [28]:
# Add texts with their metadata
pinecone_retriever.add_texts(
    texts=text_list,
    metadatas=metadata_list  # This preserves the Document metadata
)

100%|██████████| 132/132 [19:19<00:00,  8.78s/it] 


In [56]:
pinecone_retriever.invoke("Count of people diagnosed with melanoma")

[Document(metadata={'chunk_id': '6786c649-4c93-4f16-8bbc-baeead097e3c', 'chunk_index': 6091.0, 'content_type': 'text_chunk', 'file_directory': './AMTAGVI/5528_Unbranded Mini Campaign RFP/00_Resources from Iovance/Reference Materials', 'filename': 'Amtagvi_onboarding_deck_0430.pptx', 'filetype': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', 'languages': 'eng', 'last_modified': '2025-05-22T07:18:00', 'orig_elements': 'eJztmN1vGzcMwP8VwUCBFqgT3beue8oydAiQrEX68VIUBiVRttq7003SpcmK/e+jzjbarR3qlwEz4Bf7fCIlivyZpPTu8wI77HGIK6sXz9iiznRlijYvijwXpaolcM4RWplJaeq6XDxlix4jaIhA8p8XCiKunX9YaRzjhl5lJGFshyttPapIQ2nes/OLm9cXv769Oq+qXKzeDNLDoFGzGztYdgn9CHY9sNvnL885X91icJNXGJjxrmdX7g4Ghee3aNAjPbEbWtVb6MJit9oAPaZ1LvoI6zu7coN04LUd1mSY+rjiZcHPxjHe7xXiwzgrwDh2ljZh3XB+N+gzN+Jw33fG+R5iWDpjrELt1JScdDZ6DPQ9i/fd336miTsY1hOsMdDM7xY4rBfv57chrnqnrbE4OznnebXk1TLPX/PmWSaecZ60R9JcDVMv0Sc/ZvMr/yU2JYiG57VpM2FyKaSURZaVWnCpdQV5sfiTNCLexyR8PdHq6fdun69t7DBJ/DPiHGsQCIUwqiiyFlUFbSVE1pRcirqGU8SPJOI3S

In [None]:
# # What happens inside add_texts() for each document:
# for i, text in enumerate(text_list):
#     # Step 1: Create dense vector using OpenAI
#     dense_vector = self.embeddings.embed_query(text)  # Uses your embedding_model
    
#     # Step 2: Create sparse vector using fitted BM25
#     sparse_vector = self.sparse_encoder.encode_documents([text])[0]  # Uses your bm25_encoder
    
#     # Step 3: Upsert BOTH vectors together to Pinecone
#     self.index.upsert(vectors=[{
#         'id': f'generated_id_{i}',
#         'values': dense_vector,           # Dense vector here
#         'sparse_values': sparse_vector,   # Sparse vector here  
#         'metadata': {
#             'context': text,              # Original text
#             **metadatas[i]               # Your metadata
#         }
#     }])

In [40]:
import time
import os
from typing import List
from langchain.schema import Document

# =============================================================================
# STEP 9: CREATE OR LOAD VECTOR DATABASE
# =============================================================================
print(f"\n🗄️ Setting up vector database...")

# Define where to save/load the database
db_path = "./full_vector"
collection_name = "my_documents"

def create_vector_store_in_batches(
    documents: List[Document], 
    embedding_model, 
    collection_name: str,
    db_path: str,
    batch_size: int = 100,  # Adjust based on your chunk sizes
    delay: float = 1.0      # Delay between batches to respect rate limits
):
    """Create vector store by processing documents in batches"""
    
    print(f"🗄️ Creating vector database in batches...")
    print(f"  - Total documents: {len(documents)}")
    print(f"  - Batch size: {batch_size}")
    
    # Initialize vector store with first batch
    first_batch = documents[:batch_size]
    print(f"Creating initial vector store with {len(first_batch)} documents...")
    
    vector_store = Chroma.from_documents(
        documents=first_batch,
        embedding=embedding_model,
        collection_name=collection_name,
        persist_directory=db_path
    )
    
    # Process remaining documents in batches
    for i in range(batch_size, len(documents), batch_size):
        batch = documents[i:i + batch_size]
        print(f"Adding batch {i//batch_size + 1}: documents {i+1}-{min(i+len(batch), len(documents))}")
        
        try:
            vector_store.add_documents(batch)
            
            # Adding delay to respect rate limits
            if delay > 0:
                time.sleep(delay)
                
        except Exception as e:
            print(f"Error processing batch {i//batch_size + 1}: {e}")
            continue
    
    print(f"✅ Vector database created successfully!")
    return vector_store

# Check if database already exists
if os.path.exists(db_path) and os.listdir(db_path):
    print("📁 Existing vector database found! Loading from disk...")
    
    # Load existing vector store
    vector_store = Chroma(
        collection_name=collection_name,
        embedding_function=embedding_model,
        persist_directory=db_path
    )
    
    # Get collection info
    try:
        doc_count = vector_store._collection.count()
        print(f"✅ Vector database loaded successfully!")
        print(f"  - Database location: {db_path}")
        print(f"  - Collection name: {collection_name}")
        print(f"  - Documents in collection: {doc_count}")
    except Exception as e:
        print(f"⚠️  Error getting collection info: {e}")
        print(f"✅ Vector database loaded successfully!")
        print(f"  - Database location: {db_path}")
        print(f"  - Collection name: {collection_name}")
    
else:
    print("🆕 No existing database found. Creating new vector database...")
    
    # Create new vector store using batch processing
    vector_store = create_vector_store_in_batches(
        documents=unique_documents,
        embeddings=embedding_model,
        collection_name=collection_name,
        db_path=db_path,
        batch_size=100,  # Start small and increase if needed
        delay=1.0
    )


🗄️ Setting up vector database...
📁 Existing vector database found! Loading from disk...
✅ Vector database loaded successfully!
  - Database location: ./full_vector
  - Collection name: my_documents
  - Documents in collection: 4291


In [41]:
import pickle
import json
import os
from pathlib import Path
from langchain_community.retrievers import BM25Retriever
from langchain.schema import Document
import nltk
from nltk.tokenize import word_tokenize

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)


def save_bm25(bm25_retriever, file_path):
    """
    Save BM25 retriever to disk
    
    Args:
        bm25_retriever: Your BM25Retriever object
        file_path: Where to save it
    """
    print(f"💾 Saving BM25 to {file_path}...")
    
    Path(file_path).parent.mkdir(parents=True, exist_ok=True)
    
    with open(file_path, 'wb') as f:
        pickle.dump(bm25_retriever, f)
    
    print(f"✅ BM25 saved successfully!")
    nltk.download("punkt_tab", quiet=True)


def load_bm25(file_path):
    """
    Load BM25 retriever from disk
    
    Args:
        file_path: Where to load from
        
    Returns:
        BM25Retriever object or None if file doesn't exist
    """
    if not os.path.exists(file_path):
        print(f"❌ No BM25 file found at {file_path}")
        return None
    
    print(f"📂 Loading BM25 from {file_path}...")
    
    with open(file_path, 'rb') as f:
        bm25_retriever = pickle.load(f)
    
    print(f"✅ BM25 loaded successfully! ({len(bm25_retriever.docs)} documents)")
    return bm25_retriever



def create_or_load_bm25(documents, file_path, force_rebuild=False):

    if not force_rebuild:
        bm25_retriever = load_bm25(file_path)
        if bm25_retriever is not None:
            return bm25_retriever
    
    # Create new BM25
    print(f"🔧 Building new BM25 index from {len(documents)} documents...")
    
    bm25_retriever = BM25Retriever.from_documents(
        documents=documents, 
        preprocess_func=word_tokenize
    )
    bm25_retriever.k = 25  # Return top 25 results
    
    # Save for next time
    save_bm25(bm25_retriever, file_path)
    
    return bm25_retriever


In [42]:
import pickle

def load_bm25(file_path):
    """
    Load BM25 retriever from disk
    
    Args:
        file_path: Where to load from
        
    Returns:
        BM25Retriever object or None if file doesn't exist
    """
    if not os.path.exists(file_path):
        print(f"❌ No BM25 file found at {file_path}")
        return None
    
    print(f"📂 Loading BM25 from {file_path}...")
    
    with open(file_path, 'rb') as f:
        bm25_retriever = pickle.load(f)
    
    print(f"✅ BM25 loaded successfully! ({len(bm25_retriever.docs)} documents)")
    return bm25_retriever


bm25_retriever = load_bm25(file_path="./full_bm25_retriever.pkl")
bm25_retriever.k = 15


📂 Loading BM25 from ./full_bm25_retriever.pkl...
✅ BM25 loaded successfully! (4291 documents)


In [43]:
# bm25_retriever = BM25Retriever.from_documents(documents=all_documents, preprocess_func=word_tokenize) 

bm25_retriever = create_or_load_bm25(
    documents=unique_documents,
    file_path="./full_bm25_retriever.pkl"
) 

bm25_retriever.k = 15

print(f"✅ bm25_retriever created successfully!")


📂 Loading BM25 from ./full_bm25_retriever.pkl...
✅ BM25 loaded successfully! (4291 documents)
✅ bm25_retriever created successfully!


In [44]:
# Create vector retriever
vector_retriever = vector_store.as_retriever( 
    search_type="similarity", 
    search_kwargs={"k": 15}
) 

print(f"✅ vector_retriever created successfully!")

✅ vector_retriever created successfully!


In [45]:
from langchain.retrievers import EnsembleRetriever

ensemble_retriever = EnsembleRetriever(
    retrievers=[vector_retriever, bm25_retriever], 
    weights=[0.5, 0.5]
)

In [47]:
def generate_query_variations(original_query, num_variations):

    from langchain_openai import ChatOpenAI
    from langchain_core.messages import HumanMessage, SystemMessage
    from pydantic import BaseModel, Field
    from typing import List
    
    class QueryVariations(BaseModel):
        """Schema for query variations response"""
        variations: List[str] = Field(
            description=f"Exactly {num_variations} different variations of the original query",
            min_length=num_variations,
            max_length=num_variations
        )
    
    system_prompt = """You are an expert at generating search query variations. Your task is to create different ways to ask the same question that will help retrieve comprehensive information from a knowledge base.

For each variation, use different:
- Terminology (technical vs common terms, synonyms)
- Phrasing and sentence structure  
- Keywords while maintaining the same core intent

Keep variations focused, specific, and don't change the fundamental meaning."""

    human_prompt = f"""Generate {num_variations} different variations of this query:

"{original_query}"

Each variation should capture the same intent but use different wording, terminology, or phrasing."""

    try:
        llm = ChatOpenAI(
            model="gpt-4o",  
            temperature=0.0
        )
        
        structured_llm = llm.with_structured_output(QueryVariations)
        
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=human_prompt)
        ]
        
        response = structured_llm.invoke(messages) 
        
        all_queries = [original_query] + response.variations 
        
        print(f"✅ Generated {len(response.variations)} query variations:")
        for i, query in enumerate(all_queries):
            prefix = "Original" if i == 0 else f"Variation {i}"
            print(f"   {prefix}: {query}")
        
        return all_queries
        
    except Exception as e:
        print(f"⚠️ Error generating query variations: {e}")
        print("📝 Falling back to original query only")
        return [original_query]


In [48]:
from simhash import Simhash
import time

def simhash_deduplication(docs, similarity_threshold):
    """
    Deduplicate documents using Simhash
    
    Args:
        docs: List of LangChain Document objects
        similarity_threshold: Hamming distance threshold (lower = more similar)
                             Typical values: 0-3 for very similar, 3-6 for similar
    """
    if len(docs) <= 10:
        return docs
    
    print(f"🔍 Simhash deduplicating {len(docs)} documents...")
    start = time.time()
    
    # Create simhashes for all documents
    doc_hashes = []
    for doc in docs:
        # Create simhash from document content
        sim_hash = Simhash(doc.page_content)
        doc_hashes.append(sim_hash)
    
    print(f"   Hash generation: {time.time() - start:.2f}s")
    start = time.time()
    
    # Find duplicates
    to_remove = set()
    for i in range(len(docs)):
        if i in to_remove:
            continue
        for j in range(i + 1, len(docs)):
            if j in to_remove:
                continue
            
            # Calculate Hamming distance between hashes
            distance = doc_hashes[i].distance(doc_hashes[j])
            if distance <= similarity_threshold:
                to_remove.add(j)  # Remove the later document
    
    # Filter documents
    final_docs = [doc for i, doc in enumerate(docs) if i not in to_remove]
    
    print(f"   Comparison: {time.time() - start:.2f}s")
    print(f"   Simhash filtering: {len(docs)} → {len(final_docs)} docs")
    
    return final_docs

In [49]:
from langchain_cohere import CohereRerank

print(f"\n🎯 Setting up Cohere reranking with contextual compression...")

cohere_reranker = CohereRerank(
    model="rerank-english-v3.0",  
    top_n=10,  
    cohere_api_key=os.getenv("COHERE_API_KEY")  
)

def retrieve_chunks(query):

    total_start = time.time()

    start = time.time()

    # all_results = ensemble_retriever.invoke(query)
    all_results = retriever.invoke(query)
    print(f"Ensemble Retriever Results {len(all_results)}")

    print(f"Step 1: {time.time() - start:.2f}s")

    start = time.time()

    deduplicated_docs = simhash_deduplication(
        all_results, 
        similarity_threshold=2, 
    )
    print(f"Step 2: {time.time() - start:.2f}s")
    print(f"deduplicated_docs: {len(deduplicated_docs)}")

    
    start = time.time()

    if len(deduplicated_docs) <= 10:
        print(f"Step 3: 0.00s (skipped - only {len(deduplicated_docs)} docs)")
        final_docs = deduplicated_docs
    else:
        start = time.time()
        final_docs = cohere_reranker.compress_documents(
            documents=deduplicated_docs[:20],  
            query=query
        )
       

    print(f"final_docs: {len(final_docs)}")

    print(f"Step 3: {time.time() - start:.2f}s") 
    
    print(f"TOTAL: {time.time() - total_start:.2f}s")
    
    return final_docs

print(f"✅ RAG Fusion pipeline setup complete!")


🎯 Setting up Cohere reranking with contextual compression...
✅ RAG Fusion pipeline setup complete!


In [50]:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from typing import List, Optional

class Citation(BaseModel):
    filename: str = Field(description="The filename of the source document")
    page_number: Optional[str] = Field(default=None, description="The page number of the source document (if available)")


class StructuredAnswer(BaseModel):
    answer: str = Field(description="The comprehensive answer to the question")
    citations: List[Citation] = Field(description="List of source filenames that support the answer")


llm = ChatOpenAI(
    model="gpt-4o",  
    temperature=0.0
)

structured_llm = llm.with_structured_output(StructuredAnswer)


In [51]:
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.schema.output_parser import StrOutputParser 

retriever = RunnableLambda(retrieve_chunks)

def format_docs(docs):
    formatted_chunks = []
    
    for i, doc in enumerate(docs):
        filename = doc.metadata.get('filename', 'unknown')
        page_number = doc.metadata.get('page_number', 'unknown')
        
        chunk_text = f"[Source: {filename}, Page: {page_number}]\n{doc.page_content}\n"
        formatted_chunks.append(chunk_text)
    
    return "\n".join(formatted_chunks)

prompt = ChatPromptTemplate.from_template("""
                                          
Please carefully review the conversation history and the provided documents below to answer the current question. Consider any context from previous exchanges when formulating your response.
                                          
Question: {input}

If you can find a complete answer, provide it with supporting evidence. If you can only find partial information or related context, explain what you found and acknowledge what's missing. If the documents don't contain relevant information, let me know what topics or areas the documents do cover instead.

Context: {context} 

Guidelines for your response:
- Be helpful and informative even when information is incomplete
- If you can't find a direct answer, mention what related information is available
- Always cite specific evidence from the documents when possible
- If no relevant information exists, briefly describe what the documents do contain

Format your response as:
ANSWER: [your response - can be direct answer, partial information, or explanation of available content]
EVIDENCE: [specific quotes or references from the documents, or note if no relevant content found]"""
)

In [52]:
rag_chain = (
    {"input": RunnablePassthrough(),  
    "context": retriever | format_docs,  
    } | prompt | structured_llm ) 

rag_chain.invoke("what is amtagvi?")

RecursionError: maximum recursion depth exceeded

In [35]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
import json

class State(TypedDict):
    original_question: str
    history: List[BaseMessage]

In [41]:
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema import Document
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver


class AgentState(TypedDict):
    messages: List[BaseMessage]
    documents: List[Document]
    on_topic: str
    rephrased_question: str
    proceed_to_generate: bool
    rephrase_count: int
    question: HumanMessage


class GradeQuestion(BaseModel):
    score: str = Field(
        description="Question is about medical topics? If yes -> 'Yes' if not -> 'No'"
    )


class GradeDocuments(BaseModel):
    """Grade multiple documents at once for relevance to the question."""
    relevant_document_indices: List[int] = Field(
        description="List of indices (0-based) of documents that are relevant to the question"
    )
    reasoning: str = Field(
        description="Brief explanation of why these documents were selected as relevant"
    )


def question_rewriter(state: AgentState):
    print(f"Entering question_rewriter with following state: {state}")

    state["documents"] = []
    state["on_topic"] = ""
    state["rephrased_question"] = ""
    state["proceed_to_generate"] = False
    state["rephrase_count"] = 0

    if "messages" not in state or state["messages"] is None:
        state["messages"] = []

    if state["question"] not in state["messages"]:
        state["messages"].append(state["question"])

    if len(state["messages"]) > 1:
        conversation = state["messages"][:-1]
        current_question = state["question"].content
        messages = [
            SystemMessage(
                content="You are a helpful assistant that rephrases the user's question to be a standalone question optimized for retrieval."
            )
        ]
        messages.extend(conversation)
        messages.append(HumanMessage(content=current_question))
        rephrase_prompt = ChatPromptTemplate.from_messages(messages)
        llm = ChatOpenAI(model="gpt-4o-mini")
        prompt = rephrase_prompt.format()
        response = llm.invoke(prompt)
        better_question = response.content.strip()
        print(f"question_rewriter: Rephrased question: {better_question}")
        state["rephrased_question"] = better_question
    else:
        state["rephrased_question"] = state["question"].content
    return state


def question_classifier(state: AgentState):
    print("Entering question_classifier")
    system_message = SystemMessage(
        content=""" You are a classifier that determines whether a user's question is about medical topics including:
    
    1. Diseases, conditions, and symptoms
    2. Medical treatments and procedures
    3. Medications and pharmaceuticals
    4. Human anatomy and physiology
    5. Medical diagnostics and testing
    6. Healthcare and medical advice
    7. Mental health and psychology
    8. Nutrition and dietary health
    9. Medical research and studies
    10. Any other healthcare or medical-related topics
    
    If the question IS about any medical topics, respond with 'Yes'. Otherwise, respond with 'No'.
    """
    )

    human_message = HumanMessage(
        content=f"User question: {state['rephrased_question']}"
    )
    grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o")
    structured_llm = llm.with_structured_output(GradeQuestion)
    grader_llm = grade_prompt | structured_llm
    result = grader_llm.invoke({})
    state["on_topic"] = result.score.strip()
    print(f"question_classifier: on_topic = {state['on_topic']}")
    return state


def on_topic_router(state: AgentState):
    print("Entering on_topic_router")
    on_topic = state.get("on_topic", "").strip().lower()
    if on_topic == "yes":
        print("Routing to retrieve")
        return "retrieve"
    else:
        print("Routing to off_topic_response")
        return "off_topic_response"


def retrieve(state: AgentState):
    print("Entering retrieve")
    documents = retriever.invoke(state["rephrased_question"])
    print(f"retrieve: Retrieved {len(documents)} documents")
    state["documents"] = documents
    return state


def retrieval_grader(state: AgentState):
    print("Entering retrieval_grader")
    
    if not state["documents"]:
        print("No documents to grade")
        state["proceed_to_generate"] = False
        return state
    
    # Prepare all documents for single LLM call
    documents_text = ""
    for i, doc in enumerate(state["documents"]):
        documents_text += f"Document {i}:\n{doc.page_content}\n\n"
    
    system_message = SystemMessage(
        content="""You are a grader assessing the relevance of retrieved documents to a user question.
        
You will be given multiple documents numbered from 0 onwards. Your task is to:
1. Identify which documents contain information relevant to the user's question
2. Return the indices (numbers) of the relevant documents
3. Provide brief reasoning for your selection

A document is relevant if it contains information that could help answer the user's question, even if partially."""
    )
    
    human_message = HumanMessage(
        content=f"User question: {state['rephrased_question']}\n\nDocuments to evaluate:\n{documents_text}"
    )
    
    grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o")
    structured_llm = llm.with_structured_output(GradeDocuments)
    grader_llm = grade_prompt | structured_llm
    result = grader_llm.invoke({})
    
    relevant_docs = []
    for idx in result.relevant_document_indices:
        if 0 <= idx < len(state["documents"]):
            relevant_docs.append(state["documents"][idx])
            print(f"Document {idx} marked as relevant")
    
    print(f"Grading reasoning: {result.reasoning}")
    print(f"Selected {len(relevant_docs)} relevant documents out of {len(state['documents'])}")
    
    state["documents"] = relevant_docs
    state["proceed_to_generate"] = len(relevant_docs) > 0
    print(f"retrieval_grader: proceed_to_generate = {state['proceed_to_generate']}")
    return state


def proceed_router(state: AgentState):
    print("Entering proceed_router")
    rephrase_count = state.get("rephrase_count", 0)
    if state.get("proceed_to_generate", False):
        print("Routing to generate_answer")
        return "generate_answer"
    elif rephrase_count >= 2:
        print("Maximum rephrase attempts reached. Cannot find relevant documents.")
        return "generate_answer"
    else:
        print("Routing to refine_question")
        return "refine_question"
    

def refine_question(state: AgentState):
    print("Entering refine_question")
    rephrase_count = state.get("rephrase_count", 0)
    if rephrase_count >= 2:
        print("Maximum rephrase attempts reached")
        return state
    question_to_refine = state["rephrased_question"]
    system_message = SystemMessage(
        content="""You are a helpful assistant that slightly refines the user's question to improve retrieval results for medical information.
Provide a slightly adjusted version of the question that might yield better results from a medical knowledge base."""
    )
    human_message = HumanMessage(
        content=f"Original question: {question_to_refine}\n\nProvide a slightly refined question for medical information retrieval."
    )
    refine_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o")
    prompt = refine_prompt.format()
    response = llm.invoke(prompt)
    refined_question = response.content.strip()
    print(f"refine_question: Refined question: {refined_question}")
    state["rephrased_question"] = refined_question
    state["rephrase_count"] = rephrase_count + 1
    return state


def generate_answer(state: AgentState):
    print("Entering generate_answer")
    if "messages" not in state or state["messages"] is None:
        raise ValueError("State must include 'messages' before generating an answer.")

    history = state["messages"]
    documents = state["documents"]
    rephrased_question = state["rephrased_question"]

    rag_chain = prompt | llm

    response = rag_chain.invoke(
        {"chat_history": history, "context": documents, "input": rephrased_question}
    )

    generation = response.content.strip()

    state["messages"].append(AIMessage(content=generation))
    print(f"generate_answer: Generated response: {generation}")
    return state

def off_topic_response(state: AgentState):
    print("Entering off_topic_response")
    if "messages" not in state or state["messages"] is None:
        state["messages"] = []
    state["messages"].append(
        AIMessage(
            content="I'm sorry! I can only answer questions related to medical and healthcare topics."
        )
    )
    return state


prompt = ChatPromptTemplate.from_template("""
You are a document-based assistant. Answer questions using ONLY the provided context documents. Never use external knowledge or training data beyond what's in the documents.

Conversation History: {chat_history}
Question: {input}
Context Documents: {context}

INSTRUCTIONS:
1. First, check if context documents are provided
2. If documents exist: search for direct answers to the question
3. If no direct answer exists: look for related, partially relevant, or tangentially connected information
4. If documents exist but contain no relevant info: describe what topics the documents DO cover
5. If no context documents provided: inform user that no relevant documents were found
6. NEVER supplement with knowledge from outside the provided documents

Response Format:
ANSWER: 
- If no context provided: "No relevant documents were found in the knowledge base for your question about [topic]"
- If direct answer found: [Provide complete answer]
- If partial/related info found: [Explain what related information exists and how it connects to the question]
- If documents exist but no relevant info: "The provided documents do not contain information about [specific topic]. However, the documents do cover: [list main topics from the actual documents]"

EVIDENCE: 
- If no context: "No documents retrieved from knowledge base"
- Otherwise: [Cite specific passages or "No information about [topic] found in documents"]

Guidelines:
- Be genuinely helpful by exploring connections the user might not have considered
- Acknowledge limitations clearly when information is missing or incomplete
- Focus on what IS available rather than what isn't
- Never invent or assume information not explicitly in the documents

Response:""")

workflow = StateGraph(AgentState)
workflow.add_node("question_rewriter", question_rewriter)
workflow.add_node("question_classifier", question_classifier)
workflow.add_node("off_topic_response", off_topic_response)
workflow.add_node("retrieve", retrieve)
workflow.add_node("retrieval_grader", retrieval_grader)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("refine_question", refine_question)

workflow.add_edge("question_rewriter", "question_classifier") 
workflow.add_conditional_edges(
    "question_classifier",
    on_topic_router,
    {
        "retrieve": "retrieve",
        "off_topic_response": "off_topic_response",
    },
)
workflow.add_edge("retrieve", "retrieval_grader")
workflow.add_conditional_edges(
    "retrieval_grader",
    proceed_router,
    {
        "generate_answer": "generate_answer",
        "refine_question": "refine_question"
    },
)
workflow.add_edge("refine_question", "retrieve")
workflow.add_edge("generate_answer", END)
workflow.add_edge("off_topic_response", END)
workflow.set_entry_point("question_rewriter")

checkpointer = MemorySaver()

graph = workflow.compile(checkpointer=checkpointer)

In [56]:
input_data = {
    "question": HumanMessage(
        content="How is AMTAGVI provided and what is the dose?"
    )
}

result = graph.invoke(input=input_data, config={"configurable": {"thread_id": 15}})

Entering question_rewriter with following state: {'question': HumanMessage(content='How is AMTAGVI provided and what is the dose?', additional_kwargs={}, response_metadata={})}
Entering question_classifier
question_classifier: on_topic = Yes
Entering on_topic_router
Routing to retrieve
Entering retrieve
Ensemble Retriever Results 43
Step 1: 1.43s
🔍 Simhash deduplicating 43 documents...
   Hash generation: 0.05s
   Comparison: 0.00s
   Simhash filtering: 43 → 40 docs
Step 2: 0.06s
deduplicated_docs: 40
final_docs: 10
Step 3: 0.66s
TOTAL: 2.15s
retrieve: Retrieved 10 documents
Entering retrieval_grader
Document 0 marked as relevant
Document 1 marked as relevant
Document 2 marked as relevant
Document 4 marked as relevant
Document 6 marked as relevant
Grading reasoning: Documents 0 and 1 clearly provide specific information on how AMTAGVI is provided and what the dose is, stating that it is a single dose of viable tumor-derived T cells in infusion bags. Document 2 describes the preparation

In [57]:
ai_message = result['messages'][-1].content

print("AI Response:")
print(ai_message)
print("\nCitations:")

seen_citations = set()
citation_count = 0

for doc in result['documents']:
    filename = doc.metadata['filename']
    page_number = doc.metadata['page_number']
    citation_key = (filename, page_number)
    
    if citation_key not in seen_citations:
        seen_citations.add(citation_key)
        citation_count += 1
        print(f"Citation {citation_count}: {filename} (page {page_number})")

AI Response:
ANSWER: AMTAGVI is provided as a single dose for infusion containing a suspension of tumor-derived T cells. The dose is supplied in 1 to 4 patient-specific IV infusion bags in individual protective metal cassettes. Each dose contains 7.5 x 10^9 to 72 x 10^9 viable cells.

EVIDENCE: 
- "AMTAGVI® is provided as a single dose for infusion containing a suspension of tumor-derived T cells. The dose is supplied in 1 to 4 patient-speciﬁc IV infusion bag(s) in individual protective metal cassettes. Each dose contains 7.5 x 10^9 to 72 x 10^9 viable cells." (Document metadata: 'chunk_id': '2debb310-87fc-4941-bfbd-ac5abb043f8b', 'page_number': '3')

Citations:
Citation 1: AMTAGVI Commercial FAQ Document.pdf (page 3)
Citation 2: Amtagvi PI 2024.pdf (page 2)
Citation 3: prc-us-00469-ref1-Amtagvi PI 02 2024.pdf (page 3)
Citation 4: prc-us-00483-ref2-Amtagvi PI 02 2024.pdf (page 5)
Citation 5: prc-us-00469-ref1-Amtagvi PI 02 2024.pdf (page 5)
