# AI-Driven Depth of Field Simulations with Gradio and GPU

This notebook showcases the setup of a GPU-accelerated Depth of Field simulation on uploaded images:
- Generate depth maps from images using a pre-trained model.
- Apply a depth of field effect to images based on the generated depth map.

To enhance the user experience, a Gradio interface is implemented.


In [None]:
# optional: remove warnings
import warnings
warnings.filterwarnings('ignore')

## 1. Project Folder

To keep your files organized, create a separate folder for the project. 
Place the notebook file into the project folder.

## 2. Anaconda Navigator

**Environments**: Create a new environment, providing a name and selecting the Python package. 

**Home**: Install JupyterLab and launch it. 

## 3. JupyterLab

Navigate to the project folder. Open the notebook.

Verify that you are working in the newly created environment by running:

In [None]:
!conda env list

## 4. Connecting CUDA for GPU Support

In the Anaconda Navigator, go to the **Terminal** of the current environment and run: 

## 5. Verifying CUDA-enabled GPU

Check if CUDA is available by running:

In [None]:
import torch
torch.cuda.is_available()

## 6. Install Gradio, Transformers, Pillow, OpenCV

Again, use the **Terminal** of the Anaconda environment for installation:

## 7. Importing Libraries

In [None]:
import gradio as gr
import numpy as np
import cv2
import torch
from transformers import DPTForDepthEstimation, DPTImageProcessor
from PIL import Image
import tempfile
import os
import logging

## 8. The Main Code

In [None]:
# Set up logging
logging.basicConfig(level=logging.INFO)  # Configure logging to show INFO level messages
logger = logging.getLogger(__name__)  # Create a logger for this module

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available, otherwise use CPU
logger.info(f"Using device: {device}")  # Log the device being used

# Load model and processor - https://huggingface.co/Intel/dpt-hybrid-midas
model_name = "Intel/dpt-hybrid-midas"  # Name of the pre-trained depth estimation model
model = DPTForDepthEstimation.from_pretrained(model_name).to(device)  # Load the model and move it to the appropriate device
processor = DPTImageProcessor.from_pretrained(model_name)  # Load the image processor for the model

# Define function to create a depth map
def create_depth_map(original_image):
    # Prepare image for the model
    inputs = processor(images=original_image, return_tensors="pt").to(device)  # Process the image and move it to the device
    
    # Generate depth map
    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model(**inputs)  # Run the model on the input image
        predicted_depth = outputs.predicted_depth  # Extract the predicted depth
    
    # Normalize depth values
    depth_map = torch.nn.functional.interpolate(  # Resize the depth map to match the original image size
        predicted_depth.unsqueeze(1),
        size=original_image.size[::-1],
        mode="bicubic",
        align_corners=False,
    ).squeeze().cpu().numpy()  # Convert to numpy array

    # Normalize the depth map values to a range between 0 and 1
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())  
    return depth_map

# Define function to apply depth of field effect
def apply_depth_of_field(original_image, depth_map, blur_strength=15):
    # Convert images to numpy arrays
    original_image_np = np.array(original_image)  # Convert PIL image to numpy array
    depth_map_np = (depth_map * 255).astype(np.uint8)  # Scale depth map to [0, 255]
    depth_map_np = cv2.cvtColor(depth_map_np, cv2.COLOR_GRAY2BGR)  # Convert to 3 channels for mask

    # Normalize the depth map to range [0, 1]
    depth_map_normalized = depth_map_np[:, :, 0] / 255.0

    # Create a mask based on the depth map
    threshold = np.mean(depth_map_normalized)  # Average depth as a threshold
    mask = depth_map_normalized < threshold  # Areas closer than the threshold will be blurred

    # Ensure blur_strength is odd and within a reasonable range
    blur_strength = max(3, min(blur_strength, 51))  # Limit between 3 and 51
    if blur_strength % 2 == 0:
        blur_strength += 1  # Make it odd if it's even

    # Create a blurred version of the original image
    blurred_image = cv2.GaussianBlur(original_image_np, (blur_strength, blur_strength), 0)

    # Combine the original and blurred images based on the mask
    output_image = np.where(mask[:, :, np.newaxis], blurred_image, original_image_np)

    return Image.fromarray(output_image.astype(np.uint8)), depth_map_np

# Define main image processing function
def process_image(original_image, blur_strength, state):
    try:
        if original_image is None and state is None:  # Check if no image is uploaded and no state exists
            logger.warning("No image uploaded")
            return None, None, None, None

        if original_image is not None:  # If a new image is uploaded
            state = {"original_image": original_image, "depth_map": create_depth_map(original_image)}  # Create new state
        elif state is not None:  # If state exists (reusing previous image)
            original_image = state["original_image"]  # Retrieve original image from state

        logger.info(f"Processing image with blur strength: {blur_strength}")
        
        # Apply depth of field effect
        output_image, depth_map_image = apply_depth_of_field(original_image, state["depth_map"], blur_strength)

        # Save depth map and output image as PNG files
        depth_map_pil = Image.fromarray((state["depth_map"] * 255).astype(np.uint8))  # Convert depth map to PIL Image
        temp_file_depth = tempfile.NamedTemporaryFile(delete=False, suffix='.png')  # Create temporary file for depth map
        depth_map_pil.save(temp_file_depth.name)  # Save depth map

        temp_file_output = tempfile.NamedTemporaryFile(delete=False, suffix='.png')  # Create temporary file for output image
        output_image.save(temp_file_output.name)  # Save output image

        logger.info("Processing completed successfully")
        return output_image, temp_file_depth.name, temp_file_output.name, state

    # Catch any exception during the process, log error and return None
    except Exception as e:
        logger.error(f"Error in process_image: {str(e)}")
        return None, None, None, state

# Define Gradio interface
iface = gr.Interface(
    fn=process_image,  # Main function to process images
    inputs=[
        gr.Image(type="pil", label="Upload Original Image (PNG/JPEG)"),  # Input for original image
        gr.Slider(minimum=1, maximum=50, step=1, value=15, label="Blur Strength"),  # Slider for blur strength
        "state"  # Hidden state input
    ],
    outputs=[
        gr.Image(label="Output Image with Depth of Field"),  # Output for processed image
        gr.File(label="Download Depth Map"),  # Output for depth map file
        gr.File(label="Download Processed Image"),  # Output for processed image file
        "state"  # Hidden state output
    ],
    title="Depth of Field Simulation",
    description="Upload an original image to generate a depth map and apply a shallow depth of field effect.",
    allow_flagging="never"  # Disable flagging feature
)

# Launch the interface
iface.launch(debug=True)  # Launch the Gradio interface in debug mode

# Clean up temporary files
def cleanup(files):
    for file in files:
        try:
            os.remove(file)  # Remove temporary files
        except Exception as e:
            logger.error(f"Error cleaning up file {file}: {str(e)}")

# Register cleanup function
import atexit
atexit.register(cleanup, [])  # Register cleanup function to run at exit
