In [None]:
! pip install pypdfium2 PyMuPDF

In [None]:
-- List PDF documents in a Snowflake stage
list @ADVANCED_ANALYTICS.REDACT_PDF_DEMO.ORIG_PDFS

In [None]:
import streamlit as st
from snowflake.snowpark.context import get_active_session
import pypdfium2 as pdfium
import fitz  # PyMuPDF
import json
import os
import io

# --- App Setup ---
st.title("🔒 PII Detection & Redaction Tool")
st.markdown("*Powered by Snowflake Cortex AI*")

# Get the active Snowflake session
session = get_active_session()

# --- Functions ---

def get_staged_pdf_list(_sess):
    """
    Refreshes the stage and returns a list of PDF files.
    """
    stage_name = "ADVANCED_ANALYTICS.REDACT_PDF_DEMO.ORIG_PDFS"
    _sess.sql(f"ALTER STAGE {stage_name} REFRESH").collect()
    query = f"""
        SELECT RELATIVE_PATH
        FROM DIRECTORY(@{stage_name})
        WHERE RELATIVE_PATH ILIKE '%.pdf'
        ORDER BY RELATIVE_PATH ASC
    """
    return _sess.sql(query).to_pandas()

def redact_pii_from_pdf(pdf_path, pii_list):
    """
    Redacts PII from a PDF by finding the coordinates of each word, matching them against
    the PII list, and then redacting the identified phrases.
    """
    doc = fitz.open(pdf_path)
    total_redactions = 0
    
    # Clean up the PII list for efficient searching
    pii_to_find = list(set(item.strip().lower() for item in pii_list if item.strip()))
    
    if not pii_to_find:
        st.warning("⚠️ PII list is empty, no redactions to perform.")
        return doc.tobytes(), 0

    st.info(f"🔍 Searching for {len(pii_to_find)} unique PII elements across {len(doc)} pages...")

    # Iterate through each page of the PDF
    for page_num, page in enumerate(doc):
        redactions_on_page = 0
        # Get all words on the page with their coordinates
        words = page.get_text("words")
        
        if not words:
            continue

        for pii_item in pii_to_find:
            pii_words = pii_item.split()
            if not pii_words:
                continue

            # Search for sequences of words on the page that match the PII item
            for i in range(len(words) - len(pii_words) + 1):
                phrase_to_check = " ".join(words[j][4] for j in range(i, i + len(pii_words)))
                
                # Check for a case-insensitive match
                if phrase_to_check.lower() == pii_item:
                    start_rect = fitz.Rect(words[i][:4])
                    end_rect = fitz.Rect(words[i + len(pii_words) - 1][:4])
                    redaction_rect = start_rect | end_rect
                    
                    page.add_redact_annot(redaction_rect, fill=(0, 0, 0))
                    total_redactions += 1
                    redactions_on_page += 1
        
        if redactions_on_page > 0:
            st.write(f"✅ Marked {redactions_on_page} redactions on page {page_num + 1}.")
            # Apply redactions, which also removes the underlying text
            page.apply_redactions(images=fitz.PDF_REDACT_IMAGE_PIXELS) 

    pdf_bytes = doc.tobytes()
    doc.close()
    
    return pdf_bytes, total_redactions

# --- Main App Logic ---

try:
    # 1️⃣ PDF Selection
    st.header("1️⃣ Select PDF Document")
    df_files = get_staged_pdf_list(session)
    selected_pdf = st.selectbox(
        'Choose a PDF to analyze:',
        df_files['RELATIVE_PATH'],
        help="Select a PDF from your Snowflake stage."
    )

    if selected_pdf:
        # --- CHANGE: Clear session state if a new PDF is selected ---
        if 'selected_pdf' in st.session_state and st.session_state.get('selected_pdf') != selected_pdf:
            keys_to_clear = ['extracted_content', 'pii_list', 'redacted_pdf_bytes', 'total_redactions', 'local_file_path']
            for key in keys_to_clear:
                if key in st.session_state:
                    del st.session_state[key]
        
        # Download file only if it's not already downloaded for the current session
        if 'local_file_path' not in st.session_state or st.session_state.get('selected_pdf') != selected_pdf:
            st.session_state['selected_pdf'] = selected_pdf
            stage_path = f"@ADVANCED_ANALYTICS.REDACT_PDF_DEMO.ORIG_PDFS/{selected_pdf}"
            temp_dir = "/tmp"
            session.file.get(stage_path, temp_dir)
            local_file_path = os.path.join(temp_dir, os.path.basename(selected_pdf))
            st.session_state['local_file_path'] = local_file_path

        # Display PDF preview
        pdf_document = pdfium.PdfDocument(st.session_state['local_file_path'])
        first_page = pdf_document[0]
        pil_image = first_page.render(scale=1.5).to_pil()
        st.image(pil_image, caption=f"Preview: {selected_pdf}", use_column_width=True)

        st.divider()

        # 2️⃣ Text Extraction
        st.header("2️⃣ Extract Document Text")
        if st.button(f"📝 Extract Text from {selected_pdf}", key="parse_button"):
            with st.spinner("Extracting text from document..."):
                try:
                    parse_query = f"""
                        SELECT snowflake.cortex.parse_document(
                            @ADVANCED_ANALYTICS.REDACT_PDF_DEMO.ORIG_PDFS,
                            '{selected_pdf.replace("'", "''")}'
                        )
                    """
                    parsed_result = session.sql(parse_query).collect()
                    
                    if parsed_result:
                        json_string = parsed_result[0][0]
                        data = json.loads(json_string)
                        content = data.get("content", "No content found.").replace('\n', ' ')
                        st.session_state['extracted_content'] = content
                except Exception as e:
                    st.error("❌ Failed to extract text from document.")
                    st.exception(e)

        # --- CHANGE: Always display extracted text if it exists ---
        if 'extracted_content' in st.session_state:
            st.success("✅ Text extraction completed!")
            st.text_area("Extracted Text:", st.session_state.get('extracted_content'), height=300)

        st.divider()

        # 3️⃣ PII Detection
        st.header("3️⃣ Detect PII Elements")
        if st.button("Find PII in Extracted Text", key="pii_button"):
            if 'extracted_content' in st.session_state and st.session_state['extracted_content']:
                with st.spinner("Analyzing text for PII..."):
                    try:
                        prompt = f"""
                            Extract all PII from the text. PII includes names, phone numbers, emails, addresses, and unique URLs.
                            List each item. If none, state 'No PII detected'. Text: {st.session_state['extracted_content']}
                        """
                        pii_query = f"""
                            SELECT AI_COMPLETE(model => 'claude-3-5-sonnet', prompt => $${prompt}$$, 
                                response_format => {{'type': 'json', 'schema': {{'type': 'object', 
                                'properties': {{'pii_list': {{'type': 'array', 'items': {{'type': 'string'}}}}}}}}}}) as pii_results
                        """
                        pii_results_df = session.sql(pii_query).to_pandas()
                        pii_output = pii_results_df['PII_RESULTS'][0]
                        pii_data = json.loads(pii_output)
                        st.session_state['pii_list'] = pii_data.get('pii_list', [])
                    except Exception as e:
                        st.error("❌ Failed to analyze PII.")
                        st.exception(e)
            else:
                st.warning("⚠️ Please extract text first before analyzing for PII.")
        
        # --- CHANGE: Always display PII list if it exists ---
        if 'pii_list' in st.session_state:
            st.success("PII Analysis Complete")
            pii_list = st.session_state.get('pii_list')
            if pii_list:
                st.markdown("### Found PII elements:")
                st.json(pii_list)
            else:
                st.info("✅ No PII elements were found in the text.")

        st.divider()

        # 4️⃣ PII Redaction
        st.header("4️⃣ Redact PII from PDF")
        if st.button("🖤 Redact PII", key="redact_button"):
            if 'pii_list' in st.session_state and 'local_file_path' in st.session_state:
                with st.spinner("Redacting PII from PDF... This may take a moment."):
                    try:
                        redacted_pdf_bytes, total_redactions = redact_pii_from_pdf(
                            st.session_state['local_file_path'], 
                            st.session_state.get('pii_list', [])
                        )
                        st.session_state['redacted_pdf_bytes'] = redacted_pdf_bytes
                        st.session_state['total_redactions'] = total_redactions
                    except Exception as e:
                        st.error("❌ Failed to redact PDF.")
                        st.exception(e)
            else:
                st.warning("⚠️ Please detect PII first before redacting.")

        # --- CHANGE: Always display redacted PDF if it exists ---
        if 'redacted_pdf_bytes' in st.session_state:
            total_redactions = st.session_state.get('total_redactions', 0)
            redacted_pdf_bytes = st.session_state.get('redacted_pdf_bytes')
            if total_redactions > 0:
                st.success(f"✅ Successfully applied {total_redactions} redactions!")
                st.markdown("### Preview of Redacted PDF")
                redacted_pdf_stream = io.BytesIO(redacted_pdf_bytes)
                pdf_document_redacted = pdfium.PdfDocument(redacted_pdf_stream)
                first_page_redacted = pdf_document_redacted[0]
                pil_image_redacted = first_page_redacted.render(scale=1.5).to_pil()
                st.image(pil_image_redacted, caption=f"Preview: Redacted {selected_pdf}", use_column_width=True)
            else:
                st.warning("⚠️ Unable to redact the PDF - Most likely it is not machine readable.")
            
            st.download_button(
                label="📥 Download Redacted PDF",
                data=redacted_pdf_bytes,
                file_name=f"redacted_{selected_pdf}",
                mime="application/pdf"
            )

except Exception as e:
    st.error("❌ An application error occurred. Please check the configuration and stage settings.")
    st.exception(e)