# NSFW Analysis - Producer Notebook

This notebook publishes image and video files to the NSFW NATS endpoint for NSFW content detection.

In [None]:
import os
import json
import asyncio
import uuid
from pathlib import Path
from datetime import datetime

import nats
from nats.js.api import StreamConfig, RetentionPolicy, DiscardPolicy
from dotenv import load_dotenv
from IPython.display import display, JSON

In [None]:
# Load environment variables from .env file
load_dotenv(os.path.join("keys", ".env_file"))

# NATS connection settings
NAT_URL = os.getenv("NAT_URL", "nats://localhost:4222")

# Get NSFW stream and subject settings
INPUT_STREAM = os.getenv("INPUT_STREAM", "NSFW-TASKS")
INPUT_SUBJECT = os.getenv("INPUT_SUBJECT", "nsfw.tasks.started.>")

LOCAL_ENV = os.getenv("LOCAL_ENV", "1")

# Input directory for images and videos
INPUT_DIR = "images"
os.makedirs(INPUT_DIR, exist_ok=True)

# Display current configuration
print("NATS NSFW Producer Configuration:")
print(f"NATS URL: {NAT_URL}")
print(f"NSFW Task Stream: {INPUT_STREAM}")
print(f"NSFW Task Subject: {INPUT_SUBJECT}")
print(f"Local Environment: {LOCAL_ENV}")
print(f"Input Directory: {INPUT_DIR}")

## Helper Functions

In [None]:
# List of supported image and video file extensions
image_file_extensions = [".jpeg", ".jpg", ".tiff", ".png", ".tif", ".gif", ".bmp"]
video_file_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm"]

In [None]:
async def ensure_stream_exists(nats_url, stream_name, subjects):
    """
    Ensure that the specified NATS stream exists with the given subjects.
    
    Args:
        nats_url: The NATS server URL
        stream_name: The name of the stream to create
        subjects: List of subjects for the stream
    """
    nc = await nats.connect(nats_url)
    js = nc.jetstream()
    
    try:
        # Check if stream exists
        stream_info = await js.stream_info(stream_name)
        print(f"Stream '{stream_name}' already exists with {stream_info.state.messages} messages")
        
        # Check if subjects need to be updated
        existing_subjects = set(stream_info.config.subjects)
        requested_subjects = set(subjects)
        
        if not requested_subjects.issubset(existing_subjects):
            # Update subjects by including both existing and new subjects
            new_subjects = list(existing_subjects.union(requested_subjects))
            await js.update_stream(stream_name=stream_name, subjects=new_subjects)
            print(f"Updated stream '{stream_name}' with additional subjects: {requested_subjects - existing_subjects}")
        
    except Exception:
        # Stream doesn't exist, create it
        stream_config = StreamConfig(
            name=stream_name,
            subjects=subjects,
            retention=RetentionPolicy.WORK_QUEUE,  # Work queue mode
            max_age=30 * 24 * 60 * 60,  # 30 days retention
            duplicate_window=6 * 60,  # 6 minutes for duplicate detection
            discard=DiscardPolicy.NEW  # Discard new messages if the stream is full
        )
        
        await js.add_stream(config=stream_config)
        print(f"Created new stream '{stream_name}' with subjects: {subjects}")
    
    finally:
        await nc.close()

In [None]:
async def publish_files_to_nsfw(folder_path, nats_url=NAT_URL, local_env=LOCAL_ENV, threshold=0.5):
    """
    Publish image and video files from a folder to the NSFW NATS stream for analysis.
    
    Args:
        folder_path: Path to the folder containing files to process
        nats_url: The NATS server URL
        local_env: Local environment flag ("1" for local, "0" for production)
        threshold: Threshold for NSFW classification (0.0 to 1.0)
        
    Returns:
        List of files published
    """
    # Check if folder exists
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"Folder not found: {folder_path}")

    # Ensure the stream exists with the subject
    await ensure_stream_exists(
        nats_url=nats_url,
        stream_name=INPUT_STREAM,
        subjects=[INPUT_SUBJECT]
    )
    
    # Connect to NATS
    nc = await nats.connect(nats_url)
    js = nc.jetstream()
    
    # Track files published
    files_published = []
    
    try:
        # Process each file in the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            
            # Skip if it's a directory
            if os.path.isdir(file_path):
                continue
                
            # Get the file extension and check if it's supported
            _, ext = os.path.splitext(filename.lower())
            
            if ext in image_file_extensions:
                file_type = "image"
            elif ext in video_file_extensions:
                file_type = "video"
            else:
                print(f"Skipping unsupported file type: {filename}")
                continue
                
            print(f"Processing {file_type} file: {filename}")
            
            try:
                # Create a file URI based on the environment
                if local_env == "1":
                    # Using absolute path for file:// URLs in local environment
                    abs_path = os.path.abspath(file_path)
                    uri = f"file://{abs_path}"
                else:
                    # In production, use the file path directly
                    uri = f"file://{file_path}"
                
                # Generate a unique ID for this request
                request_id = str(uuid.uuid4())
                
                # Create message payload for NSFW analysis
                message = {
                    "id": request_id,
                    "batchId": f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
                    "source": {
                        "uri": uri,
                        "file_type": ext[1:],  # Remove the dot from extension
                        "file_name": filename
                    },
                    "options": {
                        "threshold": threshold
                    },
                    "state": {
                        "status": "STARTED",
                        "timestamp": datetime.now().isoformat()
                    }
                }
                
                # Create message headers
                headers = {
                    "filename": filename,
                    "file_type": file_type,
                    "content-type": "application/json"
                }
                
                # Choose the appropriate subject
                publish_subject = f"nsfw.tasks.started.{file_type}"
                
                # Publish message to NATS
                await js.publish(
                    publish_subject,
                    json.dumps(message).encode(),
                    headers=headers,
                    stream=INPUT_STREAM
                )
                
                print(f"Published {file_type} file {filename} to {publish_subject} in stream {INPUT_STREAM}")
                files_published.append({
                    "filename": filename,
                    "type": file_type,
                    "id": request_id,
                    "subject": publish_subject
                })
                
            except Exception as e:
                print(f"Error processing file {filename}: {str(e)}")
    
    finally:
        # Close NATS connection
        await nc.close()
    
    return files_published

## Process Files in Input Directory

In [None]:
async def process_folder(folder_path=INPUT_DIR, threshold=0.5):
    """
    Process all supported files in the folder for NSFW analysis
    
    Args:
        folder_path: Path to folder containing files
        threshold: Threshold for NSFW classification (0.0 to 1.0)
    """
    print(f"Publishing files from {folder_path} to NSFW NATS endpoint...")
    
    files_published = await publish_files_to_nsfw(
        folder_path, 
        threshold=threshold
    )
    
    if files_published:
        print(f"\nPublished {len(files_published)} files to NATS")
        print("Files published:")
        
        # Group by type
        images = [f for f in files_published if f["type"] == "image"]
        videos = [f for f in files_published if f["type"] == "video"]
        
        print(f"Images: {len(images)}")
        for img in images:
            print(f"- {img['filename']} (ID: {img['id']})")
            
        print(f"\nVideos: {len(videos)}")
        for vid in videos:
            print(f"- {vid['filename']} (ID: {vid['id']})")
    else:
        print("No files were published to NATS")
    
    return files_published

In [None]:
# Process all files in the input directory with a threshold of 0.5
await process_folder(INPUT_DIR, threshold=0.5)

## Process with Different Thresholds

In [None]:
# You can also process with a different threshold
# await process_folder(INPUT_DIR, threshold=0.7)

## Process Specific File Types

In [None]:
async def process_files_by_type(folder_path=INPUT_DIR, file_type="image", threshold=0.5):
    """
    Process only specific file types in the folder
    
    Args:
        folder_path: Path to folder containing files
        file_type: Type of files to process ("image" or "video")
        threshold: Threshold for NSFW classification (0.0 to 1.0)
    """
    print(f"Publishing only {file_type} files from {folder_path} to NSFW NATS endpoint...")
    
    # Create a temporary folder with only the specified file types
    temp_folder = f"{folder_path}_{file_type}_temp"
    os.makedirs(temp_folder, exist_ok=True)
    
    # Determine which extensions to look for
    if file_type == "image":
        valid_extensions = image_file_extensions
    elif file_type == "video":
        valid_extensions = video_file_extensions
    else:
        raise ValueError(f"Invalid file type: {file_type}. Must be 'image' or 'video'.")
    
    # Copy relevant files to temp folder
    import shutil
    copied_files = []
    
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        
        # Skip if it's a directory
        if os.path.isdir(file_path):
            continue
            
        # Check if file has the right extension
        _, ext = os.path.splitext(filename.lower())
        if ext in valid_extensions:
            dest_path = os.path.join(temp_folder, filename)
            shutil.copy2(file_path, dest_path)
            copied_files.append(filename)
    
    print(f"Copied {len(copied_files)} {file_type} files to temporary folder")
    
    # Process the files in the temp folder
    if copied_files:
        files_published = await publish_files_to_nsfw(
            temp_folder, 
            threshold=threshold
        )
        
        # Clean up temporary folder
        shutil.rmtree(temp_folder)
        
        return files_published
    else:
        print(f"No {file_type} files found in {folder_path}")
        # Clean up temporary folder
        shutil.rmtree(temp_folder)
        return []

In [None]:
# Process only image files
# await process_files_by_type(INPUT_DIR, file_type="image", threshold=0.5)

In [None]:
# Process only video files
# await process_files_by_type(INPUT_DIR, file_type="video", threshold=0.5)