In [None]:
import os
import json
import asyncio
import uuid
from pathlib import Path
from datetime import datetime

import nats
from nats.js.api import StreamConfig, ConsumerConfig, AckPolicy, RetentionPolicy, DiscardPolicy
from dotenv import load_dotenv
from IPython.display import display, JSON

# Load environment variables from .env file
load_dotenv(os.path.join("keys", ".env"))

# NATS connection settings
NAT_URL = os.getenv("NAT_URL", "nats://localhost:4222")
OUTPUT_STREAM = os.getenv("OUTPUT_STREAM", "PII-RESULTS")
OUTPUT_SUBJECT = os.getenv("OUTPUT_SUBJECT", "pii.results.completed")  # Removed '>' suffix
LOCAL_ENV = os.getenv("LOCAL_ENV", "1")

# Create output directory for results
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Display current configuration
print("NATS Consumer Configuration:")
print(f"NATS URL: {NAT_URL}")
print(f"Output Stream: {OUTPUT_STREAM}")
print(f"Output Subject: {OUTPUT_SUBJECT}")
print(f"Local Environment: {LOCAL_ENV}")
print(f"Output Directory: {OUTPUT_DIR}")

In [None]:
def analyze_pii_results(data):
    """Analyze PII findings from a single result"""
    if not data:
        print("No data to analyze")
        return
    
    # Extract the data section from the result
    documents = data.get("result", {}).get("data", [])
    
    if not documents:
        print("No documents found in result")
        return
    
    total_documents = len(documents)
    total_pii_items = 0
    pii_by_type = {}
    
    for document in documents:
        # Extract PII items
        pii_items = document.get("pii_items", [])
        total_pii_items += len(pii_items)
        
        # Count by PII type
        for item in pii_items:
            pii_type = item.get("type", "unknown")
            if pii_type in pii_by_type:
                pii_by_type[pii_type] += 1
            else:
                pii_by_type[pii_type] = 1
    
    print(f"Total documents analyzed: {total_documents}")
    print(f"Total PII items found: {total_pii_items}")
    print("\nPII by type:")
    for pii_type, count in sorted(pii_by_type.items(), key=lambda x: x[1], reverse=True):
        print(f"  - {pii_type}: {count}")

In [None]:
async def monitor_stream_for_new_messages(
    output_dir=OUTPUT_DIR,
    nats_url=NAT_URL,
    stream_name=OUTPUT_STREAM,
    run_time_seconds=3600,  # Default to run for 1 hour
    poll_interval=1  # Seconds between polling the stream
):
    """
    Monitor a NATS stream for new messages based on the last seen sequence.
    
    Args:
        output_dir: Directory to save output files
        nats_url: The NATS server URL
        stream_name: The stream name to read from
        run_time_seconds: How long to run the consumer (in seconds)
        poll_interval: Seconds between polling the stream
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Connect to NATS
    nc = await nats.connect(nats_url)
    js = nc.jetstream()
    
    # Track number of messages processed
    messages_processed = 0
    
    # Initialize start_time
    start_time = datetime.now()
    
    try:
        # Get stream info to find the current last sequence
        try:
            stream_info = await js.stream_info(stream_name)
            last_seq = stream_info.state.last_seq
            print(f"Connected to stream '{stream_name}'")
            print(f"Stream contains {stream_info.state.messages} messages")
            print(f"Last sequence is {last_seq}")
        except Exception as e:
            print(f"Error getting stream info: {e}")
            return messages_processed
        
        print(f"Stream monitor will run for {run_time_seconds} seconds or until interrupted")
        print(f"Results will be saved to {output_dir}")
        print("Monitoring for new messages...")
        
        # Current sequence to start from
        current_seq = last_seq + 1
        
        # Poll the stream until the run time is reached
        while (datetime.now() - start_time).total_seconds() < run_time_seconds:
            try:
                # Get the current stream info
                stream_info = await js.stream_info(stream_name)
                new_last_seq = stream_info.state.last_seq
                
                # Check if there are new messages
                if new_last_seq >= current_seq:
                    print(f"Found {new_last_seq - current_seq + 1} new messages")
                    
                    # Get messages from current_seq to new_last_seq
                    for seq in range(current_seq, new_last_seq + 1):
                        try:
                            # Get the message at this sequence
                            msg = await js.get_msg(stream_name, seq)
                            
                            # Parse message data
                            data = json.loads(msg.data.decode("utf-8"))
                            
                            # Extract filename from headers if available
                            if hasattr(msg, 'headers') and msg.headers and "filename" in msg.headers:
                                filename = msg.headers["filename"]
                            else:
                                # Generate a unique filename
                                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                filename = f"result_{timestamp}_{seq}.json"
                            
                            # Save result to file
                            output_path = os.path.join(output_dir, filename)
                            with open(output_path, "w") as f:
                                json.dump(data, f, indent=2)
                            
                            # Increment counter and display info
                            messages_processed += 1
                            print(f"Received message #{messages_processed}, seq={seq}, saved to {output_path}")
                            
                            # Analyze the PII findings
                            print("\nPII Analysis:")
                            analyze_pii_results(data)
                            print("-" * 50)
                            
                        except Exception as e:
                            print(f"Error processing message with seq={seq}: {str(e)}")
                            continue
                    
                    # Update current_seq
                    current_seq = new_last_seq + 1
                else:
                    # No new messages
                    print("No new messages found. Waiting...")
                
                # Wait before polling again
                await asyncio.sleep(poll_interval)
                
            except Exception as e:
                print(f"Error polling stream: {str(e)}")
                await asyncio.sleep(poll_interval)
                
    except KeyboardInterrupt:
        print("\nReceived interrupt signal, shutting down...")
    except Exception as e:
        print(f"Error in monitor: {str(e)}")
    finally:
        # Close NATS connection
        await nc.close()
        print("NATS connection closed")
    
    # Print summary
    run_time = (datetime.now() - start_time).total_seconds()
    print(f"\nMonitor summary:")
    print(f"- Run time: {run_time:.2f} seconds")
    print(f"- Messages processed: {messages_processed}")
    print(f"- Results saved to: {output_dir}")
    
    return messages_processed

# Alternative approach: Process existing messages and then monitor for new ones
async def process_and_monitor_stream(
    output_dir=OUTPUT_DIR,
    nats_url=NAT_URL,
    stream_name=OUTPUT_STREAM,
    run_time_seconds=3600,  # Default to run for 1 hour
    poll_interval=1,  # Seconds between polling the stream
    start_from_beginning=False  # Whether to process all existing messages
):
    """
    Process existing messages in a stream and then monitor for new ones.
    
    Args:
        output_dir: Directory to save output files
        nats_url: The NATS server URL
        stream_name: The stream name to read from
        run_time_seconds: How long to run the consumer (in seconds)
        poll_interval: Seconds between polling the stream
        start_from_beginning: Whether to process all existing messages
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Connect to NATS
    nc = await nats.connect(nats_url)
    js = nc.jetstream()
    
    # Track number of messages processed
    messages_processed = 0
    
    # Initialize start_time
    start_time = datetime.now()
    
    try:
        # Get stream info
        try:
            stream_info = await js.stream_info(stream_name)
            total_messages = stream_info.state.messages
            first_seq = stream_info.state.first_seq
            last_seq = stream_info.state.last_seq
            
            print(f"Connected to stream '{stream_name}'")
            print(f"Stream contains {total_messages} messages")
            print(f"Sequence range: {first_seq} to {last_seq}")
        except Exception as e:
            print(f"Error getting stream info: {e}")
            return messages_processed
        
        # Determine starting sequence
        if start_from_beginning and total_messages > 0:
            current_seq = first_seq
            print(f"Will process all existing messages starting from sequence {current_seq}")
        else:
            current_seq = last_seq + 1
            print(f"Will only process new messages starting from sequence {current_seq}")
        
        print(f"Stream processor will run for {run_time_seconds} seconds or until interrupted")
        print(f"Results will be saved to {output_dir}")
        print("Processing messages...")
        
        # Process messages until the run time is reached
        while (datetime.now() - start_time).total_seconds() < run_time_seconds:
            try:
                # Get the current stream info
                stream_info = await js.stream_info(stream_name)
                new_last_seq = stream_info.state.last_seq
                
                # Check if there are messages to process
                if new_last_seq >= current_seq:
                    if current_seq <= last_seq:
                        print(f"Processing {min(new_last_seq, last_seq) - current_seq + 1} existing messages")
                    else:
                        print(f"Found {new_last_seq - current_seq + 1} new messages")
                    
                    # Get messages from current_seq to new_last_seq
                    for seq in range(current_seq, new_last_seq + 1):
                        try:
                            # Get the message at this sequence
                            msg = await js.get_msg(stream_name, seq)
                            
                            # Parse message data
                            data = json.loads(msg.data.decode("utf-8"))
                            
                            # Extract filename from headers if available
                            if hasattr(msg, 'headers') and msg.headers and "filename" in msg.headers:
                                filename = msg.headers["filename"]
                            else:
                                # Generate a unique filename
                                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                filename = f"result_{timestamp}_{seq}.json"
                            
                            # Save result to file
                            output_path = os.path.join(output_dir, filename)
                            with open(output_path, "w") as f:
                                json.dump(data, f, indent=2)
                            
                            # Increment counter and display info
                            messages_processed += 1
                            print(f"Processed message #{messages_processed}, seq={seq}, saved to {output_path}")
                            
                            # Analyze the PII findings
                            print("\nPII Analysis:")
                            analyze_pii_results(data)
                            print("-" * 50)
                            
                        except Exception as e:
                            print(f"Error processing message with seq={seq}: {str(e)}")
                            continue
                    
                    # Update current_seq
                    current_seq = new_last_seq + 1
                else:
                    # No new messages
                    print("No new messages found. Waiting...")
                
                # Wait before polling again
                await asyncio.sleep(poll_interval)
                
            except Exception as e:
                print(f"Error polling stream: {str(e)}")
                await asyncio.sleep(poll_interval)
                
    except KeyboardInterrupt:
        print("\nReceived interrupt signal, shutting down...")
    except Exception as e:
        print(f"Error in processor: {str(e)}")
    finally:
        # Close NATS connection
        await nc.close()
        print("NATS connection closed")
    
    # Print summary
    run_time = (datetime.now() - start_time).total_seconds()
    print(f"\nProcessor summary:")
    print(f"- Run time: {run_time:.2f} seconds")
    print(f"- Messages processed: {messages_processed}")
    print(f"- Results saved to: {output_dir}")
    
    return messages_processed

In [None]:
run_time = 3600  # 1 hour

# Set the output directory for results
output_directory = "output"

poll_interval = 1
start_from_beginning = True

In [None]:
# monitor only new messages
messages_processed = await monitor_stream_for_new_messages(
    output_dir=output_directory,
    run_time_seconds=run_time,
    poll_interval=poll_interval
)

In [None]:
# Process queue from begining
messages_processed = await process_and_monitor_stream(
    output_dir=output_directory,
    run_time_seconds=run_time,
    poll_interval=poll_interval,
    start_from_beginning=start_from_beginning
)