In [7]:
#Step 1: search, select
import boto3
from IPython.display import display, Markdown
import ipywidgets as widgets

# Initialize the Kendra client
kendra_client = boto3.client("kendra")

# Function to query the Kendra index with filters
def query_kendra(index_id, query_text, jurisdiction=None, doc_type=None):
    """
    Queries AWS Kendra with optional filters for jurisdiction and document type.
    """
    # Build attribute filters dynamically
    attribute_filters = []

    if jurisdiction:
        attribute_filters.append({
            "EqualsTo": {
                "Key": "jurisdiction",
                "Value": {"StringValue": jurisdiction}
            }
        })

    if doc_type:
        attribute_filters.append({
            "EqualsTo": {
                "Key": "doc_type",
                "Value": {"StringValue": doc_type}
            }
        })

    # Combine filters if provided
    attribute_filter = {"AndAllFilters": attribute_filters} if attribute_filters else None

    # Query Kendra
    response = kendra_client.query(
        IndexId=index_id,
        QueryText=query_text,
        AttributeFilter=attribute_filter  # Apply filters if present
    )

    # Extract relevant results
    return [
        {
            "DocumentTitle": item.get("DocumentTitle", {}).get("Text", "No Title"),
            "DocumentId": item["DocumentId"],
            "ExcerptText": item.get("ExcerptText", "No Excerpt Available")
        }
        for item in response["ResultItems"] if "DocumentId" in item
    ]

# UI Elements for Document Search with Filters
query_input = widgets.Text(
    placeholder="Search your documents",
    description="Search:",
    layout=widgets.Layout(width="70%")
)

jurisdiction_dropdown = widgets.Dropdown(
    options=["", "Colorado", "California", "Texas"],
    description="Jurisdiction:",
    layout=widgets.Layout(width="50%")
)

doc_type_dropdown = widgets.Dropdown(
    options=["", "Guidance Memo", "Regulation", "Permit"],
    description="Doc Type:",
    layout=widgets.Layout(width="50%")
)

output = widgets.Output()
select_button = widgets.Button(description="Select Documents")

# List to store selected document IDs and titles
selected_document_ids = []
selected_document_titles = []

def on_search_click(change):
    with output:
        output.clear_output()
        results = query_kendra(
            index_id="ac2e614a-1a60-4788-921f-439355c5756d", 
            query_text=query_input.value,
            jurisdiction=jurisdiction_dropdown.value.strip() or None,
            doc_type=doc_type_dropdown.value.strip() or None
        )
        if not results:
            display(Markdown("**No results found.**"))
            return

        display(Markdown("**Search Results:**"))
        document_checkboxes = []
        for result in results:
            # Display result title and excerpt
            display(Markdown(f"**{result['DocumentTitle']}**\n{result['ExcerptText']}"))
            # Append checkboxes with proper text
            checkbox = widgets.Checkbox(description=result['DocumentTitle'], value=False)
            checkbox.document_id = result["DocumentId"]  # Store document ID in the checkbox
            document_checkboxes.append(checkbox)

        # Add the checkboxes and select button
        select_button.on_click(lambda x: select_documents(document_checkboxes))
        display(widgets.VBox(document_checkboxes))
        display(select_button)

def select_documents(document_checkboxes):
    with output:
        output.clear_output()
        selected_document_ids.clear()
        selected_document_titles.clear()
        for checkbox in document_checkboxes:
            if checkbox.value:  # If the checkbox is selected
                selected_document_ids.append(checkbox.document_id)
                selected_document_titles.append(checkbox.description)

        if not selected_document_ids:
            display(Markdown("**No documents selected.**"))
            return
        
        # Display the selected documents
        display(Markdown(f"**{len(selected_document_ids)} doc(s) selected. Proceed to step 2.**"))
        display(Markdown(f"**Doc Titles:** {', '.join(selected_document_titles)}"))

# Display UI
search_button = widgets.Button(description="Search Documents")
search_button.on_click(on_search_click)

display(query_input, jurisdiction_dropdown, doc_type_dropdown, search_button, output)

Text(value='', description='Search:', layout=Layout(width='70%'), placeholder='Search your documents')

Dropdown(description='Jurisdiction:', layout=Layout(width='50%'), options=('', 'Colorado', 'California', 'Texa…

Dropdown(description='Doc Type:', layout=Layout(width='50%'), options=('', 'Guidance Memo', 'Regulation', 'Per…

Button(description='Search Documents', style=ButtonStyle())

Output()

In [8]:
# Step 2: Fetch, Chunk and Map Chunk Source (new delete above if works)
import boto3
from IPython.display import display, Markdown
import ipywidgets as widgets
import re
import io
from PyPDF2 import PdfReader

# Initialize S3 client
s3_client = boto3.client("s3")

# Function to fetch document text from S3
def fetch_document_text(document_uri):
    """
    Fetches document content from S3, dynamically handling multiple buckets
    and supporting both PDFs and text files.
    """
    # Extract bucket name and object key from S3 URI
    match = re.match(r"s3://([^/]+)/(.*)", document_uri)  # Extract bucket and key
    if not match:
        raise ValueError(f"Invalid S3 URI format: {document_uri}")

    bucket_name, object_key = match.groups()
    print(f"Fetching from S3: Bucket={bucket_name}, Key={object_key}")  # Debugging info

    try:
        response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
        content = response["Body"].read()

        # Handle PDFs
        if object_key.endswith(".pdf"):
            pdf_reader = PdfReader(io.BytesIO(content))
            text = "\n".join([page.extract_text() for page in pdf_reader.pages if page.extract_text()])
            return text

        return content.decode("utf-8")  # Handle text files
    except s3_client.exceptions.NoSuchKey:
        raise ValueError(f"Error: The file '{object_key}' does not exist in bucket '{bucket_name}'.")
    except Exception as e:
        raise ValueError(f"Unexpected error fetching document: {e}")

# Function to chunk document into smaller pieces
def chunk_document(text, chunk_size=500):
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

# UI for Ingestion and Optimization
chunk_size_slider = widgets.IntSlider(
    value=500, min=100, max=2000, step=100, description="Chunk Size:"
)
ingest_button = widgets.Button(description="Ingest Documents")
ingestion_output = widgets.Output()

# Step 2: Store all document chunks properly
all_chunks = []
chunk_doc_map = []  # Stores mapping of chunk -> document title

def ingest_documents(change):
    global all_chunks, chunk_doc_map  # Store chunks and their source docs
    with ingestion_output:
        ingestion_output.clear_output()
        if not selected_document_ids:
            display(Markdown("**No documents selected. Please complete Step 1.**"))
            return

        all_chunks = []  # Reset chunks
        chunk_doc_map = []  # Reset mapping

        for doc_id in selected_document_ids:
            try:
                doc_text = fetch_document_text(doc_id)
                chunks = chunk_document(doc_text, chunk_size=chunk_size_slider.value)

                # Store chunks and associate them with their document
                all_chunks.extend(chunks)
                chunk_doc_map.extend([doc_id] * len(chunks))  # Map each chunk to its document

                display(Markdown(f"Successfully processed document: **{doc_id}**"))
            except ValueError as e:
                display(Markdown(f"**Error:** {e}"))

        if all_chunks:
            display(Markdown(f"**{len(all_chunks)} document chunks stored. Proceed to Step 3.**"))
            
            # Debugging: Print distinct document sources
            unique_sources = set(chunk_doc_map)
            print("Debugging: Unique document sources used:")
            for source in unique_sources:
                print(source)

        else:
            display(Markdown("**No valid text extracted.**"))

ingest_button.on_click(ingest_documents)
display(widgets.VBox([chunk_size_slider, ingest_button, ingestion_output]))


VBox(children=(IntSlider(value=500, description='Chunk Size:', max=2000, min=100, step=100), Button(descriptio…

In [4]:
# Step 3Enhanced Interactive Chat Interface with Source Tracking
import boto3
import json
from IPython.display import display, Markdown, HTML
import ipywidgets as widgets
from collections import Counter

# Initialize Bedrock Runtime client
bedrock_runtime_client = boto3.client("bedrock-runtime")

# Function to invoke the Claude model
def invoke_claude_model(context, question, max_output_tokens=6000):
    try:
        # Build the messages list for Claude
        messages = [
            {"role": "user", "content": f"Context: {context}\nQuestion: {question}"}
        ]

        response = bedrock_runtime_client.invoke_model(
            modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
            body=json.dumps({
                "messages": messages,
                "max_tokens": max_output_tokens,
                "anthropic_version": "bedrock-2023-05-31"
            })
        )

        response_body = response["body"].read().decode("utf-8")
        result = json.loads(response_body)

        if "content" in result and isinstance(result["content"], list):
            return " ".join([item["text"] for item in result["content"] if item["type"] == "text"])
        else:
            raise ValueError("Claude API response does not contain valid 'content'.")
    except Exception as e:
        print(f"Error: {e}")
        return None

# Function to find the most relevant chunks for a question
def find_relevant_chunks(question, chunks, chunk_doc_map, num_chunks=15):
    """
    Finds the most relevant chunks based on keyword overlap with the question and labels them with source documents.
    Also returns detailed statistics about which sources were used and how many chunks from each.
    """
    keywords = set(question.lower().split())  # Extract words from the user's question
    scored_chunks = []
    selected_sources = set()

    # Score chunks based on keyword matches and retain document references
    for i, chunk in enumerate(chunks):
        chunk_words = set(chunk.lower().split())
        match_count = len(keywords.intersection(chunk_words))  # Count overlapping words
        if match_count > 0:
            doc_source = chunk_doc_map[i]  # Get original document source
            scored_chunks.append((chunk, match_count, doc_source, i))  # Store chunk, score, source, and index
            selected_sources.add(doc_source)

    # Sort by relevance (highest matches first)
    scored_chunks.sort(key=lambda x: x[1], reverse=True)

    # Take top N chunks
    top_chunks = scored_chunks[:num_chunks]
    
    # Extract sources and chunks used with indices
    used_chunks = [(chunk[0], chunk[2], chunk[3]) for chunk in top_chunks]  # (chunk_text, source, index)
    source_counter = Counter([chunk[2] for chunk in top_chunks])  # Count chunks per source
    
    # Create formatted context chunks
    relevant_chunks = [f"[Source: {chunk[2]}] {chunk[0]}" for chunk in top_chunks]

    # Fallback: If no relevant chunks found, select evenly distributed chunks
    if not relevant_chunks and chunks:
        step = max(1, len(chunks) // num_chunks)
        indices = [i for i in range(0, len(chunks), step)[:num_chunks]]
        relevant_chunks = [f"[Source: {chunk_doc_map[i]}] {chunks[i]}" for i in indices]
        source_counter = Counter([chunk_doc_map[i] for i in indices])
        used_chunks = [(chunks[i], chunk_doc_map[i], i) for i in indices]
        selected_sources = set([chunk_doc_map[i] for i in indices])

    return relevant_chunks, list(selected_sources), source_counter, used_chunks

def create_enhanced_chat_interface():
    """
    Creates and returns an improved chat interface with better source tracking.
    """
    # Create widgets for interactive chat interface
    question_input = widgets.Text(
        value='',
        placeholder='Ask a question about your documents',
        description='Question:',
        disabled=False,
        layout=widgets.Layout(width='80%')
    )

    submit_button = widgets.Button(
        description='Ask',
        button_style='primary',
        tooltip='Submit your question'
    )

    # Create output widget with scrolling and increased height
    chat_output = widgets.Output(
        layout=widgets.Layout(
            height='550px',  # Increased height for source statistics
            width='95%',     # Wider display
            overflow='auto',  # Enable scrolling
            border='1px solid #ddd',
            padding='10px'
        )
    )

    # Handler for submit button click
    def on_submit_question(b):
        global all_chunks, chunk_doc_map
        
        question = question_input.value.strip()
        
        if not question:
            return
        
        if not all_chunks:
            with chat_output:
                display(Markdown("**Error:** No document chunks available. Please run Step 2 first."))
            return
        
        with chat_output:
            # Clear previous output and show the question
            chat_output.clear_output()
            display(Markdown(f"**User Question:** {question}"))
            display(Markdown("*Processing...*"))
            
            # Find relevant chunks and get AI response with enhanced source tracking
            context_chunks, selected_sources, source_counter, used_chunks = find_relevant_chunks(
                question, all_chunks, chunk_doc_map, num_chunks=35
            )
            context = " ".join(context_chunks)
            
            # Debug info
            print(f"Found {len(context_chunks)} relevant chunks from {len(selected_sources)} documents")
            
            # Invoke Claude model and display response
            chat_output.clear_output()
            display(Markdown(f"**User Question:** {question}"))
            
            response = invoke_claude_model(context, question, max_output_tokens=6000)
            
            if response:
                # Format the response with HTML for better readability
                display(HTML(f"<div style='white-space: pre-wrap; margin: 10px 0;'><b>AI Response:</b><br>{response}</div>"))
                
                if selected_sources:
                    # Format detailed source usage statistics
                    display(HTML("<b>Source Usage Statistics:</b>"))
                    display(HTML("<div style='margin: 10px 0; padding: 10px; background-color: #f8f8f8; border-radius: 5px;'>"))
                    
                    # Display table with source usage information
                    table_html = """<table style='width:100%; border-collapse: collapse;'>
                        <tr style='background-color: #e0e0e0;'>
                            <th style='padding: 8px; text-align: left; border: 1px solid #ddd;'>Source</th>
                            <th style='padding: 8px; text-align: center; border: 1px solid #ddd;'>Chunks Used</th>
                            <th style='padding: 8px; text-align: center; border: 1px solid #ddd;'>% of Context</th>
                        </tr>"""
                    
                    # Calculate total chunks used
                    total_chunks = sum(source_counter.values())
                    
                    # Add rows for each source
                    for source, count in source_counter.most_common():
                        percentage = (count / total_chunks) * 100
                        table_html += f"""<tr>
                            <td style='padding: 8px; border: 1px solid #ddd;'>{source}</td>
                            <td style='padding: 8px; text-align: center; border: 1px solid #ddd;'>{count}</td>
                            <td style='padding: 8px; text-align: center; border: 1px solid #ddd;'>{percentage:.1f}%</td>
                        </tr>"""
                    
                    table_html += "</table>"
                    display(HTML(table_html))
                    display(HTML("</div>"))
                    
                    # Add option to show all available sources
                    show_all_button = widgets.Button(
                        description="Show All Available Sources",
                        button_style="info",
                        layout=widgets.Layout(width='auto')
                    )
                    
                    sources_output = widgets.Output()
                    
                    def on_show_all(b):
                        with sources_output:
                            sources_output.clear_output()
                            unique_sources = set(chunk_doc_map)
                            display(HTML("<b>All Available Sources:</b>"))
                            for src in unique_sources:
                                has_chunks = src in source_counter
                                style = "color: green;" if has_chunks else "color: #888;"
                                count_text = f"({source_counter[src]} chunks)" if has_chunks else "(not used)"
                                display(HTML(f"<div style='margin-left: 20px; {style}'>• {src} {count_text}</div>"))
                    
                    show_all_button.on_click(on_show_all)
                    display(show_all_button)
                    display(sources_output)
            else:
                display(Markdown("**Error:** Failed to get a response from the AI model."))
        
        # Clear the input field for the next question
        question_input.value = ''

    # Connect the button to the handler function
    submit_button.on_click(on_submit_question)

    # Handle Enter key in the text input using the modern pattern
    question_input.observe(lambda change: on_submit_question(None) 
                         if change['type'] == 'change' and change['name'] == 'value' and change.get('new', '').endswith('\n') 
                         else None, 'value')

    # Return the assembled interface
    return widgets.VBox([
        widgets.HTML("<h3>Chat with Your Documents (Enhanced Source Tracking)</h3>"),
        widgets.HBox([question_input, submit_button]),
        chat_output
    ])

# To use this in a notebook, run:
display(create_enhanced_chat_interface())


VBox(children=(HTML(value='<h3>Chat with Your Documents (Enhanced Source Tracking)</h3>'), HBox(children=(Text…

Output()

In [5]:
# Step 3.1: Enhanced Source Diversity Version
import boto3
import json
from IPython.display import display, Markdown, HTML
import ipywidgets as widgets
from collections import Counter

# Function to invoke the Claude model with source diversity instruction
def invoke_claude_model_diverse(context, question, max_output_tokens=6000):
    """
    Invokes Claude with explicit instruction to use all available sources.
    """
    try:
        # Add explicit instruction to consider all sources
        instruction = "Please consider information from ALL available sources in the provided context when answering. Make sure to draw from as many different sources as possible for your response."
        
        # Build the messages list for Claude with the diversity instruction
        messages = [
            {"role": "user", "content": f"Context: {context}\n\n{instruction}\n\nQuestion: {question}"}
        ]

        response = bedrock_runtime_client.invoke_model(
            modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
            body=json.dumps({
                "messages": messages,
                "max_tokens": max_output_tokens,
                "anthropic_version": "bedrock-2023-05-31"
            })
        )

        response_body = response["body"].read().decode("utf-8")
        result = json.loads(response_body)

        if "content" in result and isinstance(result["content"], list):
            return " ".join([item["text"] for item in result["content"] if item["type"] == "text"])
        else:
            raise ValueError("Claude API response does not contain valid 'content'.")
    except Exception as e:
        print(f"Error: {e}")
        return None

# Enhanced function to find relevant chunks with maximized source diversity
def find_relevant_chunks_diverse(question, chunks, chunk_doc_map, num_chunks=50):
    """
    Finds the most relevant chunks with enhanced source diversity features.
    1. Scores chunks based on keyword overlap with question
    2. Applies source diversity boosting
    3. Ensures representation from all available sources
    4. Returns detailed source statistics
    """
    keywords = set(question.lower().split())  # Extract words from the user's question
    scored_chunks = []
    all_sources = set(chunk_doc_map)  # All available sources
    print(f"Available sources: {len(all_sources)}")
    
    # PHASE 1: Score chunks based on keyword matches
    for i, chunk in enumerate(chunks):
        chunk_words = set(chunk.lower().split())
        match_count = len(keywords.intersection(chunk_words))  # Count overlapping words
        if match_count > 0:  # Only consider chunks with at least one keyword match
            doc_source = chunk_doc_map[i]  # Get original document source
            scored_chunks.append((chunk, match_count, doc_source, i))  # Store chunk, score, source, and index
    
    # If no chunks matched keywords, use all chunks with minimal scores
    if not scored_chunks:
        for i, chunk in enumerate(chunks):
            doc_source = chunk_doc_map[i]
            scored_chunks.append((chunk, 0.1, doc_source, i))  # Minimal score
    
    # PHASE 2: Apply source diversity boosting
    # Count how many chunks we have from each source
    source_counts = Counter([src for _, _, src, _ in scored_chunks])
    
    # Calculate diversity boost for each chunk - less common sources get bigger boost
    max_count = max(source_counts.values()) if source_counts else 1
    diversity_boosted_chunks = []
    
    for chunk, score, source, idx in scored_chunks:
        # Apply diversity boost - give bonus to sources with fewer chunks
        diversity_boost = 1.0 + (max_count / (source_counts[source])) * 0.5
        new_score = score * diversity_boost
        diversity_boosted_chunks.append((chunk, new_score, source, idx))
    
    # Sort by boosted score
    diversity_boosted_chunks.sort(key=lambda x: x[1], reverse=True)
    
    # PHASE 3: Ensure representation from all sources
    # First pass: select one chunk from each source
    ensured_chunks = []
    source_included = set()  # Track sources we've included
    
    # First ensure every source gets at least one chunk if possible
    for source in all_sources:
        # Find the highest scored chunk for this source
        source_chunks = [(chunk, score, src, idx) for chunk, score, src, idx 
                         in diversity_boosted_chunks if src == source]
        
        if source_chunks:  # If we have chunks from this source
            best_chunk = max(source_chunks, key=lambda x: x[1])  # Get highest scored chunk
            ensured_chunks.append(best_chunk)
            source_included.add(source)
    
    # Then fill remaining slots with best chunks
    remaining_chunks = [c for c in diversity_boosted_chunks 
                        if c not in ensured_chunks]
    remaining_chunks.sort(key=lambda x: x[1], reverse=True)  # Sort by score
    
    # Fill up to num_chunks
    final_chunks = ensured_chunks + remaining_chunks[:num_chunks - len(ensured_chunks)]
    
    # Sort by relevance score for final ordering
    final_chunks.sort(key=lambda x: x[1], reverse=True)
    
    # Generate statistics and prepare output
    used_source_counter = Counter([chunk[2] for chunk in final_chunks])
    selected_sources = list(used_source_counter.keys())
    used_chunks = [(chunk[0], chunk[2], chunk[3]) for chunk in final_chunks]  # (text, source, index)
    
    # Create formatted context chunks
    relevant_chunks = [f"[Source: {chunk[2]}] {chunk[0]}" for chunk in final_chunks]
    
    print(f"Sources representation: {len(selected_sources)}/{len(all_sources)} sources")
    for source, count in used_source_counter.most_common():
        print(f"  - {source}: {count} chunks")
        
    return relevant_chunks, selected_sources, used_source_counter, used_chunks

# Create our enhanced chat interface with source diversity
def create_diverse_chat_interface():
    """
    Creates and returns an improved chat interface with maximized source diversity.
    """
    # Create widgets for interactive chat interface
    question_input = widgets.Text(
        value='',
        placeholder='Ask a question about your documents',
        description='Question:',
        disabled=False,
        layout=widgets.Layout(width='80%')
    )

    submit_button = widgets.Button(
        description='Ask',
        button_style='primary',
        tooltip='Submit your question'
    )
    
    # Chunk count slider for fine-tuning
    chunk_slider = widgets.IntSlider(
        value=50, 
        min=15, 
        max=100, 
        step=5, 
        description='Max Chunks:',
        layout=widgets.Layout(width='50%'),
        style={'description_width': 'initial'}
    )

    # Create output widget with scrolling and increased height
    chat_output = widgets.Output(
        layout=widgets.Layout(
            height='650px',  # Increased height for source statistics
            width='95%',     # Wider display
            overflow='auto',  # Enable scrolling
            border='1px solid #ddd',
            padding='10px'
        )
    )

    # Handler for submit button click
    def on_submit_question(b):
        global all_chunks, chunk_doc_map
        
        question = question_input.value.strip()
        
        if not question:
            return
        
        if not all_chunks:
            with chat_output:
                display(Markdown("**Error:** No document chunks available. Please run Step 2 first."))
            return
        
        with chat_output:
            # Clear previous output and show the question
            chat_output.clear_output()
            display(Markdown(f"**User Question:** {question}"))
            display(Markdown("*Processing with source diversity optimization...*"))
            
            # Use enhanced diverse chunk finder
            context_chunks, selected_sources, source_counter, used_chunks = find_relevant_chunks_diverse(
                question, all_chunks, chunk_doc_map, num_chunks=chunk_slider.value
            )
            context = " ".join(context_chunks)
            
            # Invoke Claude model and display response
            chat_output.clear_output()
            display(Markdown(f"**User Question:** {question}"))
            display(Markdown("*Generating answer with source diversity optimization...*"))
            
            response = invoke_claude_model_diverse(context, question, max_output_tokens=6000)
            
            if response:
                # Clear processing messages
                chat_output.clear_output()
                display(Markdown(f"**User Question:** {question}"))
                
                # Format the response with HTML for better readability
                display(HTML(f"<div style='white-space: pre-wrap; margin: 10px 0;'><b>AI Response:</b><br>{response}</div>"))
                
                # Show source diversity metrics
                all_sources = set(chunk_doc_map)
                source_coverage = (len(selected_sources) / len(all_sources)) * 100 if all_sources else 0
                
                display(HTML(f"<div style='margin: 15px 0;'>"
                           f"<b>Source Coverage:</b> {len(selected_sources)}/{len(all_sources)} sources "
                           f"({source_coverage:.1f}%)</div>"))
                
                if selected_sources:
                    # Format detailed source usage statistics
                    display(HTML("<b>Source Usage Statistics:</b>"))
                    display(HTML("<div style='margin: 10px 0; padding: 10px; background-color: #f8f8f8; border-radius: 5px;'>"))
                    
                    # Display table with source usage information
                    table_html = """<table style='width:100%; border-collapse: collapse;'>
                        <tr style='background-color: #e0e0e0;'>
                            <th style='padding: 8px; text-align: left; border: 1px solid #ddd;'>Source</th>
                            <th style='padding: 8px; text-align: center; border: 1px solid #ddd;'>Chunks Used</th>
                            <th style='padding: 8px; text-align: center; border: 1px solid #ddd;'>% of Context</th>
                        </tr>"""
                    
                    # Calculate total chunks used
                    total_chunks = sum(source_counter.values())
                    
                    # Add rows for each source
                    for source, count in source_counter.most_common():
                        percentage = (count / total_chunks) * 100
                        table_html += f"""<tr>
                            <td style='padding: 8px; border: 1px solid #ddd;'>{source}</td>
                            <td style='padding: 8px; text-align: center; border: 1px solid #ddd;'>{count}</td>
                            <td style='padding: 8px; text-align: center; border: 1px solid #ddd;'>{percentage:.1f}%</td>
                        </tr>"""
                    
                    table_html += "</table>"
                    display(HTML(table_html))
                    display(HTML("</div>"))
                    
                    # Add option to show all available sources
                    show_all_button = widgets.Button(
                        description="Show All Available Sources",
                        button_style="info",
                        layout=widgets.Layout(width='auto')
                    )
                    
                    sources_output = widgets.Output()
                    
                    def on_show_all(b):
                        with sources_output:
                            sources_output.clear_output()
                            unique_sources = set(chunk_doc_map)
                            display(HTML("<b>All Available Sources:</b>"))
                            for src in unique_sources:
                                has_chunks = src in source_counter
                                style = "color: green; font-weight: bold;" if has_chunks else "color: #888;"
                                count_text = f"({source_counter[src]} chunks)" if has_chunks else "(not used)"
                                display(HTML(f"<div style='margin-left: 20px; {style}'>• {src} {count_text}</div>"))
                    
                    show_all_button.on_click(on_show_all)
                    display(show_all_button)
                    display(sources_output)
            else:
                display(Markdown("**Error:** Failed to get a response from the AI model."))
        
        # Clear the input field for the next question
        question_input.value = ''

    # Connect the button to the handler function
    submit_button.on_click(on_submit_question)

    # Handle Enter key in the text input using the modern pattern
    question_input.observe(lambda change: on_submit_question(None) 
                         if change['type'] == 'change' and change['name'] == 'value' and change.get('new', '').endswith('\n') 
                         else None, 'value')

    # Return the assembled interface
    return widgets.VBox([
        widgets.HTML("<h3>📚 Chat with Your Documents (Enhanced Source Diversity)</h3>"),
        widgets.HBox([question_input, submit_button]),
        widgets.HBox([chunk_slider]),
        chat_output
    ])

# Display the enhanced interface
display(create_diverse_chat_interface())

VBox(children=(HTML(value='<h3>📚 Chat with Your Documents (Enhanced Source Diversity)</h3>'), HBox(children=(T…

Output()

In [6]:
#step 3.2 improved and balanced source diversity, invoke model chat with your docs
import boto3
import json
from IPython.display import display, Markdown, HTML
import ipywidgets as widgets
from collections import Counter, defaultdict

# Initialize Bedrock Runtime client
bedrock_runtime_client = boto3.client("bedrock-runtime")

# Function to invoke the Claude model with balanced sources
def invoke_claude_model_balanced(context, question, max_output_tokens=6000):
    """
    Invokes Claude with explicit instruction to use balanced information from all sources.
    """
    try:
        # Add explicit instruction about balanced source usage
        instruction = "Please use a balanced approach drawing from ALL available sources in the provided context. Give approximately equal weight to each different source document when crafting your response."
        
        # Build the messages list for Claude with the balance instruction
        messages = [
            {"role": "user", "content": f"Context: {context}\n\n{instruction}\n\nQuestion: {question}"}
        ]

        response = bedrock_runtime_client.invoke_model(
            modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
            body=json.dumps({
                "messages": messages,
                "max_tokens": max_output_tokens,
                "anthropic_version": "bedrock-2023-05-31"
            })
        )

        response_body = response["body"].read().decode("utf-8")
        result = json.loads(response_body)

        if "content" in result and isinstance(result["content"], list):
            return " ".join([item["text"] for item in result["content"] if item["type"] == "text"])
        else:
            raise ValueError("Claude API response does not contain valid 'content'.")
    except Exception as e:
        print(f"Error: {e}")
        return None

# Truly balanced chunk selection algorithm
def find_relevant_chunks_balanced(question, chunks, chunk_doc_map, num_chunks=50, diversity_weight=0.7):
    """
    Finds relevant chunks with a truly balanced approach across sources:
    1. Initial scoring based on keyword relevance
    2. Ensures one chunk from each source
    3. Uses round-robin selection to ensure balanced representation
    4. Applies stronger diversity weighting
    """
    keywords = set(question.lower().split())  # Extract words from the user's question
    scored_chunks = []
    all_sources = list(set(chunk_doc_map))  # All available sources
    print(f"Available sources: {len(all_sources)}")
    
    # PHASE 1: Initial scoring of chunks based on keyword matches
    for i, chunk in enumerate(chunks):
        chunk_words = set(chunk.lower().split())
        match_count = len(keywords.intersection(chunk_words))  # Count overlapping words
        doc_source = chunk_doc_map[i]  # Get original document source
        scored_chunks.append((chunk, match_count, doc_source, i))  # Store chunk, score, source, and index
    
    # PHASE 2: Group chunks by source
    source_chunks = defaultdict(list)
    for chunk_data in scored_chunks:
        chunk, score, source, idx = chunk_data
        source_chunks[source].append((chunk, score, source, idx))
    
    # Sort chunks within each source by relevance score
    for source in source_chunks:
        source_chunks[source].sort(key=lambda x: x[1], reverse=True)
    
    # PHASE 3: Calculate how many chunks to take from each source for balanced representation
    source_count = len(all_sources)
    chunks_per_source = max(1, num_chunks // source_count)
    
    # Ensure remaining chunks are distributed evenly
    remaining = num_chunks - (chunks_per_source * source_count)
    extra_chunks = [1] * remaining + [0] * (source_count - remaining)
    
    # PHASE 4: Select chunks using round-robin with adjusted allocations
    final_chunks = []
    
    # First, ensure every source gets its fair allocation
    for i, source in enumerate(all_sources):
        # Calculate how many chunks to take from this source
        allocation = chunks_per_source + extra_chunks[i]
        
        # Get the top N chunks from this source (or as many as available)
        source_selection = source_chunks[source][:allocation]
        final_chunks.extend(source_selection)
    
    # PHASE 5: If we still have slots to fill, take the most relevant remaining chunks
    if len(final_chunks) < num_chunks:
        # Collect all remaining chunks
        remaining_chunks = []
        for source in all_sources:
            used = len([c for c in final_chunks if c[2] == source])
            remaining_chunks.extend(source_chunks[source][used:])
        
        # Apply a stronger diversity boost to remaining chunks
        boosted_remaining = []
        for chunk, score, source, idx in remaining_chunks:
            # Count how many chunks we already took from this source
            already_used = len([c for c in final_chunks if c[2] == source])
            
            # Stronger inverse scaling - sources with fewer chunks get more boost
            diversity_penalty = already_used ** 2  # Quadratic penalty
            adjusted_score = score - (diversity_penalty * diversity_weight)
            boosted_remaining.append((chunk, adjusted_score, source, idx))
        
        # Sort and add remaining chunks until we reach the limit
        boosted_remaining.sort(key=lambda x: x[1], reverse=True)
        final_chunks.extend(boosted_remaining[:num_chunks - len(final_chunks)])
    
    # Generate statistics and prepare output
    used_source_counter = Counter([chunk[2] for chunk in final_chunks])
    selected_sources = list(used_source_counter.keys())
    used_chunks = [(chunk[0], chunk[2], chunk[3]) for chunk in final_chunks]  # (text, source, index)
    
    # Create formatted context chunks
    relevant_chunks = [f"[Source: {chunk[2]}] {chunk[0]}" for chunk in final_chunks]
    
    print(f"Sources representation: {len(selected_sources)}/{len(all_sources)} sources")
    for source, count in used_source_counter.most_common():
        print(f"  - {source}: {count} chunks")
        
    return relevant_chunks, selected_sources, used_source_counter, used_chunks

# Create balanced chat interface
def create_balanced_chat_interface():
    """
    Creates and returns an improved chat interface with truly balanced source distribution.
    """
    # Create widgets for interactive chat interface
    question_input = widgets.Text(
        value='',
        placeholder='Ask a question about your documents',
        description='Question:',
        disabled=False,
        layout=widgets.Layout(width='80%')
    )

    submit_button = widgets.Button(
        description='Ask',
        button_style='primary',
        tooltip='Submit your question'
    )
    
    # Chunk count and diversity weight sliders
    chunk_slider = widgets.IntSlider(
        value=50, 
        min=15, 
        max=100, 
        step=5, 
        description='Max Chunks:',
        layout=widgets.Layout(width='50%'),
        style={'description_width': 'initial'}
    )
    
    diversity_slider = widgets.FloatSlider(
        value=0.7,
        min=0.1,
        max=1.0,
        step=0.1,
        description='Diversity Weight:',
        layout=widgets.Layout(width='50%'),
        style={'description_width': 'initial'},
        tooltip='Higher value means more balanced source distribution'
    )

    # Create output widget with scrolling and increased height
    chat_output = widgets.Output(
        layout=widgets.Layout(
            height='650px',  # Increased height for source statistics
            width='95%',     # Wider display
            overflow='auto',  # Enable scrolling
            border='1px solid #ddd',
            padding='10px'
        )
    )

    # Handler for submit button click
    def on_submit_question(b):
        global all_chunks, chunk_doc_map
        
        question = question_input.value.strip()
        
        if not question:
            return
        
        if not all_chunks:
            with chat_output:
                display(Markdown("**Error:** No document chunks available. Please run Step 2 first."))
            return
        
        with chat_output:
            # Clear previous output and show the question
            chat_output.clear_output()
            display(Markdown(f"**User Question:** {question}"))
            display(Markdown("*Processing with balanced source distribution...*"))
            
            # Use balanced chunk finder
            context_chunks, selected_sources, source_counter, used_chunks = find_relevant_chunks_balanced(
                question, all_chunks, chunk_doc_map, 
                num_chunks=chunk_slider.value,
                diversity_weight=diversity_slider.value
            )
            context = " ".join(context_chunks)
            
            # Invoke Claude model and display response
            chat_output.clear_output()
            display(Markdown(f"**User Question:** {question}"))
            display(Markdown("*Generating answer with balanced source distribution...*"))
            
            response = invoke_claude_model_balanced(context, question, max_output_tokens=6000)
            
            if response:
                # Clear processing messages
                chat_output.clear_output()
                display(Markdown(f"**User Question:** {question}"))
                
                # Format the response with HTML for better readability
                display(HTML(f"<div style='white-space: pre-wrap; margin: 10px 0;'><b>AI Response:</b><br>{response}</div>"))
                
                # Show source balance metrics
                all_sources = set(chunk_doc_map)
                source_coverage = (len(selected_sources) / len(all_sources)) * 100 if all_sources else 0
                
                display(HTML(f"<div style='margin: 15px 0;'>"
                           f"<b>Source Coverage:</b> {len(selected_sources)}/{len(all_sources)} sources "
                           f"({source_coverage:.1f}%)</div>"))
                
                if selected_sources:
                    # Format detailed source usage statistics
                    display(HTML("<b>Source Distribution:</b>"))
                    display(HTML("<div style='margin: 10px 0; padding: 10px; background-color: #f8f8f8; border-radius: 5px;'>"))
                    
                    # Display table with source usage information
                    table_html = """<table style='width:100%; border-collapse: collapse;'>
                        <tr style='background-color: #e0e0e0;'>
                            <th style='padding: 8px; text-align: left; border: 1px solid #ddd;'>Source</th>
                            <th style='padding: 8px; text-align: center; border: 1px solid #ddd;'>Chunks Used</th>
                            <th style='padding: 8px; text-align: center; border: 1px solid #ddd;'>% of Context</th>
                        </tr>"""
                    
                    # Calculate total chunks used
                    total_chunks = sum(source_counter.values())
                    
                    # Add rows for each source
                    for source, count in source_counter.most_common():
                        percentage = (count / total_chunks) * 100
                        table_html += f"""<tr>
                            <td style='padding: 8px; border: 1px solid #ddd;'>{source}</td>
                            <td style='padding: 8px; text-align: center; border: 1px solid #ddd;'>{count}</td>
                            <td style='padding: 8px; text-align: center; border: 1px solid #ddd;'>{percentage:.1f}%</td>
                        </tr>"""
                    
                    table_html += "</table>"
                    display(HTML(table_html))
                    display(HTML("</div>"))
                    
                    # Add option to show all available sources
                    show_all_button = widgets.Button(
                        description="Show All Available Sources",
                        button_style="info",
                        layout=widgets.Layout(width='auto')
                    )
                    
                    sources_output = widgets.Output()
                    
                    def on_show_all(b):
                        with sources_output:
                            sources_output.clear_output()
                            unique_sources = set(chunk_doc_map)
                            display(HTML("<b>All Available Sources:</b>"))
                            for src in unique_sources:
                                has_chunks = src in source_counter
                                style = "color: green; font-weight: bold;" if has_chunks else "color: #888;"
                                count_text = f"({source_counter[src]} chunks)" if has_chunks else "(not used)"
                                display(HTML(f"<div style='margin-left: 20px; {style}'>• {src} {count_text}</div>"))
                    
                    show_all_button.on_click(on_show_all)
                    display(show_all_button)
                    display(sources_output)
            else:
                display(Markdown("**Error:** Failed to get a response from the AI model."))
        
        # Clear the input field for the next question
        question_input.value = ''

    # Connect the button to the handler function
    submit_button.on_click(on_submit_question)

    # Handle Enter key in the text input
    question_input.observe(lambda change: on_submit_question(None) 
                         if change['type'] == 'change' and change['name'] == 'value' and change.get('new', '').endswith('\n') 
                         else None, 'value')

    # Return the assembled interface
    return widgets.VBox([
        widgets.HTML("<h3>⚖️ Chat with Your Documents (Balanced Source Distribution)</h3>"),
        widgets.HBox([question_input, submit_button]),
        widgets.HBox([chunk_slider, diversity_slider]),
        chat_output
    ])


display(create_balanced_chat_interface()) 

VBox(children=(HTML(value='<h3>⚖️ Chat with Your Documents (Balanced Source Distribution)</h3>'), HBox(childre…

Output()

Output()

Output()