In [None]:
import os
import json
import asyncio
import uuid
from pathlib import Path
from datetime import datetime

import nats
from nats.js.api import StreamConfig, RetentionPolicy, DiscardPolicy
from dotenv import load_dotenv
from IPython.display import display, JSON

# Load environment variables from .env file
load_dotenv(os.path.join("keys", ".env"))

# NATS connection settings
NAT_URL = os.getenv("NAT_URL", "nats://localhost:4222")
INPUT_STREAM = os.getenv("INPUT_STREAM", "IMAGE-TASKS")

# Get mode-specific subjects - ENSURE THESE MATCH YOUR MAIN.PY
INPUT_SUBJECT_TAGGER = os.getenv("INPUT_SUBJECT_TAGGER", "tagger.tasks.started.>")
INPUT_SUBJECT_CAPTIONING = os.getenv("INPUT_SUBJECT_CAPTIONING", "caption.tasks.started.>") 
LOCAL_ENV = os.getenv("LOCAL_ENV", "1")

# Display current configuration
print("NATS Image Producer Configuration:")
print(f"NATS URL: {NAT_URL}")
print(f"Input Stream: {INPUT_STREAM}")
print(f"Tagging Subject: {INPUT_SUBJECT_TAGGER}")
print(f"Captioning Subject: {INPUT_SUBJECT_CAPTIONING}")
print(f"Local Environment: {LOCAL_ENV}")

# Input directory for images
INPUT_DIR = "images_local"
os.makedirs(INPUT_DIR, exist_ok=True)

In [None]:
async def ensure_stream_exists(
    nats_url,
    stream_name,
    subjects
):
    """Ensure that a NATS stream exists with the given name and subjects"""
    nc = await nats.connect(nats_url)
    js = nc.jetstream()

    try:
        # Check if stream exists
        try:
            stream_info = await js.stream_info(stream_name)
            print(f"Stream '{stream_name}' already exists with subjects: {stream_info.config.subjects}")

            # Check if we need to add any subjects
            current_subjects = set(stream_info.config.subjects)
            requested_subjects = set(subjects)

            if not requested_subjects.issubset(current_subjects):
                missing_subjects = requested_subjects - current_subjects
                print(f"Adding missing subjects to stream: {missing_subjects}")

                # Update the stream with the combined subjects
                updated_config = stream_info.config
                updated_config.subjects = list(current_subjects.union(requested_subjects))
                await js.update_stream(config=updated_config)

        except Exception as e:
            # Stream doesn't exist, create it
            stream_config = StreamConfig(
                name=stream_name,
                subjects=subjects,
                retention=RetentionPolicy.WORK_QUEUE,
                max_age=30 * 24 * 60 * 60,  # 30 days
                duplicate_window=6 * 60,  # 6 minutes
                discard=DiscardPolicy.NEW,  # Discard new messages if the stream is full
            )

            await js.add_stream(config=stream_config)
            print(f"Stream '{stream_name}' created with subjects: {subjects}")
    finally:
        await nc.close()

In [None]:
async def publish_images_to_nats(
    folder_path, 
    nats_url=NAT_URL,
    input_stream=INPUT_STREAM,
    mode="tagging",
    local_env=LOCAL_ENV,
    num_labels=5,
    prompt="OD"
):
    """
    Publish all image files from a folder to the NATS input stream/subject.

    Args:
        folder_path: Path to the folder containing image files to process
        nats_url: The NATS server URL
        input_stream: The input stream name
        mode: Processing mode ("tagging" or "captioning")
        local_env: Local environment flag
        num_labels: Number of labels to return for each image (for tagging)
        prompt: Prompt for Florence-2 model ("OD" or "MORE_DETAILED_CAPTION")
    """
    # Check if folder exists
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"Folder not found: {folder_path}")

    # Determine the appropriate subject based on mode
    if mode == "tagging":
        input_subject = INPUT_SUBJECT_TAGGER
    else:  # mode == "captioning"
        input_subject = INPUT_SUBJECT_CAPTIONING

    print(f"Using mode: {mode}, subject: {input_subject}, prompt: {prompt}")

    # First ensure the stream exists with all subjects
    await ensure_stream_exists(
        nats_url=nats_url,
        stream_name=input_stream,
        subjects=[INPUT_SUBJECT_TAGGER, INPUT_SUBJECT_CAPTIONING]
    )

    # Connect to NATS
    nc = await nats.connect(nats_url)
    js = nc.jetstream()

    # Track files published
    files_published = []

    try:
        # Process each image file in the folder
        for filename in os.listdir(folder_path):
            # Filter for common image file extensions
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')):
                file_path = os.path.join(folder_path, filename)
                print(f"Processing image file: {filename}")

                try:
                    # Create a temporary HTTP server URL or file URL
                    # For local testing, we'll use file:// URLs
                    if local_env == "1":
                        # Using absolute path for file:// URLs
                        abs_path = os.path.abspath(file_path)
                        uri = f"file://{abs_path}"
                    else:
                        # In production, this might be an HTTP URL
                        uri = f"file://{file_path}"

                    # Create message payload
                    message = {
                        "source": {
                            "uri": uri,
                            "type": "image"
                        },
                        "state": {
                            "status": "STARTED",
                            "timestamp": datetime.now().isoformat()
                        },
                        "prompt": prompt
                    }

                    # Add num_labels only for tagging mode
                    if mode == "tagging":
                        message["num_labels"] = num_labels

                    # Create message headers with filename
                    headers = {
                        "filename": filename
                    }

                    # Publish message to input subject
                    await js.publish(
                        input_subject, 
                        json.dumps(message).encode(), 
                        headers=headers
                    )

                    print(f"Published file {filename} to {input_subject}")
                    files_published.append(filename)

                except Exception as e:
                    print(f"Error processing file {filename}: {str(e)}")

    finally:
        # Close NATS connection
        await nc.close()

    return files_published

async def process_folder_with_mode(folder_path, mode="tagging", num_labels=5, prompt=None):
    """
    Publish all image files in the folder to NATS with specified mode

    Args:
        folder_path: Path to folder containing images
        mode: "tagging" or "captioning"
        num_labels: Number of labels for tagging mode
        prompt: Override the default prompt (if None, will use mode-appropriate default)
    """
    # Set default prompt based on mode if not provided
    if prompt is None:
        if mode == "tagging":
            prompt = "OD"
        else:  # mode == "captioning"
            prompt = "MORE_DETAILED_CAPTION"

    print(f"Publishing image files from {folder_path} to NATS using {mode} mode...")
    files_published = await publish_images_to_nats(
        folder_path, 
        mode=mode, 
        num_labels=num_labels, 
        prompt=prompt
    )

    if files_published:
        print(f"\nPublished {len(files_published)} files to NATS")
        print("Files published:")
        for file in files_published:
            print(f"- {file}")
    else:
        print("No files were published to NATS")

    return files_published

In [None]:
# 1. For tagging mode (original behavior)
await process_folder_with_mode(INPUT_DIR, mode="tagging", num_labels=5, prompt="OD")

In [None]:
# captioning

await process_folder_with_mode(INPUT_DIR, mode="captioning", prompt="MORE_DETAILED_CAPTION")

In [None]:
# both modes 
#await process_folder_both_modes(INPUT_DIR, num_labels=5)