In [None]:
import gradio as gr
from PIL import Image
import torch
import cv2
import numpy as np
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
import tempfile
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import time

# Global variables for model and processor
processor = None
model = None

# Available colormaps and models
COLORMAP_OPTIONS = {
    "INFERNO": cv2.COLORMAP_INFERNO,
    "JET": cv2.COLORMAP_JET,
    "VIRIDIS": cv2.COLORMAP_VIRIDIS,
    "PLASMA": cv2.COLORMAP_PLASMA,
    "MAGMA": cv2.COLORMAP_MAGMA,
    "HOT": cv2.COLORMAP_HOT
}

MODEL_OPTIONS = {
    "Depth Anything Small": "nielsr/depth-anything-small",
    "Depth Anything Base": "nielsr/depth-anything-base",
}

def load_model(model_name):
    """Initialize the depth estimation model and processor."""
    global processor, model
    processor = AutoImageProcessor.from_pretrained(MODEL_OPTIONS[model_name])
    model = AutoModelForDepthEstimation.from_pretrained(MODEL_OPTIONS[model_name]).cuda()
    return processor, model

def create_depth_histogram(depth_map):
    """Create a histogram visualization of depth values."""
    plt.figure(figsize=(6, 4))
    plt.hist(depth_map.flatten(), bins=50, color='blue', alpha=0.7)
    plt.title('Depth Distribution')
    plt.xlabel('Depth Value')
    plt.ylabel('Frequency')
    plt.grid(True, alpha=0.3)
    
    # Save to temporary file
    temp_path = tempfile.mktemp(suffix='.png')
    plt.savefig(temp_path)
    plt.close()
    
    return temp_path

def process_image(image, colormap="INFERNO", model_name="Depth Anything Small", 
                 show_histogram=False, enhance_contrast=False):
    """Process a single image with enhanced options."""
    if processor is None or model is None:
        load_model(model_name)
    
    # Convert to PIL Image if necessary
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Ensure image is in RGB mode
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Get image dimensions
    h, w = np.array(image).shape[:2]
    
    # Process image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values.cuda()
    
    with torch.no_grad():
        outputs = model(pixel_values)
        predicted_depth = outputs.predicted_depth
    
    # Resize depth map
    depth = torch.nn.functional.interpolate(
        predicted_depth[None],
        (h, w),
        mode='bilinear',
        align_corners=False
    )[0, 0]
    
    # Convert to numpy and normalize
    depth_np = depth.cpu().numpy()
    
    if enhance_contrast:
        # Apply histogram equalization to enhance contrast
        depth_np = cv2.equalizeHist(
            ((depth_np - depth_np.min()) * 255 / 
             (depth_np.max() - depth_np.min())).astype(np.uint8)
        )
    else:
        depth_np = ((depth_np - depth_np.min()) * 255 / 
                   (depth_np.max() - depth_np.min())).astype(np.uint8)
    
    # Apply colormap
    colored_depth = cv2.applyColorMap(depth_np, COLORMAP_OPTIONS[colormap])
    colored_depth_rgb = cv2.cvtColor(colored_depth, cv2.COLOR_BGR2RGB)
    
    # Create output dictionary
    output = {
        "processed_image": Image.fromarray(colored_depth_rgb),
        "original_image": image,
    }
    
    if show_histogram:
        output["histogram"] = create_depth_histogram(depth_np)
    
    return output

def process_video(video_path, colormap="INFERNO", model_name="Depth Anything Small", 
                 enhance_contrast=False):
    """Process video with enhanced options."""
    if processor is None or model is None:
        load_model(model_name)
    
    output_path = tempfile.mktemp(suffix='.mp4')
    
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("Error opening video file")
    
    # Get video properties
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Create video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
    
    try:
        for _ in tqdm(range(total_frames), desc="Processing video"):
            ret, frame = cap.read()
            if not ret:
                break
            
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(frame_rgb)
            
            # Process frame
            pixel_values = processor(images=pil_image, return_tensors="pt").pixel_values.cuda()
            
            with torch.no_grad():
                outputs = model(pixel_values)
                predicted_depth = outputs.predicted_depth
            
            depth = torch.nn.functional.interpolate(
                predicted_depth[None],
                (frame_height, frame_width),
                mode='bilinear',
                align_corners=False
            )[0, 0]
            
            depth_np = depth.cpu().numpy()
            
            if enhance_contrast:
                depth_np = cv2.equalizeHist(
                    ((depth_np - depth_np.min()) * 255 / 
                     (depth_np.max() - depth_np.min())).astype(np.uint8)
                )
            else:
                depth_np = ((depth_np - depth_np.min()) * 255 / 
                           (depth_np.max() - depth_np.min())).astype(np.uint8)
            
            colored_depth = cv2.applyColorMap(depth_np, COLORMAP_OPTIONS[colormap])
            out.write(colored_depth)
            
    finally:
        cap.release()
        out.release()
        cv2.destroyAllWindows()
        torch.cuda.empty_cache()
    
    return output_path

def process_input(input_data, colormap, model_name, show_histogram, enhance_contrast):
    """Process either image or video input with options."""
    if isinstance(input_data, str) and (input_data.endswith('.mp4') or 
                                      input_data.endswith('.avi') or 
                                      input_data.endswith('.mov')):
        return {
            "processed_video": process_video(
                input_data, 
                colormap, 
                model_name, 
                enhance_contrast
            )
        }
    else:
        result = process_image(
            input_data, 
            colormap, 
            model_name, 
            show_histogram, 
            enhance_contrast
        )
        # Return the specific outputs needed by Gradio
        return [
            result["processed_image"],  # For the output_image component
            result["histogram"] if show_histogram else None  # For the hist_image component
        ]

# Create enhanced Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown("""
    # 🎯 Advanced Depth Estimation App
    
    Estimate and visualize depth in images and videos using state-of-the-art AI models.
    
    ### Features:
    - Process images and videos
    - Multiple colormap options
    - Different model options
    - Depth distribution histogram
    - Contrast enhancement
    
    ### Instructions:
    1. Upload an image or video
    2. Select your preferred options
    3. Click the process button
    4. View the results and analysis
    """)
    
    with gr.Tabs():
        with gr.TabItem("Image Processing"):
            with gr.Row():
                with gr.Column(scale=1):
                    input_image = gr.Image(
                        label="Input Image", 
                        type="pil",
                        elem_id="input_image"
                    )
                    with gr.Row():
                        colormap_choice = gr.Dropdown(
                            choices=list(COLORMAP_OPTIONS.keys()),
                            value="INFERNO",
                            label="Colormap"
                        )
                        model_choice = gr.Dropdown(
                            choices=list(MODEL_OPTIONS.keys()),
                            value="Depth Anything Small",
                            label="Model"
                        )
                    with gr.Row():
                        show_hist = gr.Checkbox(
                            label="Show Histogram",
                            value=True
                        )
                        enhance = gr.Checkbox(
                            label="Enhance Contrast",
                            value=False
                        )
                    image_button = gr.Button(
                        "Process Image", 
                        variant="primary"
                    )
                
                with gr.Column(scale=1):
                    output_image = gr.Image(
                        label="Depth Map",
                        type="pil",
                        elem_id="output_image"
                    )
                    hist_image = gr.Image(
                        label="Depth Distribution",
                        type="filepath",
                        elem_id="hist_image"
                    )
        
        with gr.TabItem("Video Processing"):
            with gr.Row():
                with gr.Column(scale=1):
                    input_video = gr.Video(
                        label="Input Video",
                        elem_id="input_video"
                    )
                    with gr.Row():
                        video_colormap = gr.Dropdown(
                            choices=list(COLORMAP_OPTIONS.keys()),
                            value="INFERNO",
                            label="Colormap"
                        )
                        video_model = gr.Dropdown(
                            choices=list(MODEL_OPTIONS.keys()),
                            value="Depth Anything Small",
                            label="Model"
                        )
                    video_enhance = gr.Checkbox(
                        label="Enhance Contrast",
                        value=False
                    )
                    video_button = gr.Button(
                        "Process Video", 
                        variant="primary"
                    )
                
                with gr.Column(scale=1):
                    output_video = gr.Video(
                        label="Processed Video",
                        elem_id="output_video"
                    )
    
    with gr.Accordion("About", open=False):
        gr.Markdown("""
        ### How it works:
        This app uses deep learning models to estimate depth from regular images and videos.
        The depth is visualized using different colormaps where:
        - Darker colors (blue/purple) indicate objects closer to the camera
        - Brighter colors (yellow/white) indicate objects farther away
        
        ### Models Available:
        - Depth Anything Small: Faster, lighter model
        - Depth Anything Base: More accurate, but slower
        
        ### Tips:
        - Try different colormaps to find the best visualization
        - Enable contrast enhancement for better detail in some scenes
        - Check the histogram to understand depth distribution
        """)
    
    # Set up event handlers
    image_button.click(
        fn=process_input,
        inputs=[
            input_image,
            colormap_choice,
            model_choice,
            show_hist,
            enhance
        ],
        outputs=[
            output_image,
            hist_image,
        ]
    )
    
    video_button.click(
        fn=process_input,
        inputs=[
            input_video,
            video_colormap,
            video_model,
            gr.Checkbox(value=False, visible=False),  # dummy for show_histogram
            video_enhance
        ],
        outputs=[
            output_video
        ]
    )

# Launch the app
if __name__ == "__main__":
    app.launch(share=True)