In [None]:
import os
import json
import asyncio
import uuid
from pathlib import Path
from datetime import datetime

import nats
from nats.js.api import StreamConfig, ConsumerConfig, AckPolicy, RetentionPolicy, DiscardPolicy
from dotenv import load_dotenv
from IPython.display import display, JSON, Image as IPImage
import matplotlib.pyplot as plt

# Load environment variables from .env file
load_dotenv(os.path.join("keys", ".env"))

# NATS connection settings
NAT_URL = os.getenv("NAT_URL", "nats://localhost:4222")
OUTPUT_STREAM = os.getenv("OUTPUT_STREAM", "IMAGE-RESULTS")

# Get mode-specific subjects
OUTPUT_SUBJECT_TAGGER = os.getenv("OUTPUT_SUBJECT_TAGGER", "tagger.results.completed")
OUTPUT_SUBJECT_CAPTIONING = os.getenv("OUTPUT_SUBJECT_CAPTIONING", "caption.results.completed")

LOCAL_ENV = os.getenv("LOCAL_ENV", "1")

# Create output directories for results
OUTPUT_DIR = "output_images"
OUTPUT_DIR_TAGGING = os.path.join(OUTPUT_DIR, "tagging")
OUTPUT_DIR_CAPTIONING = os.path.join(OUTPUT_DIR, "captioning")

# Create directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR_TAGGING, exist_ok=True)
os.makedirs(OUTPUT_DIR_CAPTIONING, exist_ok=True)

# Display current configuration
print("NATS Image Consumer Configuration:")
print(f"NATS URL: {NAT_URL}")
print(f"Output Stream: {OUTPUT_STREAM}")
print(f"Tagging Subject: {OUTPUT_SUBJECT_TAGGER}")
print(f"Captioning Subject: {OUTPUT_SUBJECT_CAPTIONING}")
print(f"Local Environment: {LOCAL_ENV}")
print(f"Output Directory: {OUTPUT_DIR}")
print(f"  - Tagging Results: {OUTPUT_DIR_TAGGING}")
print(f"  - Captioning Results: {OUTPUT_DIR_CAPTIONING}")

In [None]:
def analyze_image_results(data, mode="auto"):
    """
    Analyze image processing results from a single result
    
    Args:
        data: The result data to analyze
        mode: Processing mode - "tagging", "captioning", or "auto" (detect)
    """
    if not data:
        print("No data to analyze")
        return
    
    # Extract the data section from the result
    documents = data.get("data", [])
    
    if not documents:
        print("No documents found in result")
        return
    
    # Detect mode if set to auto
    if mode == "auto":
        # Check the first document's content format to determine the mode
        if documents and "source" in documents[0]:
            content = documents[0]["source"].get("content", [])
            # If content is a list of dictionaries with 'label' and 'confidence', it's tagging
            if isinstance(content, list) and content and isinstance(content, list):
                if content and isinstance(content[0], dict) and "label" in content[0] and "confidence" in content[0]:
                    mode = "tagging"
                # If content is a list of strings, it's captioning
                elif content and isinstance(content[0], str):
                    mode = "captioning"
                else:
                    mode = "unknown"
    
    print(f"Analyzing results in {mode.upper()} mode")
    
    # Analyze based on detected or specified mode
    if mode == "tagging":
        analyze_tagging_results(documents)
    elif mode == "captioning":
        analyze_captioning_results(documents)
    else:
        print(f"Unknown result format - cannot analyze")


def analyze_tagging_results(documents):
    """Analyze image tagging results"""
    total_documents = len(documents)
    total_tags = 0
    tags_by_confidence = {}
    
    for document in documents:
        # Extract source information for display
        source_info = document.get("source", {})
        file_name = source_info.get("file_name", "unknown")
        
        print(f"\nDocument: {document.get('id', 'unknown')}")
        print(f"File: {file_name}")
        
        # Extract tags/labels
        tags = source_info.get("content", [])
        total_tags += len(tags)
        
        # Display tags with confidence
        if tags:
            print("\nTags:")
            for tag in tags:
                label = tag.get("label", "unknown")
                confidence = tag.get("confidence", 0)
                
                # Group tags by confidence level
                confidence_level = round(confidence * 10) / 10  # Round to 1 decimal
                if confidence_level in tags_by_confidence:
                    tags_by_confidence[confidence_level].append(label)
                else:
                    tags_by_confidence[confidence_level] = [label]
                
                print(f"  - {label}: {confidence:.4f}")
        else:
            print("No tags found for this document")
    
    print(f"\nSummary:")
    print(f"Total documents analyzed: {total_documents}")
    print(f"Total tags found: {total_tags}")
    print("\nTags by confidence level:")
    for confidence, tags in sorted(tags_by_confidence.items(), key=lambda x: x[0], reverse=True):
        print(f"  Confidence {confidence:.1f}: {len(tags)} tags")
        if confidence >= 0.5:  # Only show high confidence tags in summary
            tag_list = ", ".join(tags)
            print(f"    Tags: {tag_list}")


def analyze_captioning_results(documents):
    """Analyze image captioning results"""
    total_documents = len(documents)
    
    print(f"\nCAPTIONING RESULTS")
    print(f"Total documents analyzed: {total_documents}")
    
    for document in documents:
        # Extract source information for display
        source_info = document.get("source", {})
        file_name = source_info.get("file_name", "unknown")
        
        print(f"\nDocument: {document.get('id', 'unknown')}")
        print(f"File: {file_name}")
        
        # Extract caption/description
        captions = source_info.get("content", [])
        
        # Display captions
        if captions:
            print("\nCaption:")
            for caption in captions:
                print(f"  {caption}")
        else:
            print("No caption found for this document")
    
    print(f"\nSummary:")
    print(f"Processed {total_documents} image(s) with the captioning model")


def visualize_top_tags(data, max_tags=10):
    """Create a simple visualization of top tags from the data"""
    if not data:
        return
    
    # Extract the data section from the result
    documents = data.get("data", [])
    
    if not documents:
        return
    
    # Check if this is tagging data
    if not documents or "source" not in documents[0]:
        print("No tagging data to visualize")
        return
    
    # Get first document to check content format
    first_doc = documents[0]
    content = first_doc.get("source", {}).get("content", [])
    
    # Check if this is tagging data
    if not content or not isinstance(content, list) or not isinstance(content[0], dict) or "label" not in content[0]:
        print("This appears to be captioning data, not tagging data - no visualization available")
        return
    
    # Collect all tags with their confidences
    tag_counts = {}
    
    for document in documents:
        tags = document.get("source", {}).get("content", [])
        
        for tag in tags:
            label = tag.get("label", "unknown")
            confidence = tag.get("confidence", 0)
            
            if label in tag_counts:
                tag_counts[label] = max(tag_counts[label], confidence)  # Keep highest confidence
            else:
                tag_counts[label] = confidence
    
    # Sort by confidence and get top tags
    top_tags = sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)[:max_tags]
    
    if not top_tags:
        print("No tags to visualize")
        return
    
    # Create visualization
    labels = [tag[0] for tag in top_tags]
    confidences = [tag[1] for tag in top_tags]
    
    plt.figure(figsize=(10, 6))
    plt.barh(labels, confidences, color='skyblue')
    plt.xlabel('Confidence')
    plt.title('Top Tags by Confidence')
    plt.xlim(0, 1.0)
    plt.gca().invert_yaxis()  # Highest confidence at the top
    plt.tight_layout()
    plt.show()


In [None]:
async def monitor_stream_for_new_messages(
    mode="both",
    output_dir=OUTPUT_DIR,
    nats_url=NAT_URL,
    stream_name=OUTPUT_STREAM,
    run_time_seconds=3600,  # Default to run for 1 hour
    poll_interval=1  # Seconds between polling the stream
):
    """
    Monitor a NATS stream for new messages based on the last seen sequence.
    
    Args:
        mode: Which mode to monitor - "tagging", "captioning", or "both"
        output_dir: Directory to save output files
        nats_url: The NATS server URL
        stream_name: The stream name to read from
        run_time_seconds: How long to run the consumer (in seconds)
        poll_interval: Seconds between polling the stream
    """
    # Determine which subjects to monitor
    subjects_to_monitor = []
    if mode == "tagging" or mode == "both":
        subjects_to_monitor.append(OUTPUT_SUBJECT_TAGGER)
    if mode == "captioning" or mode == "both":
        subjects_to_monitor.append(OUTPUT_SUBJECT_CAPTIONING)
    
    if not subjects_to_monitor:
        print(f"Invalid mode: {mode}. Please use 'tagging', 'captioning', or 'both'")
        return 0
    
    print(f"Monitoring {mode} mode with subjects: {subjects_to_monitor}")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Connect to NATS
    nc = await nats.connect(nats_url)
    js = nc.jetstream()
    
    # Track number of messages processed
    messages_processed = 0
    
    # Initialize start_time
    start_time = datetime.now()
    
    try:
        # Get stream info to find the current last sequence
        try:
            stream_info = await js.stream_info(stream_name)
            last_seq = stream_info.state.last_seq
            print(f"Connected to stream '{stream_name}'")
            print(f"Stream contains {stream_info.state.messages} messages")
            print(f"Last sequence is {last_seq}")
        except Exception as e:
            print(f"Error getting stream info: {e}")
            return messages_processed
        
        print(f"Stream monitor will run for {run_time_seconds} seconds or until interrupted")
        print(f"Results will be saved to {output_dir}")
        print("Monitoring for new messages...")
        
        # Current sequence to start from
        current_seq = last_seq + 1
        
        # Poll the stream until the run time is reached
        while (datetime.now() - start_time).total_seconds() < run_time_seconds:
            try:
                # Get the current stream info
                stream_info = await js.stream_info(stream_name)
                new_last_seq = stream_info.state.last_seq
                
                # Check if there are new messages
                if new_last_seq >= current_seq:
                    print(f"Found {new_last_seq - current_seq + 1} new messages")
                    
                    # Get messages from current_seq to new_last_seq
                    for seq in range(current_seq, new_last_seq + 1):
                        try:
                            # Get the message at this sequence
                            msg = await js.get_msg(stream_name, seq)
                            
                            # Check if message subject is in our list of subjects to monitor
                            if msg.subject not in subjects_to_monitor:
                                print(f"Skipping message with subject {msg.subject}")
                                continue
                            
                            # Determine the mode from the subject
                            msg_mode = "tagging" if msg.subject == OUTPUT_SUBJECT_TAGGER else "captioning"
                            
                            # Set appropriate output directory based on mode
                            if msg_mode == "tagging":
                                result_dir = OUTPUT_DIR_TAGGING
                            else:
                                result_dir = OUTPUT_DIR_CAPTIONING
                            
                            # Parse message data
                            data = json.loads(msg.data.decode("utf-8"))
                            
                            # Extract filename from headers if available
                            if hasattr(msg, 'headers') and msg.headers and "filename" in msg.headers:
                                orig_filename = msg.headers["filename"]
                                filename = f"result_{orig_filename}.json"
                            else:
                                # Generate a unique filename
                                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                filename = f"result_{timestamp}_{seq}.json"
                            
                            # Save result to file
                            output_path = os.path.join(result_dir, filename)
                            with open(output_path, "w") as f:
                                json.dump(data, f, indent=2)
                            
                            # Increment counter and display info
                            messages_processed += 1
                            print(f"Received {msg_mode} message #{messages_processed}, seq={seq}, saved to {output_path}")
                            
                            # Analyze the results based on mode
                            print(f"\n{msg_mode.upper()} Analysis:")
                            analyze_image_results(data, mode=msg_mode)
                            
                            # Visualize top tags for tagging mode
                            if msg_mode == "tagging":
                                print("\nVisualizing top tags:")
                                visualize_top_tags(data)
                            
                            print("-" * 50)
                            
                        except Exception as e:
                            print(f"Error processing message with seq={seq}: {str(e)}")
                            continue
                    
                    # Update current_seq
                    current_seq = new_last_seq + 1
                else:
                    # No new messages
                    print("No new messages found. Waiting...")
                
                # Wait before polling again
                await asyncio.sleep(poll_interval)
                
            except Exception as e:
                print(f"Error polling stream: {str(e)}")
                await asyncio.sleep(poll_interval)
                
    except KeyboardInterrupt:
        print("\nReceived interrupt signal, shutting down...")
    except Exception as e:
        print(f"Error in monitor: {str(e)}")
    finally:
        # Close NATS connection
        await nc.close()
        print("NATS connection closed")
    
    # Print summary
    run_time = (datetime.now() - start_time).total_seconds()
    print(f"\nMonitor summary:")
    print(f"- Run time: {run_time:.2f} seconds")
    print(f"- Messages processed: {messages_processed}")
    print(f"- Results saved to: {output_dir}")
    
    return messages_processed


# Alternative approach: Process existing messages and then monitor for new ones
async def process_and_monitor_stream(
    mode="both",
    output_dir=OUTPUT_DIR,
    nats_url=NAT_URL,
    stream_name=OUTPUT_STREAM,
    run_time_seconds=3600,  # Default to run for 1 hour
    poll_interval=1,  # Seconds between polling the stream
    start_from_beginning=False  # Whether to process all existing messages
):
    """
    Process existing messages in a stream and then monitor for new ones.
    
    Args:
        mode: Which mode to monitor - "tagging", "captioning", or "both"
        output_dir: Directory to save output files
        nats_url: The NATS server URL
        stream_name: The stream name to read from
        run_time_seconds: How long to run the consumer (in seconds)
        poll_interval: Seconds between polling the stream
        start_from_beginning: Whether to process all existing messages
    """
    # Determine which subjects to monitor
    subjects_to_monitor = []
    if mode == "tagging" or mode == "both":
        subjects_to_monitor.append(OUTPUT_SUBJECT_TAGGER)
    if mode == "captioning" or mode == "both":
        subjects_to_monitor.append(OUTPUT_SUBJECT_CAPTIONING)
    
    if not subjects_to_monitor:
        print(f"Invalid mode: {mode}. Please use 'tagging', 'captioning', or 'both'")
        return 0
    
    print(f"Processing and monitoring {mode} mode with subjects: {subjects_to_monitor}")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(OUTPUT_DIR_TAGGING, exist_ok=True)
    os.makedirs(OUTPUT_DIR_CAPTIONING, exist_ok=True)
    
    # Connect to NATS
    nc = await nats.connect(nats_url)
    js = nc.jetstream()
    
    # Track number of messages processed
    messages_processed = 0
    
    # Initialize start_time
    start_time = datetime.now()
    
    try:
        # Get stream info
        try:
            stream_info = await js.stream_info(stream_name)
            total_messages = stream_info.state.messages
            first_seq = stream_info.state.first_seq
            last_seq = stream_info.state.last_seq
            
            print(f"Connected to stream '{stream_name}'")
            print(f"Stream contains {total_messages} messages")
            print(f"Sequence range: {first_seq} to {last_seq}")
        except Exception as e:
            print(f"Error getting stream info: {e}")
            return messages_processed
        
        # Determine starting sequence
        if start_from_beginning and total_messages > 0:
            current_seq = first_seq
            print(f"Will process all existing messages starting from sequence {current_seq}")
        else:
            current_seq = last_seq + 1
            print(f"Will only process new messages starting from sequence {current_seq}")
        
        print(f"Stream processor will run for {run_time_seconds} seconds or until interrupted")
        print(f"Results will be saved to {output_dir}")
        print("Processing messages...")
        
        # Process messages until the run time is reached
        while (datetime.now() - start_time).total_seconds() < run_time_seconds:
            try:
                # Get the current stream info
                stream_info = await js.stream_info(stream_name)
                new_last_seq = stream_info.state.last_seq
                
                # Check if there are messages to process
                if new_last_seq >= current_seq:
                    if current_seq <= last_seq:
                        print(f"Processing {min(new_last_seq, last_seq) - current_seq + 1} existing messages")
                    else:
                        print(f"Found {new_last_seq - current_seq + 1} new messages")
                    
                    # Get messages from current_seq to new_last_seq
                    for seq in range(current_seq, new_last_seq + 1):
                        try:
                            # Get the message at this sequence
                            msg = await js.get_msg(stream_name, seq)
                            
                            # Check if message subject is in our list of subjects to monitor
                            if msg.subject not in subjects_to_monitor:
                                print(f"Skipping message with subject {msg.subject}")
                                continue
                            
                            # Determine the mode from the subject
                            msg_mode = "tagging" if msg.subject == OUTPUT_SUBJECT_TAGGER else "captioning"
                            
                            # Set appropriate output directory based on mode
                            if msg_mode == "tagging":
                                result_dir = OUTPUT_DIR_TAGGING
                            else:
                                result_dir = OUTPUT_DIR_CAPTIONING
                            
                            # Parse message data
                            data = json.loads(msg.data.decode("utf-8"))
                            
                            # Extract filename from headers if available
                            if hasattr(msg, 'headers') and msg.headers and "filename" in msg.headers:
                                orig_filename = msg.headers["filename"]
                                filename = f"result_{orig_filename}.json"
                            else:
                                # Generate a unique filename
                                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                filename = f"result_{timestamp}_{seq}.json"
                            
                            # Save result to file
                            output_path = os.path.join(result_dir, filename)
                            with open(output_path, "w") as f:
                                json.dump(data, f, indent=2)
                            
                            # Increment counter and display info
                            messages_processed += 1
                            print(f"Processed {msg_mode} message #{messages_processed}, seq={seq}, saved to {output_path}")
                            
                            # Analyze the results based on mode
                            print(f"\n{msg_mode.upper()} Analysis:")
                            analyze_image_results(data, mode=msg_mode)
                            
                            # Visualize top tags for tagging mode
                            if msg_mode == "tagging":
                                print("\nVisualizing top tags:")
                                visualize_top_tags(data)
                            
                            print("-" * 50)
                            
                        except Exception as e:
                            print(f"Error processing message with seq={seq}: {str(e)}")
                            continue
                    
                    # Update current_seq
                    current_seq = new_last_seq + 1
                else:
                    # No new messages
                    print("No new messages found. Waiting...")
                
                # Wait before polling again
                await asyncio.sleep(poll_interval)
                
            except Exception as e:
                print(f"Error polling stream: {str(e)}")
                await asyncio.sleep(poll_interval)
                
    except KeyboardInterrupt:
        print("\nReceived interrupt signal, shutting down...")
    except Exception as e:
        print(f"Error in processor: {str(e)}")
    finally:
        # Close NATS connection
        await nc.close()
        print("NATS connection closed")
    
    # Print summary
    run_time = (datetime.now() - start_time).total_seconds()
    print(f"\nProcessor summary:")
    print(f"- Run time: {run_time:.2f} seconds")
    print(f"- Messages processed: {messages_processed}")
    print(f"- Results saved to: {output_dir}")
    
    return messages_processed

In [None]:
# Example usage - choose one of these to run:

# Option 1: Monitor only tagging results
# await monitor_stream_for_new_messages(mode="tagging", run_time_seconds=3600)

# Option 2: Monitor only captioning results
# await monitor_stream_for_new_messages(mode="captioning", run_time_seconds=3600)

# Option 3: Monitor both types of results (default)
await process_and_monitor_stream(
    mode="tagging",
    run_time_seconds=3600,
    poll_interval=1,
    start_from_beginning=True
)

In [None]:
# only tagging
await monitor_stream_for_new_messages(mode="captioning", run_time_seconds=3600)

In [None]:
await process_and_monitor_stream(
    mode="captioning",
    run_time_seconds=3600,
    poll_interval=1,
    start_from_beginning=True
)

In [None]:
# Option 1: Monitor only new messages
await monitor_stream_for_new_messages(
    output_dir=output_directory,
    run_time_seconds=run_time,
    poll_interval=poll_interval
)

In [None]:
#Option 2: Process queue from beginning
await process_and_monitor_stream(
    output_dir=output_directory,
    run_time_seconds=run_time,
    poll_interval=poll_interval,
    start_from_beginning=start_from_beginning
)