# Clinical Notes Processing Pipeline: Comprehensive Overview

This pipeline provides a comprehensive workflow for processing, filtering, summarizing, and interacting with clinical notes. It's designed as a Jupyter notebook with interactive widgets that guide users through each step of the processws.

## Pipeline Architecture

The pipeline consists of five main steps, each implemented as a separate module:

1. **Data Loading**: Extract clinical notes from CSV files
2. **Note Filtering**: Filter notes by MRN, encounter type, and event description
3. **AWS Configuration**: Set up AWS Bedrock for AI processing
4. **Note Summarization**: Generate individual or combined summaries of clinical notes
5. **Chat Interface**: Ask questions about the clinical notes

Each step builds on the previous one, but the updated implementation allows users to access any step directly through a button interface.

## Step 1: Data Loading

**Key Components:**
- `initialize_extraction()`: Sets up the UI for loading CSV data
- Input fields for CSV path and output directory
- "Load CSV" button that reads the data and creates the output directory

This step loads clinical notes from a CSV file and prepares them for processing. The CSV is expected to contain columns for MRN, encounter type, medical service, and the actual note content.

## Step 2: Note Filtering

**Key Components:**
- `initialize_filtering_widgets()`: Creates UI for filtering notes
- `search_mrn()`: Filters notes by MRN
- `filter_encounter_type()`: Further filters by encounter type
- `filter_event_desc()`: Final filtering by event description
- `save_blob_contents()`: Saves filtered notes as text files

This step allows users to progressively filter notes by different criteria. The filtered notes are saved as structured text files organized by MRN in the specified output directory.

## Step 3: AWS Configuration

**Key Components:**
- `setup_aws_ui()`: Creates UI for AWS configuration
- `sso_login()`: Handles AWS SSO login
- `initialize_aws()`: Sets up AWS Bedrock clients

This step configures AWS credentials and initializes the Bedrock clients needed for AI processing. It supports SSO login and verifies access to Bedrock models.

## Step 4: Note Summarization

**Key Components:**
- `initialize_summarization_ui()`: Creates UI for summarization
- `summarize_with_bedrock()`: Generates summaries using AWS Bedrock
- `summarize_combined_notes_chunked()`: Handles large sets of notes by chunking
- `filter_notes()`: Filters notes based on include/exclude keywords

This step offers two summarization modes:
1. **Individual summaries**: Generates a separate summary for each note
2. **Combined summary**: Creates a comprehensive summary of all notes

The implementation includes:
- Model selection from popular Bedrock models
- MRN-specific processing
- Note filtering by keywords
- Saving individual summaries as separate text files

## Step 5: Chat Interface

**Key Components:**
- `initialize_chat_ui()`: Creates UI for chatting with notes
- `chat_with_clinical_notes()`: Processes queries using AWS Bedrock
- `filter_notes()`: Filters notes based on include/exclude keywords
- `show_example_questions()`: Displays example questions

This step allows users to ask questions about the clinical notes and receive AI-generated answers. Features include:
- Model selection for Q&A
- Note filtering by keywords
- MRN-specific processing
- Example questions for guidance

## Supporting Modules

### Prompt Templates
The `summarization_prompts.py` file contains prompt templates for different models and tasks:
- `INDIVIDUAL_NOTE_PROMPTS`: For summarizing individual notes
- `CHUNK_SUMMARIZATION_PROMPTS`: For summarizing chunks of notes
- `FINAL_SUMMARIZATION_PROMPTS`: For combining chunk summaries
- `COMBINED_NOTE_PROMPTS`: For summarizing all notes together
- `QA_PROMPTS`: For answering questions about notes

### Utility Functions
- `get_clinical_notes()`: Reads text files from a directory
- `combine_clinical_notes()`: Combines multiple notes with identifiers
- `get_model_family()`: Determines the model family from model ID
- `read_text_file()`: Reads content from a text file

## Data Flow

1. Clinical notes are loaded from a CSV file
2. Notes are filtered based on user criteria
3. Filtered notes are saved as text files organized by MRN
4. AWS Bedrock is configured for AI processing
5. Notes are summarized individually or collectively
6. Summaries are saved as text files
7. Users can ask questions about the notes through the chat interface

## Enhancements

The pipeline has been enhanced with:
1. **Step-by-step navigation**: Users can access any step directly
2. **MRN-specific processing**: Notes can be organized and processed by patient
3. **Note filtering**: Advanced filtering capabilities for both summarization and chat
4. **Model selection**: Users can choose from different AI models
5. **Individual summary files**: Each summary is saved as a separate text file
6. **Improved error handling**: Better feedback and error messages

This pipeline provides a comprehensive solution for healthcare professionals to process, summarize, and interact with clinical notes, making it easier to extract insights from large volumes of clinical documentation.

In [11]:
# Import necessary libraries
import pandas as pd
import os
os.chdir("../src")
import html
import json
from pathlib import Path
from IPython.display import display, clear_output

# UI components
import ipywidgets as widgets
from tqdm.notebook import tqdm

# AWS libraries
import boto3
import subprocess

# Import prompts from external file
from summarization_prompts import (
    INDIVIDUAL_NOTE_PROMPTS,
    CHUNK_SUMMARIZATION_PROMPTS,
    FINAL_SUMMARIZATION_PROMPTS,
    COMBINED_NOTE_PROMPTS,
    QA_PROMPTS
)

# Global variables
df = pd.DataFrame()
mrn_filtered_df = pd.DataFrame()
encounter_filtered_df = pd.DataFrame()
final_filtered_df = pd.DataFrame()
notes_dir = ''
chat_context = None
bedrock_runtime = None
bedrock = None

## Data Loading and Filtering Functions

In [12]:
def initialize_extraction():
    """Initialize the data extraction process with UI components"""
    global df, notes_dir
    
    # Default paths - adjust these as needed
    data_csv_path = '../data/randome_asd_notes_cleaned3.csv'
    notes_dir = '../data/notes/extracted_notes_ind_mrn'
    
    # Create a file path input widget
    csv_path_input = widgets.Text(
        value=data_csv_path,
        placeholder='Path to CSV file',
        description='CSV Path:',
        layout={'width': '80%'},
        style={'description_width': 'initial'}
    )
    
    # Create an output directory input widget
    output_path_input = widgets.Text(
        value='../data/notes/extracted_notes_ind_mrn',
        placeholder='Path to output directory for notes',
        description='Notes Path:',
        layout={'width': '80%'},
        style={'description_width': 'initial'}
    )
    
    # Create a load button
    load_button = widgets.Button(
        description='Load CSV',
        button_style='primary'
    )
    
    # Create an output widget for messages
    load_output = widgets.Output()
    
    # Define the button click handler
    def on_load_button_click(b):
        global df, notes_dir
        with load_output:
            clear_output()
            try:
                # Set paths based on user input
                data_csv_path = csv_path_input.value
                notes_dir = output_path_input.value
                
                # Create output directory
                os.makedirs(notes_dir, exist_ok=True)
                
                # Load the CSV file
                print(f"Loading data from {data_csv_path}...")
                df = pd.read_csv(data_csv_path)
                print(f"Successfully loaded {len(df)} records from CSV.")
                print(f"Notes will be saved to: {notes_dir}")
                
                # Initialize the filtering widgets after successful load
                initialize_filtering_widgets()
            except Exception as e:
                print(f"Error loading CSV: {e}")
    
    # Attach the click handler
    load_button.on_click(on_load_button_click)
    
    # Display the widgets
    display(widgets.VBox([
        widgets.HTML("<h3>Step 1: Load Data</h3>"),
        csv_path_input,
        output_path_input,
        load_button,
        load_output
    ]))

def initialize_filtering_widgets():
    """Initialize widgets for filtering clinical notes"""
    # Step 1: MRN Search
    mrn_search = widgets.Text(
        value='',
        placeholder='Enter MRN',
        description='MRN:',
        disabled=False
    )
    mrn_search_button = widgets.Button(description='Search MRN')
    
    # Step 2: Encounter Type Filter
    encounter_type_dropdown = widgets.Dropdown(
        description='Encounter Type:',
        disabled=True
    )
    encounter_type_button = widgets.Button(
        description='Filter Encounter Type', 
        disabled=True
    )
    
    # Step 3: Event Description Filter
    event_desc_dropdown = widgets.Dropdown(
        description='Event Description:',
        disabled=True
    )
    event_desc_button = widgets.Button(
        description='Filter Event Description', 
        disabled=True
    )
    
    # Output widget
    output = widgets.Output()
    
    # Define the filtering functions
    def search_mrn(b):
        global mrn_filtered_df
        with output:
            clear_output()
            mrn = mrn_search.value
            if mrn:
                mrn_filtered_df = df[df['MRN'].str.contains(mrn, case=False, na=False)]
                print(f"Found {len(mrn_filtered_df)} records for MRN containing '{mrn}'")
                
                # Update encounter type dropdown
                encounter_type_dropdown.options = ['All'] + sorted(mrn_filtered_df['encounter_type'].unique().tolist())
                encounter_type_dropdown.disabled = False
                encounter_type_button.disabled = False
                
                display(mrn_filtered_df)
            else:
                print("Please enter an MRN")

    def filter_encounter_type(b):
        global encounter_filtered_df
        with output:
            clear_output()
            encounter_type = encounter_type_dropdown.value
            if encounter_type != 'All':
                encounter_filtered_df = mrn_filtered_df[mrn_filtered_df['encounter_type'] == encounter_type]
            else:
                encounter_filtered_df = mrn_filtered_df.copy()
            
            print(f"Found {len(encounter_filtered_df)} records for encounter type '{encounter_type}'")
            
            # Update event description dropdown
            event_desc_dropdown.options = ['All'] + sorted(encounter_filtered_df['event_desc'].unique().tolist())
            event_desc_dropdown.disabled = False
            event_desc_button.disabled = False
            
            display(encounter_filtered_df)

    def filter_event_desc(b):
        global final_filtered_df
        with output:
            clear_output()
            event_desc = event_desc_dropdown.value
            if event_desc != 'All':
                final_filtered_df = encounter_filtered_df[encounter_filtered_df['event_desc'] == event_desc]
            else:
                final_filtered_df = encounter_filtered_df.copy()
            
            print(f"Found {len(final_filtered_df)} records for event description '{event_desc}'")
            
            # Save blob contents as text and HTML files
            save_blob_contents(final_filtered_df)
            
            display(final_filtered_df)
    
    # Attach button click handlers
    mrn_search_button.on_click(search_mrn)
    encounter_type_button.on_click(filter_encounter_type)
    event_desc_button.on_click(filter_event_desc)
    
    # Display widgets
    display(widgets.VBox([
        widgets.HTML("<h3>Step 2: Filter Notes</h3>"),
        widgets.HBox([mrn_search, mrn_search_button]),
        widgets.HBox([encounter_type_dropdown, encounter_type_button]),
        widgets.HBox([event_desc_dropdown, event_desc_button]),
        output
    ]))

def save_blob_contents(filtered_df):
    """Save blob contents as text files organized by MRN"""
    saved_count = 0
    
    for _, row in filtered_df.iterrows():
        # Create directory for this MRN in the notes folder
        mrn_dir = os.path.join(notes_dir, str(row['MRN']))
        os.makedirs(mrn_dir, exist_ok=True)
        
        # Clean up any characters that might be problematic in filenames
        encounter_type = str(row['encounter_type']).replace('/', '_').replace('\\', '_').replace(' ', '_')
        med_service = str(row.get('med_service', 'unknown')).replace('/', '_').replace('\\', '_').replace(' ', '_')
        
        # Create a structured text file with all required fields
        structured_filename = f"{row['MRN']}_{encounter_type}_{med_service}_{row['note_index']}.txt"
        structured_filepath = os.path.join(mrn_dir, structured_filename)
        
        with open(structured_filepath, 'w') as f:
            # Write metadata
            f.write(f"ICD: {row.get('ICD', 'N/A')}\n")
            f.write(f"Encounter Type: {row.get('encounter_type', 'N/A')}\n")
            f.write(f"Medical Service: {row.get('med_service', 'N/A')}\n")
            f.write(f"Reason for Visit: {row.get('reason_for_visit', 'N/A')}\n")
            f.write(f"Encounter ID: {row.get('encntr_id', 'N/A')}\n")
            f.write(f"Event Description: {row.get('event_desc', 'N/A')}\n")
            f.write(f"Event End Datetime: {row.get('event_end_dt_tm', 'N/A')}\n")
            f.write(f"MRN: {row['MRN']}\n")
            f.write(f"Note Index: {row['note_index']}\n\n")
            
            # Write the clinical note content
            f.write("CLINICAL NOTE\n")
            f.write("=============\n\n")
            f.write(row['blob_content_clean'])
        
        saved_count += 1
    
    print(f"Saved {saved_count} structured text files in {notes_dir}, organized by MRN")
    
    # Return the MRN directory path for the next step (summarization)
    return notes_dir

## AWS Configuration Functions

In [13]:
def setup_aws_ui():
    """Create UI for AWS setup"""
    # Create input widgets
    profile_input = widgets.Text(
        value='plm-dev',
        placeholder='AWS SSO profile name',
        description='AWS Profile:',
        layout={'width': '50%'},
        style={'description_width': 'initial'}
    )
    
    region_input = widgets.Text(
        value='us-west-2',
        placeholder='AWS region',
        description='AWS Region:',
        layout={'width': '50%'},
        style={'description_width': 'initial'}
    )
    
    login_button = widgets.Button(
        description='Login with SSO',
        button_style='primary',
        layout={'width': 'auto'}
    )
    
    init_button = widgets.Button(
        description='Initialize AWS',
        button_style='success',
        layout={'width': 'auto'}
    )
    
    output = widgets.Output()
    
    # Define button handlers
    def on_login_button_click(b):
        with output:
            clear_output()
            sso_login(profile_input.value)
    
    def on_init_button_click(b):
        with output:
            clear_output()
            success = initialize_aws(profile_input.value, region_input.value)
            if success:
                # Initialize the summarization UI
                initialize_summarization_ui()
    
    # Attach handlers
    login_button.on_click(on_login_button_click)
    init_button.on_click(on_init_button_click)
    
    # Display widgets
    display(widgets.VBox([
        widgets.HTML("<h3>Step 3: Configure AWS Bedrock</h3>"),
        widgets.HTML("<p>Use your AWS SSO credentials to access Bedrock for note summarization</p>"),
        profile_input,
        region_input,
        widgets.HBox([login_button, init_button]),
        output
    ]))

def sso_login(profile_name="default"):
    """Attempt to login with AWS SSO"""
    print(f"Attempting to login with AWS SSO using profile '{profile_name}'...")
    result = subprocess.run(
        ["aws", "sso", "login", "--profile", profile_name],
        capture_output=True,
        text=True
    )
    
    if result.returncode == 0:
        print("✅ SSO login successful")
        return True
    else:
        print(f"❌ SSO login failed: {result.stderr}")
        print("Please run 'aws sso login --profile your-profile' in a terminal")
        return False

def initialize_aws(profile_name="default", region_name="us-west-2"):
    """Initialize AWS clients"""
    global bedrock_runtime, bedrock
    
    # Create a session with the profile
    session = boto3.Session(profile_name=profile_name, region_name=region_name)
    
    # Create Bedrock clients using the session
    try:
        bedrock_runtime = session.client('bedrock-runtime')
        bedrock = session.client('bedrock')
        
        # Test the connection
        models = bedrock.list_foundation_models()
        print("Available Bedrock models:")
        
        # Print first 10 models
        for model in models['modelSummaries'][:10]:
            print(f"- {model['modelId']}")
        print(f"... and {len(models['modelSummaries']) - 10} more models")
        print(f"\nTotal models available: {len(models['modelSummaries'])}")
        
        return True
    except Exception as e:
        print(f"❌ Could not initialize AWS: {e}")
        print("\nTroubleshooting steps:")
        print(f"1. Run 'aws sso login --profile {profile_name}' in a terminal")
        print(f"2. Check that your profile is configured with the correct region")
        print(f"3. Verify that your IAM role has permissions to access Bedrock")
        print(f"4. Make sure Bedrock is available in region '{region_name}'")
        return False

## File Handling and Note Processing Functions

In [14]:
def read_text_file(file_path):
    """Read the content of a text file."""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.read()
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return ""

def get_clinical_notes(directory_path):
    """Get all text files in a directory."""
    path = Path(directory_path)
    if not path.exists() or not path.is_dir():
        raise ValueError(f"Invalid directory path: {directory_path}")
        
    files = {}
    for file_path in path.glob("**/*.txt"):  # Search recursively for all .txt files
        files[file_path.name] = read_text_file(file_path)
        
    print(f"Found {len(files)} text files")
    return files

def combine_clinical_notes(clinical_notes):
    """Combine all clinical notes into a single text with file identifiers."""
    combined_text = ""
    
    for filename, content in clinical_notes.items():
        if content.strip():  # Skip empty files
            combined_text += f"\n\n--- CLINICAL NOTE: {filename} ---\n\n"
            combined_text += content
    
    return combined_text.strip()

def get_model_family(model_id):
    """Determine the model family from the model ID."""
    if model_id.startswith("anthropic.claude"):
        # Check if it's a Claude 3.5 model which uses the messages API
        if "claude-3-5" in model_id:
            return "anthropic.claude-3-5"
        return "anthropic.claude"
    elif model_id.startswith("amazon.titan"):
        return "amazon.titan"
    elif model_id.startswith("meta.llama"):
        return "meta.llama"
    else:
        return "unknown"

## Chat Interface Functions

In [15]:
def initialize_chat_ui(directory_path, model_id):
    """Initialize the chat interface to ask questions about clinical notes,
    and populate the query field from predefined healthcare provider questions.
    
    This updated version defines the missing `chat_with_clinical_notes` function locally,
    which uses QA prompt templates to process user queries.
    """
    global chat_context, bedrock_runtime

    import os
    import json
    from pathlib import Path
    import ipywidgets as widgets
    from IPython.display import display, clear_output
    from summarization_prompts import QA_PROMPTS
    from healthcare_provider_questions import QUESTION_LIBRARY

    # A helper function to determine model family.
    def get_model_family(model_id):
        if model_id.startswith("anthropic.claude"):
            # Check if it's a Claude 3.5 model which uses the messages API
            if "claude-3-5" in model_id:
                return "anthropic.claude-3-5"
            return "anthropic.claude"
        elif model_id.startswith("amazon.titan"):
            return "amazon.titan"
        elif model_id.startswith("meta.llama"):
            return "meta.llama"
        else:
            return "unknown"

    # Define the missing chat function
    def chat_with_clinical_notes(query, context, model_id="anthropic.claude-v2"):
        try:
            model_family = get_model_family(model_id)
            if model_family not in QA_PROMPTS:
                # If using a Claude 3.5 model, fall back to standard Claude prompts
                if model_family == "anthropic.claude-3-5":
                    prompt_template = QA_PROMPTS["anthropic.claude"]
                else:
                    return f"Model family {model_family} not supported for Q&A"
            else:
                prompt_template = QA_PROMPTS[model_family]
                
            # Format the prompt with the provided clinical notes context and query
            formatted_prompt = prompt_template.format(context=context, query=query)
            
            # Prepare and invoke the model request based on the model family
            if model_family == "anthropic.claude-3-5":
                # Claude 3.5 models use the messages API format
                body = json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 1000,
                    "temperature": 0,
                    "top_p": 0.9,
                    "messages": [
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": formatted_prompt
                                }
                            ]
                        }
                    ]
                })
                response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
                response_body = json.loads(response.get("body").read())
                return response_body.get("content", [{}])[0].get("text", "")
                
            elif model_id.startswith("anthropic.claude"):
                body = json.dumps({
                    "prompt": f"\n\nHuman: {formatted_prompt}\n\nAssistant:",
                    "max_tokens_to_sample": 1000,
                    "temperature": 0,
                    "top_p": 0.9
                })
                response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
                response_body = json.loads(response.get("body").read())
                return response_body.get("completion", "")
                
            elif model_id.startswith("amazon.titan"):
                body = json.dumps({
                    "inputText": formatted_prompt,
                    "textGenerationConfig": {
                        "maxTokenCount": 1000,
                        "temperature": 0,
                        "topP": 0.9
                    }
                })
                response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
                response_body = json.loads(response.get("body").read())
                return response_body.get("results", [{}])[0].get("outputText", "")
                
            elif model_id.startswith("meta.llama"):
                body = json.dumps({
                    "prompt": formatted_prompt,
                    "max_gen_len": 1000,
                    "temperature": 0,
                    "top_p": 0.9
                })
                response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
                response_body = json.loads(response.get("body").read())
                return response_body.get("generation", "")
                
            else:
                return f"Model {model_id} not supported for Q&A"
                
        except Exception as e:
            return f"Error processing Q&A: {str(e)}"

    # Extract MRN from directory_path if it exists.
    base_path = directory_path
    mrn = ""
    path_parts = directory_path.split(os.sep)
    if len(path_parts) > 1:
        potential_mrn = path_parts[-1]
        if potential_mrn and not potential_mrn.startswith('.'):
            mrn = potential_mrn
            base_path = os.sep.join(path_parts[:-1])

    # Popular Bedrock models.
    popular_models = [
        "anthropic.claude-v2",
        "anthropic.claude-3-sonnet-20240229-v1:0",
        "anthropic.claude-3-haiku-20240307-v1:0",
        "anthropic.claude-3-5-sonnet-20241022-v2:0",  # New model
        "anthropic.claude-3-5-haiku-20241022-v1:0",   # New model
        "meta.llama3-70b-instruct-v1:0",
        "meta.llama3-3-70b-instruct-v1:0",            # New model
        "amazon.titan-text-express-v1"
    ]

    # Model selector dropdown.
    model_dropdown = widgets.Dropdown(
        options=popular_models,
        value=model_id if model_id in popular_models else popular_models[0],
        description="Model:",
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )

    # Create provider dropdown from QUESTION_LIBRARY keys.
    provider_options = list(QUESTION_LIBRARY.keys())
    provider_dropdown = widgets.Dropdown(
        options=provider_options,
        value=provider_options[0] if provider_options else None,
        description="Provider:",
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )

    # Initialize question set dropdown based on selected provider.
    qs_dict = QUESTION_LIBRARY.get(provider_dropdown.value, {})
    qs_options = list(qs_dict.keys()) if qs_dict else []
    question_set_dropdown = widgets.Dropdown(
        options=qs_options,
        value=qs_options[0] if qs_options else None,
        description="Question Set:",
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )

    # Initialize specific question dropdown based on the question set.
    initial_questions = qs_dict.get(question_set_dropdown.value, []) if qs_dict else []
    question_dropdown = widgets.Dropdown(
        options=initial_questions,
        value=initial_questions[0] if initial_questions else "",
        description="Select Question:",
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )

    # Observers to update question set and question options.
    def on_provider_change(change):
        new_provider = change['new']
        qs = QUESTION_LIBRARY.get(new_provider, {})
        new_qs_options = list(qs.keys()) if qs else []
        question_set_dropdown.options = new_qs_options
        if new_qs_options:
            first_set = new_qs_options[0]
            question_set_dropdown.value = first_set
            question_dropdown.options = qs[first_set]
            question_dropdown.value = qs[first_set][0] if qs[first_set] else ""
        else:
            question_set_dropdown.options = []
            question_dropdown.options = []
            question_dropdown.value = ""

    def on_question_set_change(change):
        current_provider = provider_dropdown.value
        qs = QUESTION_LIBRARY.get(current_provider, {})
        selected_set = change['new']
        if selected_set and selected_set in qs:
            question_dropdown.options = qs[selected_set]
            question_dropdown.value = qs[selected_set][0] if qs[selected_set] else ""
        else:
            question_dropdown.options = []
            question_dropdown.value = ""

    def on_question_select(change):
        if change['new']:
            query_input.value = change['new']

    provider_dropdown.observe(on_provider_change, names="value")
    question_set_dropdown.observe(on_question_set_change, names="value")
    question_dropdown.observe(on_question_select, names="value")

    # Create UI elements for the chat interface.
    chat_header = widgets.HTML("<h3>Step 5: Chat with Clinical Notes</h3>")
    chat_description = widgets.HTML(
        "<p>Ask questions about the clinical notes. Responses will be based on the content of the notes.</p>"
    )
    base_path_input = widgets.Text(
        value=base_path,
        placeholder='Enter base path to directory with clinical notes',
        description='Base Notes Path:',
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )
    mrn_input = widgets.Text(
        value=mrn,
        placeholder='Enter MRN',
        description='MRN:',
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )
    include_filter = widgets.Text(
        value='',
        placeholder='Enter keywords to include (comma-separated)',
        description='Include notes with:',
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )
    exclude_filter = widgets.Text(
        value='',
        placeholder='Enter keywords to exclude (comma-separated)',
        description='Exclude notes with:',
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )
    list_notes_button = widgets.Button(
        description='List Available Notes',
        button_style='info',
        tooltip='List available notes for the given path/MRN'
    )
    notes_output = widgets.Output()
    load_notes_button = widgets.Button(
        description='Load Filtered Notes',
        button_style='primary',
        tooltip='Load filtered clinical notes for chat'
    )
    loading_output = widgets.Output()
    query_input = widgets.Text(
        value='',
        placeholder='Ask a question about the clinical notes',
        description='Query:',
        style={'description_width': 'initial'},
        layout={'width': '90%'}
    )
    chat_button = widgets.Button(
        description='Ask',
        button_style='success',
        tooltip='Submit your question'
    )
    chat_history = widgets.Output()
    clear_chat_button = widgets.Button(
        description='Clear Chat',
        button_style='warning',
        tooltip='Clear chat history'
    )
    example_output = widgets.Output()
    with example_output:
        show_example_questions()

    # Function to list available notes.
    def list_available_notes(b):
        with notes_output:
            clear_output()
            base_path = base_path_input.value
            mrn_val = mrn_input.value.strip()
            actual_path = os.path.join(base_path, mrn_val) if mrn_val else base_path
            if not os.path.exists(actual_path):
                print(f"Directory not found: {actual_path}")
                return
            files = list(Path(actual_path).glob("*.txt"))
            if not files:
                print(f"No text files found in {actual_path}")
                return
            print(f"Found {len(files)} text files in {actual_path}:")
            for i, file_path in enumerate(files, 1):
                print(f"{i}. {file_path.name}")
            print("\nUse the filter fields above to include or exclude notes based on keywords.")

    # Function to filter notes.
    def filter_notes(notes_dict, include_keywords, exclude_keywords):
        if not (include_keywords or exclude_keywords):
            return notes_dict
        filtered_notes = {}
        for filename, content in notes_dict.items():
            if exclude_keywords:
                if any(keyword.lower() in filename.lower() or keyword.lower() in content.lower() for keyword in exclude_keywords):
                    continue
            if include_keywords:
                if any(keyword.lower() in filename.lower() or keyword.lower() in content.lower() for keyword in include_keywords):
                    filtered_notes[filename] = content
            else:
                filtered_notes[filename] = content
        return filtered_notes

    # Function to load notes into the chat context.
    def load_notes(b):
        global chat_context
        with loading_output:
            clear_output()
            base_path_val = base_path_input.value
            mrn_val = mrn_input.value.strip()
            include_keywords = [k.strip() for k in include_filter.value.split(',') if k.strip()]
            exclude_keywords = [k.strip() for k in exclude_filter.value.split(',') if k.strip()]
            actual_path = os.path.join(base_path_val, mrn_val) if mrn_val else base_path_val
            print(f"Loading clinical notes from {actual_path}...")
            try:
                all_clinical_notes = get_clinical_notes(actual_path)
                if not all_clinical_notes:
                    print(f"No clinical notes found in {actual_path}")
                    return
                if include_keywords or exclude_keywords:
                    print(f"Applying filters - Include: {include_keywords}, Exclude: {exclude_keywords}")
                    clinical_notes = filter_notes(all_clinical_notes, include_keywords, exclude_keywords)
                    print(f"Filtered from {len(all_clinical_notes)} to {len(clinical_notes)} notes")
                else:
                    clinical_notes = all_clinical_notes
                if not clinical_notes:
                    print("No notes remain after filtering. Please adjust your filter criteria.")
                    return
                chat_context = combine_clinical_notes(clinical_notes)
                print(f"Successfully loaded {len(clinical_notes)} notes. You can now ask questions.")
                if clinical_notes:
                    print("\nLoaded files:")
                    for i, filename in enumerate(list(clinical_notes.keys())[:5], 1):
                        print(f"{i}. {filename}")
                    if len(clinical_notes) > 5:
                        print(f"... and {len(clinical_notes) - 5} more files")
            except Exception as e:
                print(f"Error loading notes: {e}")
                import traceback
                traceback.print_exc()

    # Chat submission handler.
    def on_ask_button_click(b):
        query = query_input.value
        if not query:
            return
        if 'chat_context' not in globals() or not chat_context:
            with chat_history:
                print("Please load the clinical notes first using the 'Load Filtered Notes' button")
            return
        selected_model = model_dropdown.value
        with chat_history:
            print(f"\n🙋 You: {query}")
            print("\n🤖 Assistant: ", end="")
            response = chat_with_clinical_notes(query, chat_context, model_id=selected_model)
            print(response)
        query_input.value = ''

    # Clear chat handler.
    def clear_chat(b):
        with chat_history:
            chat_history.clear_output()
            print("Chat history cleared.")

    list_notes_button.on_click(list_available_notes)
    load_notes_button.on_click(load_notes)
    chat_button.on_click(on_ask_button_click)
    clear_chat_button.on_click(clear_chat)

    # Display the chat interface.
    display(widgets.VBox([
        chat_header,
        chat_description,
        widgets.HTML("<h4>Document Selection</h4>"),
        base_path_input,
        mrn_input,
        list_notes_button,
        notes_output,
        widgets.HTML("<h4>Filter Options</h4>"),
        include_filter,
        exclude_filter,
        load_notes_button,
        loading_output,
        widgets.HTML("<h4>Predefined Questions</h4>"),
        provider_dropdown,
        question_set_dropdown,
        question_dropdown,
        widgets.HTML("<h4>Chat Interface</h4>"),
        widgets.HTML("<p>Select a model to use for answering questions, or type your own query below:</p>"),
        model_dropdown,
        widgets.HBox([query_input, chat_button]),
        chat_history,
        clear_chat_button,
        example_output
    ]))

## Summarization Functions

In [16]:
def get_model_family(model_id):
    """Determine the model family from the model ID."""
    if model_id.startswith("anthropic.claude"):
        # Check if it's a Claude 3.5 model which uses the messages API
        if "claude-3-5" in model_id:
            return "anthropic.claude-3-5"
        return "anthropic.claude"
    elif model_id.startswith("amazon.titan"):
        return "amazon.titan"
    elif model_id.startswith("meta.llama"):
        return "meta.llama"
    else:
        return "unknown"

def summarize_with_bedrock(text, model_id="anthropic.claude-v2", max_tokens=1000):
    """Generate a summary of the provided text using AWS Bedrock."""
    try:
        # Import prompts from the summarization_prompts module
        from summarization_prompts import INDIVIDUAL_NOTE_PROMPTS
        
        # Determine the model family
        model_family = get_model_family(model_id)
        
        # Get the appropriate prompt template
        if model_family in INDIVIDUAL_NOTE_PROMPTS:
            prompt_template = INDIVIDUAL_NOTE_PROMPTS[model_family]
        else:
            # If not found, try the base Claude family
            if model_family == "anthropic.claude-3-5":
                prompt_template = INDIVIDUAL_NOTE_PROMPTS["anthropic.claude"]
            else:
                return f"Model family {model_family} not supported for summarization"
        
        # Format the prompt with the text
        formatted_prompt = prompt_template.format(text=text)
        
        # Different models have different payload formats
        if model_family == "anthropic.claude-3-5":
            # Claude 3.5 models use the messages API format
            body = json.dumps({
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": max_tokens,
                "temperature": 0,
                "top_p": 0.9,
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": formatted_prompt
                            }
                        ]
                    }
                ]
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("content", [{}])[0].get("text", "")
            
        elif model_id.startswith("anthropic.claude"):
            # Standard Claude models (v2, Claude 3 Haiku/Sonnet)
            body = json.dumps({
                "prompt": f"\n\nHuman: {formatted_prompt}\n\nAssistant:",
                "max_tokens_to_sample": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("completion", "")
        
        elif model_id.startswith("amazon.titan"):
            # Amazon Titan models
            body = json.dumps({
                "inputText": formatted_prompt,
                "textGenerationConfig": {
                    "maxTokenCount": max_tokens,
                    "temperature": 0,
                    "topP": 0.9
                }
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("results", [{}])[0].get("outputText", "")
            
        elif model_id.startswith("meta.llama"):
            # Meta Llama models
            body = json.dumps({
                "prompt": formatted_prompt,
                "max_gen_len": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("generation", "")
        
        else:
            return "Model not supported for summarization"
        
    except Exception as e:
        return f"Error generating summary: {str(e)}"

def summarize_chunk(text, model_id, max_tokens):
    """Summarize a single chunk of notes."""
    try:
        # Import prompts from the summarization_prompts module
        from summarization_prompts import CHUNK_SUMMARIZATION_PROMPTS
        
        # Determine the model family
        model_family = get_model_family(model_id)
        
        # Get the appropriate prompt template
        if model_family in CHUNK_SUMMARIZATION_PROMPTS:
            prompt_template = CHUNK_SUMMARIZATION_PROMPTS[model_family]
        else:
            # If not found, try the base Claude family
            if model_family == "anthropic.claude-3-5":
                prompt_template = CHUNK_SUMMARIZATION_PROMPTS["anthropic.claude"]
            else:
                return f"Model family {model_family} not supported for summarization"
        
        # Format the prompt with the text
        formatted_prompt = prompt_template.format(text=text)
        
        # Handle different model types
        if model_family == "anthropic.claude-3-5":
            # Claude 3.5 models use the messages API format
            body = json.dumps({
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": max_tokens,
                "temperature": 0,
                "top_p": 0.9,
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": formatted_prompt
                            }
                        ]
                    }
                ]
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("content", [{}])[0].get("text", "")
            
        elif model_id.startswith("anthropic.claude"):
            body = json.dumps({
                "prompt": f"\n\nHuman: {formatted_prompt}\n\nAssistant:",
                "max_tokens_to_sample": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("completion", "")
        
        elif model_id.startswith("amazon.titan"):
            body = json.dumps({
                "inputText": formatted_prompt,
                "textGenerationConfig": {
                    "maxTokenCount": max_tokens,
                    "temperature": 0,
                    "topP": 0.9
                }
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("results", [{}])[0].get("outputText", "")
            
        elif model_id.startswith("meta.llama"):
            body = json.dumps({
                "prompt": formatted_prompt,
                "max_gen_len": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("generation", "")
        
        else:
            return f"Model {model_id} not supported for summarization"
            
    except Exception as e:
        return f"Error generating chunk summary: {str(e)}"

def summarize_final(combined_summaries, model_id, max_tokens):
    """Create a final summary from multiple chunk summaries."""
    try:
        # Import prompts from the summarization_prompts module
        from summarization_prompts import FINAL_SUMMARIZATION_PROMPTS
        
        # Determine the model family
        model_family = get_model_family(model_id)
        
        # Get the appropriate prompt template
        if model_family in FINAL_SUMMARIZATION_PROMPTS:
            prompt_template = FINAL_SUMMARIZATION_PROMPTS[model_family]
        else:
            return f"Model family {model_family} not supported for summarization"
        
        # Format the prompt with the text
        formatted_prompt = prompt_template.format(text=combined_summaries)
        
        # Handle different model types
        if model_id.startswith("anthropic.claude"):
            body = json.dumps({
                "prompt": f"\n\nHuman: {formatted_prompt}\n\nAssistant:",
                "max_tokens_to_sample": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("completion", "")
        
        elif model_id.startswith("amazon.titan"):
            body = json.dumps({
                "inputText": formatted_prompt,
                "textGenerationConfig": {
                    "maxTokenCount": max_tokens,
                    "temperature": 0,
                    "topP": 0.9
                }
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("results", [{}])[0].get("outputText", "")
            
        elif model_id.startswith("meta.llama"):
            body = json.dumps({
                "prompt": formatted_prompt,
                "max_gen_len": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("generation", "")
        
        else:
            return f"Model {model_id} not supported for summarization"
            
    except Exception as e:
        return f"Error generating final summary: {str(e)}"

def summarize_combined_notes(combined_text, model_id="anthropic.claude-v2", max_tokens=2000):
    """Generate a comprehensive summary of combined clinical notes."""
    try:
        # Import prompts from the summarization_prompts module
        from summarization_prompts import COMBINED_NOTE_PROMPTS
        
        # Determine the model family
        model_family = get_model_family(model_id)
        
        # Get the appropriate prompt template
        if model_family in COMBINED_NOTE_PROMPTS:
            prompt_template = COMBINED_NOTE_PROMPTS[model_family]
        else:
            return f"Model family {model_family} not supported for summarization"
        
        # Format the prompt with the text
        formatted_prompt = prompt_template.format(text=combined_text)
        
        # Handle different model types
        if model_id.startswith("anthropic.claude"):
            body = json.dumps({
                "prompt": f"\n\nHuman: {formatted_prompt}\n\nAssistant:",
                "max_tokens_to_sample": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("completion", "")
        
        elif model_id.startswith("amazon.titan"):
            body = json.dumps({
                "inputText": formatted_prompt,
                "textGenerationConfig": {
                    "maxTokenCount": max_tokens,
                    "temperature": 0,
                    "topP": 0.9
                }
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("results", [{}])[0].get("outputText", "")
            
        elif model_id.startswith("meta.llama"):
            body = json.dumps({
                "prompt": formatted_prompt,
                "max_gen_len": max_tokens,
                "temperature": 0,
                "top_p": 0.9
            })
            response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
            response_body = json.loads(response.get("body").read())
            return response_body.get("generation", "")
        
        else:
            return "Model not supported for summarization"
        
    except Exception as e:
        return f"Error generating summary: {str(e)}"

def summarize_combined_notes_chunked(clinical_notes, model_id="anthropic.claude-v2", max_tokens=2000, max_input_tokens=7500):
    """Generate a comprehensive summary by processing notes in chunks to avoid token limits."""
    # Sort notes by filename to ensure consistency
    sorted_filenames = sorted(clinical_notes.keys())
    
    # Estimate characters per token (rough approximation)
    chars_per_token = 4  # This is an approximation, varies by model
    
    # Initialize variables
    all_summaries = []
    current_chunk = ""
    current_chunk_size = 0
    chunk_number = 1
    
    print(f"Processing {len(sorted_filenames)} notes in chunks...")
    
    # Process notes in chunks
    for filename in sorted_filenames:
        content = clinical_notes[filename]
        if not content.strip():
            continue
            
        note_text = f"\n\n--- CLINICAL NOTE: {filename} ---\n\n{content}"
        note_size = len(note_text) // chars_per_token
        
        # If adding this note would exceed limit, process current chunk
        if current_chunk_size + note_size > max_input_tokens and current_chunk:
            print(f"Processing chunk {chunk_number} (approx. {current_chunk_size} tokens)...")
            chunk_summary = summarize_chunk(current_chunk, model_id, max_tokens)
            all_summaries.append(chunk_summary)
            current_chunk = note_text
            current_chunk_size = note_size
            chunk_number += 1
        else:
            # Add to current chunk
            current_chunk += note_text
            current_chunk_size += note_size
    
    # Process final chunk if it exists
    if current_chunk:
        print(f"Processing final chunk {chunk_number} (approx. {current_chunk_size} tokens)...")
        chunk_summary = summarize_chunk(current_chunk, model_id, max_tokens)
        all_summaries.append(chunk_summary)
    
    # If we have multiple chunk summaries, summarize them together
    if len(all_summaries) > 1:
        print("Generating final summary from all chunk summaries...")
        combined_summaries = "\n\n--- CHUNK SUMMARIES ---\n\n" + "\n\n".join(all_summaries)
        final_summary = summarize_final(combined_summaries, model_id, max_tokens)
        return final_summary
    elif all_summaries:
        return all_summaries[0]
    else:
        return "No valid content found to summarize."

## Summarization UI Initialization

In [17]:
def initialize_summarization_ui():
    """Initialize the UI for summarization"""
    # Define commonly used Bedrock models
    def initialize_summarization_ui():
        """Initialize the UI for summarization"""
    # Define commonly used Bedrock models
    popular_models = [
        "anthropic.claude-v2",
        "anthropic.claude-3-sonnet-20240229-v1:0",
        "anthropic.claude-3-haiku-20240307-v1:0",
        "anthropic.claude-3-5-sonnet-20241022-v2:0",  # New model
        "anthropic.claude-3-5-haiku-20241022-v1:0",   # New model
        "meta.llama3-70b-instruct-v1:0",
        "meta.llama3-3-70b-instruct-v1:0",            # New model
        "amazon.titan-text-express-v1"
    ]
    
    # Model selector dropdown
    model_dropdown = widgets.Dropdown(
        options=popular_models,
        value="anthropic.claude-v2",
        description="Model:",
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )
    
    # Base directory path input (defaulting to the notes directory from extraction)
    base_path_input = widgets.Text(
        value=notes_dir if notes_dir else '../out/notes',
        placeholder='Enter base path to directory with clinical notes',
        description='Base Notes Path:',
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )
    
    # MRN input field
    mrn_input = widgets.Text(
        value='',
        placeholder='Enter MRN to process specific patient notes',
        description='MRN (optional):',
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )
    
    # Summary type selector (individual or combined)
    summary_type = widgets.RadioButtons(
        options=['Individual summaries', 'Combined summary'],
        value='Individual summaries',
        description='Summary type:',
        style={'description_width': 'initial'}
    )
    
    # Max tokens slider
    max_tokens = widgets.IntSlider(
        value=2000,
        min=500,
        max=4000,
        step=500,
        description='Max tokens:',
        style={'description_width': 'initial'},
        layout={'width': '50%'}
    )
    
    # Filter options
    include_filter = widgets.Text(
        value='',
        placeholder='Enter keywords to include (comma-separated)',
        description='Include notes with:',
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )
    
    exclude_filter = widgets.Text(
        value='',
        placeholder='Enter keywords to exclude (comma-separated)',
        description='Exclude notes with:',
        style={'description_width': 'initial'},
        layout={'width': '80%'}
    )
    
    # List notes button
    list_notes_button = widgets.Button(
        description='List Available Notes',
        button_style='info',
        tooltip='List available notes for the given path/MRN'
    )
    
    # Notes selection output
    notes_output = widgets.Output()
    
    # Create the processing button
    process_button = widgets.Button(
        description='Generate Summaries',
        button_style='primary',
        tooltip='Click to process clinical notes'
    )
    
    # Create output area
    output = widgets.Output()
    
    # Function to list available notes
    def list_available_notes(b):
        with notes_output:
            clear_output()
            
            # Get base path and MRN
            base_path = base_path_input.value
            mrn = mrn_input.value.strip()
            
            # Determine the actual path to use
            if mrn:
                actual_path = os.path.join(base_path, mrn)
            else:
                actual_path = base_path
            
            try:
                # Check if directory exists
                if not os.path.exists(actual_path):
                    print(f"Directory not found: {actual_path}")
                    return
                
                # Get all text files in the directory
                files = list(Path(actual_path).glob("*.txt"))
                
                if not files:
                    print(f"No text files found in {actual_path}")
                    return
                
                print(f"Found {len(files)} text files in {actual_path}:")
                for i, file_path in enumerate(files, 1):
                    print(f"{i}. {file_path.name}")
                
                # Show filter instructions
                print("\nUse the filter fields above to include or exclude notes based on keywords.")
                print("For example, enter 'Neurology, Psychiatry' in the include field to only process notes containing those terms.")
                print("Enter 'temp, draft' in the exclude field to skip notes containing those terms.")
            
            except Exception as e:
                print(f"Error listing notes: {e}")
    
    # Function to filter notes based on include/exclude keywords
    def filter_notes(notes_dict, include_keywords, exclude_keywords):
        """Filter notes based on include and exclude keywords"""
        if not (include_keywords or exclude_keywords):
            return notes_dict  # No filtering needed
        
        filtered_notes = {}
        
        for filename, content in notes_dict.items():
            # Check exclude keywords first (if any keyword matches, skip this note)
            if exclude_keywords:
                if any(keyword.lower() in filename.lower() or 
                       keyword.lower() in content.lower() 
                       for keyword in exclude_keywords):
                    continue
            
            # Then check include keywords (if any are specified, at least one must match)
            if include_keywords:
                if any(keyword.lower() in filename.lower() or 
                       keyword.lower() in content.lower() 
                       for keyword in include_keywords):
                    filtered_notes[filename] = content
            else:
                # If no include keywords specified but passed exclude filter, include it
                filtered_notes[filename] = content
        
        return filtered_notes
    
    # Define the processing function
    def process_clinical_notes(b):
        with output:
            clear_output()
            
            # Get user inputs
            base_path = base_path_input.value
            mrn = mrn_input.value.strip()
            model_id = model_dropdown.value
            summary_mode = summary_type.value
            token_limit = max_tokens.value
            
            # Process include/exclude keywords
            include_keywords = [k.strip() for k in include_filter.value.split(',') if k.strip()]
            exclude_keywords = [k.strip() for k in exclude_filter.value.split(',') if k.strip()]
            
            # Determine the actual path to use
            if mrn:
                directory_path = os.path.join(base_path, mrn)
            else:
                directory_path = base_path
            
            # Validate input
            if not directory_path:
                print("Please enter a valid directory path")
                return
            
            try:
                # Get all clinical notes from the directory
                print(f"Loading clinical notes from {directory_path}...")
                all_clinical_notes = get_clinical_notes(directory_path)
                
                if not all_clinical_notes:
                    print(f"No clinical notes found in {directory_path}")
                    return
                
                # Apply filters if specified
                if include_keywords or exclude_keywords:
                    print(f"Applying filters - Include: {include_keywords}, Exclude: {exclude_keywords}")
                    clinical_notes = filter_notes(all_clinical_notes, include_keywords, exclude_keywords)
                    print(f"Filtered from {len(all_clinical_notes)} to {len(clinical_notes)} notes")
                else:
                    clinical_notes = all_clinical_notes
                
                if not clinical_notes:
                    print("No notes remain after filtering. Please adjust your filter criteria.")
                    return
                
                if summary_mode == 'Individual summaries':
                    # Create a DataFrame to store results
                    results = []
                    
                    # Process each clinical note individually
                    print(f"Generating individual summaries using {model_id}...")
                    for filename, note_text in tqdm(clinical_notes.items(), desc="Processing"):
                        # Check if text is not empty
                        if note_text.strip():
                            # Generate summary
                            summary = summarize_with_bedrock(note_text, model_id=model_id, max_tokens=1000)
                            
                            # Save the summary to a text file with the same name + _individual_summary
                            summary_filename = os.path.splitext(filename)[0] + "_individual_summary.txt"
                            summary_filepath = os.path.join(directory_path, summary_filename)
                            
                            with open(summary_filepath, 'w', encoding='utf-8') as f:
                                f.write(f"SUMMARY OF {filename}\n")
                                f.write(f"Generated using {model_id}\n\n")
                                f.write(summary)
                            
                            # Append result to DataFrame
                            results.append({
                                "filename": filename,
                                "summary_filename": summary_filename,
                                "original_text": note_text[:100] + "..." if len(note_text) > 100 else note_text,
                                "summary": summary
                            })
                        else:
                            print(f"Skipping empty file: {filename}")
                    
                    # Create DataFrame
                    results_df = pd.DataFrame(results)
                    
                    # Save the results to a CSV file in the same directory
                    output_path = os.path.join(directory_path, "clinical_summaries_individual.csv")
                    results_df.to_csv(output_path, index=False)
                    print(f"Individual summaries saved to {directory_path}")
                    print(f"Summary index saved to {output_path}")
                    
                    # Display the DataFrame
                    display(results_df)
                    
                else:  # Combined summary
                    # Generate summary using chunking approach
                    print(f"\nGenerating combined summary using {model_id}...")
                    comprehensive_summary = summarize_combined_notes_chunked(
                        clinical_notes, 
                        model_id=model_id, 
                        max_tokens=token_limit,
                        max_input_tokens=7500  # Conservative limit to avoid errors
                    )
                    
                    # Save the combined summary to a text file
                    output_path = os.path.join(directory_path, "clinical_summary_combined.txt")
                    with open(output_path, 'w', encoding='utf-8') as f:
                        f.write(f"COMBINED SUMMARY OF {len(clinical_notes)} CLINICAL NOTES\n")
                        f.write(f"Generated using {model_id}\n\n")
                        f.write(comprehensive_summary)
                        
                    print(f"\nCombined summary saved to {output_path}")
                    
                    # Display the summary
                    print("\n==== COMPREHENSIVE SUMMARY ====")
                    print(comprehensive_summary)
                    
                # Prompt the user to proceed to the chat interface
                print("\nSummaries have been generated. You can now proceed to the chat interface to ask questions about the notes.")
                initialize_chat_ui(directory_path, model_id)
                    
            except Exception as e:
                print(f"Error processing clinical notes: {e}")
                import traceback
                traceback.print_exc()
    
    # Connect the button click events
    list_notes_button.on_click(list_available_notes)
    process_button.on_click(process_clinical_notes)
    
    # Display the widgets
    display(widgets.VBox([
        widgets.HTML("<h3>Step 4: Generate Summaries</h3>"),
        widgets.HTML("<p>Configure the summarization process below:</p>"),
        base_path_input,
        mrn_input,
        list_notes_button,
        notes_output,
        widgets.HTML("<h4>Filter Options</h4>"),
        include_filter,
        exclude_filter,
        widgets.HTML("<h4>Summarization Options</h4>"),
        model_dropdown,
        summary_type,
        max_tokens,
        process_button,
        output
    ]))

## Main Pipeline Execution

In [18]:
# Global variables for chat context and AWS clients
chat_context = None
bedrock_runtime = None
bedrock = None

# Main Pipeline Execution
def run_pipeline():
    """Initialize the complete clinical notes pipeline with step buttons"""
    # Create step buttons
    step1_button = widgets.Button(
        description='Step 1: Load Data',
        button_style='primary',
        tooltip='Extract notes from CSV',
        layout={'width': 'auto', 'margin': '10px 0px'}
    )
    
    step2_button = widgets.Button(
        description='Step 2: Filter Notes',
        button_style='primary',
        tooltip='Filter extracted notes',
        layout={'width': 'auto', 'margin': '10px 0px'}
    )
    
    step3_button = widgets.Button(
        description='Step 3: Configure AWS',
        button_style='primary',
        tooltip='Set up AWS Bedrock',
        layout={'width': 'auto', 'margin': '10px 0px'}
    )
    
    step4_button = widgets.Button(
        description='Step 4: Generate Summaries',
        button_style='primary',
        tooltip='Summarize clinical notes',
        layout={'width': 'auto', 'margin': '10px 0px'}
    )
    
    step5_button = widgets.Button(
        description='Step 5: Chat with Notes',
        button_style='primary',
        tooltip='Ask questions about the notes',
        layout={'width': 'auto', 'margin': '10px 0px'}
    )
    
    # Step output area
    step_output = widgets.Output()
    
    # Connect button click handlers
    step1_button.on_click(lambda b: run_step1(step_output))
    step2_button.on_click(lambda b: run_step2(step_output))
    step3_button.on_click(lambda b: run_step3(step_output))
    step4_button.on_click(lambda b: run_step4(step_output))
    step5_button.on_click(lambda b: run_step5(step_output))
    
    # Display the pipeline interface
    display(widgets.HTML("<h2>Clinical Notes Pipeline</h2>"))
    display(widgets.HTML("<p>Click on each step button to execute that part of the pipeline:</p>"))
    display(widgets.VBox([
        step1_button,
        step2_button,
        step3_button,
        step4_button,
        step5_button
    ]))
    display(step_output)

# Step execution functions
def run_step1(output):
    """Run Step 1: Load Data"""
    with output:
        clear_output()
        print("Executing Step 1: Load Data...")
        initialize_extraction()

def run_step2(output):
    """Run Step 2: Filter Notes"""
    with output:
        clear_output()
        print("Executing Step 2: Filter Notes...")
        initialize_filtering_widgets()

def run_step3(output):
    """Run Step 3: Configure AWS"""
    with output:
        clear_output()
        print("Executing Step 3: Configure AWS...")
        setup_aws_ui()

def run_step4(output):
    """Run Step 4: Generate Summaries"""
    with output:
        clear_output()
        print("Executing Step 4: Generate Summaries...")
        initialize_summarization_ui()

def run_step5(output):
    """Run Step 5: Chat with Notes"""
    with output:
        clear_output()
        print("Executing Step 5: Chat with Notes...")
        # Use the last directory path from summarization, or default to notes_dir
        directory_path = notes_dir
        # Use a default model if none was selected
        model_id = "anthropic.claude-v2"
        initialize_chat_ui(directory_path, model_id)

# Example questions for the chat interface
def show_example_questions():
    """Display example questions that can be asked about the clinical notes"""
    example_questions = [
        "What are the primary diagnoses mentioned in the notes?",
        "What medications are prescribed to the patients?",
        "Are there any allergies mentioned in the notes?",
        "What treatments have been recommended?",
        "What symptoms are most frequently mentioned?",
        "Were any surgical procedures performed?",
        "What is the patient's medical history?",
        "What follow-up recommendations were given?"
    ]
    
    print("Example questions you can ask about the clinical notes:")
    for i, question in enumerate(example_questions, 1):
        print(f"{i}. {question}")

# Run the pipeline when the notebook is executed
run_pipeline()

HTML(value='<h2>Clinical Notes Pipeline</h2>')

HTML(value='<p>Click on each step button to execute that part of the pipeline:</p>')

VBox(children=(Button(button_style='primary', description='Step 1: Load Data', layout=Layout(margin='10px 0px'…

Output()