In [None]:
import os
import sys
import logging
import gc
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps, ImageEnhance
from IPython.display import display, HTML
from tqdm.notebook import tqdm
from transformers import (
    BlipProcessor, 
    BlipForConditionalGeneration,
    AutoProcessor, 
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM
)
import torchvision.transforms as T
from torchvision.transforms.functional import to_tensor
import torch.nn.functional as F
import warnings
from typing import Union, Tuple, List, Dict, Optional, Any
import json
import ipywidgets as widgets
from IPython.display import clear_output

# Suppress deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Configure logger
class NotebookLoggingHandler(logging.Handler):
    """Custom logging handler for displaying logs in Jupyter notebook with style"""
    
    def emit(self, record):
        log_entry = self.format(record)
        level = record.levelname
        
        if level == 'ERROR':
            style = "color:red; font-weight:bold"
        elif level == 'WARNING':
            style = "color:orange; font-weight:bold"
        elif level == 'INFO':
            style = "color:green"
        else:
            style = "color:gray"
            
        output = f"<div style='{style}'>[{level}] {log_entry}</div>"
        display(HTML(output))

def setup_logging():
    """Configure advanced logging with custom handler"""
    logger = logging.getLogger('image_captioner')
    logger.setLevel(logging.INFO)
    
    # Add custom notebook handler for pretty display
    notebook_handler = NotebookLoggingHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s', datefmt='%H:%M:%S')
    notebook_handler.setFormatter(formatter)
    logger.addHandler(notebook_handler)
    
    # Add file handler for persistent logs
    file_handler = logging.FileHandler('image_caption_generation.log')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    
    # Make sure we don't duplicate logs if logger already exists
    logger.propagate = False
    
    return logger

logger = setup_logging()

class GPUStatsMonitor:
    """Monitor GPU memory usage during processing"""
    
    def __init__(self):
        self.has_gpu = torch.cuda.is_available()
        if self.has_gpu:
            self.device_count = torch.cuda.device_count()
        else:
            self.device_count = 0
            
    def log_memory_usage(self):
        """Log current GPU memory usage"""
        if not self.has_gpu:
            logger.info("No GPU available")
            return
            
        for i in range(self.device_count):
            total_mem = torch.cuda.get_device_properties(i).total_memory / 1024**3  # GB
            reserved = torch.cuda.memory_reserved(i) / 1024**3  # GB
            allocated = torch.cuda.memory_allocated(i) / 1024**3  # GB
            free = total_mem - reserved
            
            logger.info(f"GPU {i} - Total: {total_mem:.2f}GB | Used: {allocated:.2f}GB | " 
                      f"Reserved: {reserved:.2f}GB | Free: {free:.2f}GB")
            
    def clear_memory(self):
        """Clear GPU cache"""
        if self.has_gpu:
            torch.cuda.empty_cache()
            gc.collect()
            logger.info("GPU memory cache cleared")

class ImageProcessor:
    """Advanced image processing utilities"""
    
    @staticmethod
    def preprocess_image(
        image_path: str, 
        resize_dim: Tuple[int, int] = (224, 224),
        apply_normalization: bool = True,
        enhance_contrast: bool = False,
        contrast_factor: float = 1.2
    ) -> Image.Image:
        """
        Preprocess an image with advanced options
        
        Args:
            image_path: Path to the image
            resize_dim: Target resize dimensions
            apply_normalization: Whether to normalize image
            enhance_contrast: Whether to enhance contrast
            contrast_factor: Contrast enhancement factor
            
        Returns:
            Processed PIL Image
        """
        try:
            # Open and convert image to RGB
            image = Image.open(image_path).convert('RGB')
            
            # Apply auto-contrast adjustment
            if enhance_contrast:
                image = ImageEnhance.Contrast(image).enhance(contrast_factor)
                
            # Resize image with antialiasing
            image = image.resize(resize_dim, Image.Resampling.LANCZOS)
            
            return image
            
        except Exception as e:
            logger.error(f"Error preprocessing image: {e}")
            raise
    
    @staticmethod
    def visualize_image(image: Image.Image, figsize: Tuple[int, int] = (8, 8)) -> None:
        """
        Display the image with matplotlib
        
        Args:
            image: PIL Image to display
            figsize: Figure size for display
        """
        plt.figure(figsize=figsize)
        plt.imshow(image)
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        
    @staticmethod
    def convert_to_tensor(
        image: Image.Image, 
        device: torch.device,
        normalize: bool = True
    ) -> torch.Tensor:
        """
        Convert PIL image to normalized tensor
        
        Args:
            image: PIL Image
            device: Target device
            normalize: Whether to normalize using ImageNet stats
            
        Returns:
            Image tensor on specified device
        """
        if normalize:
            # ImageNet normalization values
            transform = T.Compose([
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            tensor = transform(image).unsqueeze(0).to(device)
        else:
            tensor = to_tensor(image).unsqueeze(0).to(device)
            
        return tensor

class ModelManager:
    """Manage different image captioning models"""
    
    AVAILABLE_MODELS = {
        "blip-base": {
            "name": "Salesforce/blip-image-captioning-base",
            "type": "blip",
            "desc": "Base BLIP model for image captioning"
        },
        "blip-large": {
            "name": "Salesforce/blip-image-captioning-large",
            "type": "blip",
            "desc": "Large BLIP model for image captioning (better quality)"
        },
        "git-base": {
            "name": "microsoft/git-base",
            "type": "git",
            "desc": "GenerativeImage2Text base model from Microsoft"
        }
    }
    
    def __init__(self, model_key: str = "blip-base", device: Optional[Union[torch.device, str]] = None):
        """
        Initialize model manager
        
        Args:
            model_key: Key for the model to load
            device: Target device, defaults to CUDA if available
        """
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        elif isinstance(device, int):
            self.device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device if isinstance(device, torch.device) else torch.device(device)
            
        self.model_key = model_key
        self.model_info = self.AVAILABLE_MODELS.get(model_key)
        
        if not self.model_info:
            raise ValueError(f"Unknown model: {model_key}. Available models: {list(self.AVAILABLE_MODELS.keys())}")
            
        self.processor = None
        self.model = None
        
    def load_model(self, use_half_precision: bool = True) -> None:
        """
        Load the specified model and processor
        
        Args:
            use_half_precision: Whether to use FP16 for faster inference with slight quality reduction
        """
        model_name = self.model_info["name"]
        model_type = self.model_info["type"]
        
        gpu_monitor = GPUStatsMonitor()
        gpu_monitor.log_memory_usage()
        
        try:
            logger.info(f"Loading {model_name} on {self.device}...")
            start_time = time.time()
            
            if model_type == "blip":
                self.processor = BlipProcessor.from_pretrained(model_name)
                self.model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device)
            elif model_type == "git":
                self.processor = AutoProcessor.from_pretrained(model_name)
                self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
            else:
                raise ValueError(f"Unsupported model type: {model_type}")
                
            # Convert to half precision if requested and supported
            if use_half_precision and self.device.type == "cuda" and torch.cuda.is_available():
                if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
                    self.model = self.model.half()
                    logger.info("Model converted to half precision (FP16)")
                else:
                    logger.warning("Half precision requested but not supported")
                    
            elapsed = time.time() - start_time
            logger.info(f"Model loaded in {elapsed:.2f} seconds")
            gpu_monitor.log_memory_usage()
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise
            
    def unload_model(self) -> None:
        """Unload model from memory"""
        if self.model is not None:
            self.model = None
            self.processor = None
            
            gpu_monitor = GPUStatsMonitor()
            gpu_monitor.clear_memory()
            gpu_monitor.log_memory_usage()
            
            logger.info("Model unloaded from memory")
            
    def generate_caption(
        self, 
        image: Image.Image,
        max_length: int = 30,
        num_beams: int = 5,
        min_length: int = 5,
        top_p: float = 0.9,
        repetition_penalty: float = 1.0,
        temperature: float = 1.0,
        num_return_sequences: int = 1
    ) -> Union[str, List[str]]:
        """
        Generate caption(s) for the given image
        
        Args:
            image: PIL Image to caption
            max_length: Maximum length of generated caption
            num_beams: Number of beams for beam search
            min_length: Minimum length of generated caption
            top_p: Top-p sampling probability
            repetition_penalty: Penalty for repeating tokens
            temperature: Sampling temperature
            num_return_sequences: Number of captions to generate
            
        Returns:
            Caption string or list of captions
        """
        if self.model is None or self.processor is None:
            raise RuntimeError("Model and processor must be loaded first")
            
        model_type = self.model_info["type"]
        
        try:
            # Process image
            if model_type == "blip":
                inputs = self.processor(images=image, return_tensors="pt")
                pixel_values = inputs.pixel_values.to(self.device)
                
                # Use torch.cuda.amp.autocast for mixed precision inference if available
                if self.device.type == "cuda" and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
                    with torch.cuda.amp.autocast():
                        with torch.no_grad():
                            generated_ids = self.model.generate(
                                pixel_values,
                                max_length=max_length,
                                min_length=min_length,
                                num_beams=num_beams,
                                top_p=top_p,
                                repetition_penalty=repetition_penalty,
                                temperature=temperature,
                                num_return_sequences=num_return_sequences,
                            )
                else:
                    with torch.no_grad():
                        generated_ids = self.model.generate(
                            pixel_values,
                            max_length=max_length,
                            min_length=min_length,
                            num_beams=num_beams,
                            top_p=top_p,
                            repetition_penalty=repetition_penalty,
                            temperature=temperature,
                            num_return_sequences=num_return_sequences,
                        )
                        
                # Decode generated ids
                if num_return_sequences > 1:
                    captions = [self.processor.decode(output_ids, skip_special_tokens=True) 
                               for output_ids in generated_ids]
                    return captions
                else:
                    caption = self.processor.decode(generated_ids[0], skip_special_tokens=True)
                    return caption
                    
            elif model_type == "git":
                inputs = self.processor(images=image, return_tensors="pt").to(self.device)
                
                with torch.no_grad():
                    generated_ids = self.model.generate(
                        pixel_values=inputs.pixel_values,
                        max_length=max_length,
                        min_length=min_length,
                        num_beams=num_beams,
                        top_p=top_p,
                        repetition_penalty=repetition_penalty,
                        temperature=temperature,
                        num_return_sequences=num_return_sequences,
                    )
                    
                # Decode generated ids
                if num_return_sequences > 1:
                    captions = [self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
                               for output_ids in generated_ids]
                    return captions
                else:
                    caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                    return caption
                    
        except Exception as e:
            logger.error(f"Error generating caption: {e}")
            raise
            
class ImageCaptioningApp:
    """Interactive image captioning application for Jupyter notebooks"""
    
    def __init__(self):
        """Initialize the application"""
        self.gpu_monitor = GPUStatsMonitor()
        self.current_image_path = None
        self.current_image = None
        self.model_manager = None
        self.model_key = "blip-base"  # Default model
        self.use_half_precision = True
        self.setup_ui()
        
    def setup_ui(self):
        """Set up interactive UI widgets"""
        # Model selection
        self.model_dropdown = widgets.Dropdown(
            options=list(ModelManager.AVAILABLE_MODELS.keys()),
            value=self.model_key,
            description='Model:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        
        # Half precision toggle
        self.half_precision_toggle = widgets.Checkbox(
            value=self.use_half_precision,
            description='Use half precision (FP16)',
            disabled=False
        )
        
        # Image upload widget
        self.file_upload = widgets.FileUpload(
            accept='.jpg,.jpeg,.png',
            multiple=False,
            description='Upload Image:',
            layout=widgets.Layout(width='300px')
        )
        
        # Caption parameters
        self.max_length_slider = widgets.IntSlider(
            value=30, min=10, max=100, step=5, 
            description='Max Length:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        
        self.num_beams_slider = widgets.IntSlider(
            value=5, min=1, max=10, step=1, 
            description='Beam Search:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        
        self.temperature_slider = widgets.FloatSlider(
            value=1.0, min=0.1, max=2.0, step=0.1, 
            description='Temperature:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        
        self.num_captions_slider = widgets.IntSlider(
            value=1, min=1, max=5, step=1, 
            description='Caption Count:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        
        # Enhanced image processing options
        self.enhance_contrast_toggle = widgets.Checkbox(
            value=False,
            description='Enhance contrast',
            disabled=False
        )
        
        self.contrast_factor_slider = widgets.FloatSlider(
            value=1.2, min=0.5, max=2.0, step=0.1, 
            description='Contrast:',
            disabled=True,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        
        # Link contrast toggle to slider
        def toggle_contrast_slider(change):
            self.contrast_factor_slider.disabled = not change['new']
        self.enhance_contrast_toggle.observe(toggle_contrast_slider, names='value')
        
        # Button to load model
        self.load_model_button = widgets.Button(
            description='Load Model',
            button_style='primary',
            tooltip='Load the selected model'
        )
        self.load_model_button.on_click(self.handle_load_model)
        
        # Button to unload model
        self.unload_model_button = widgets.Button(
            description='Unload Model',
            button_style='warning',
            tooltip='Unload model from memory',
            disabled=True
        )
        self.unload_model_button.on_click(self.handle_unload_model)
        
        # Button to generate caption
        self.generate_button = widgets.Button(
            description='Generate Caption',
            button_style='success',
            tooltip='Generate caption for the image',
            disabled=True,
            layout=widgets.Layout(width='200px')
        )
        self.generate_button.on_click(self.handle_generate_caption)
        
        # Output area
        self.output_area = widgets.Output()
        
        # Handle file upload
        self.file_upload.observe(self.handle_file_upload, names='value')
        
        # Layout all widgets
        model_section = widgets.VBox([
            widgets.HTML("<h3>Model Settings</h3>"),
            widgets.HBox([self.model_dropdown, self.half_precision_toggle]),
            widgets.HBox([self.load_model_button, self.unload_model_button])
        ])
        
        image_section = widgets.VBox([
            widgets.HTML("<h3>Image Upload</h3>"),
            self.file_upload
        ])
        
        caption_params = widgets.VBox([
            widgets.HTML("<h3>Caption Parameters</h3>"),
            self.max_length_slider,
            self.num_beams_slider,
            self.temperature_slider,
            self.num_captions_slider
        ])
        
        image_processing = widgets.VBox([
            widgets.HTML("<h3>Image Processing</h3>"),
            self.enhance_contrast_toggle,
            self.contrast_factor_slider
        ])
        
        generate_section = widgets.VBox([
            widgets.HTML("<h3>Generate</h3>"),
            self.generate_button
        ])
        
        # Main layout
        self.main_ui = widgets.VBox([
            widgets.HTML("<h2>Advanced Image Caption Generator</h2>"),
            widgets.HBox([model_section, image_section]),
            widgets.HBox([caption_params, image_processing, generate_section]),
            widgets.HTML("<h3>Results</h3>"),
            self.output_area
        ])
        
        # Display the UI
        display(self.main_ui)
        
    def handle_load_model(self, button):
        """Handle model loading button click"""
        with self.output_area:
            clear_output()
            self.model_key = self.model_dropdown.value
            self.use_half_precision = self.half_precision_toggle.value
            
            try:
                self.model_manager = ModelManager(self.model_key)
                self.model_manager.load_model(use_half_precision=self.use_half_precision)
                
                self.load_model_button.disabled = True
                self.unload_model_button.disabled = False
                
                # Enable generate button if we also have an image
                if self.current_image is not None:
                    self.generate_button.disabled = False
                    
                model_desc = ModelManager.AVAILABLE_MODELS[self.model_key]["desc"]
                logger.info(f"Model loaded successfully: {model_desc}")
                
            except Exception as e:
                logger.error(f"Failed to load model: {str(e)}")
                
    def handle_unload_model(self, button):
        """Handle model unloading button click"""
        with self.output_area:
            clear_output()
            if self.model_manager:
                self.model_manager.unload_model()
                self.model_manager = None
                
            self.load_model_button.disabled = False
            self.unload_model_button.disabled = True
            self.generate_button.disabled = True
            
    def handle_file_upload(self, change):
        """Handle file upload changes"""
        if not change.new:
            return
            
        with self.output_area:
            clear_output()
            
            try:
                # Fix: Handle widget.FileUpload value properly
                # In newer ipywidgets versions, the structure is different
                if isinstance(change.new, tuple):
                    # Handle case when change.new is a tuple (older version)
                    if len(change.new) > 0 and isinstance(change.new[0], dict):
                        uploaded_file = change.new[0]
                    else:
                        logger.error("Unexpected file upload format")
                        return
                else:
                    # Handle case when change.new is a dict with values() method
                    uploaded_file = list(change.new.values())[0]
                
                # Extract content and filename
                if 'content' in uploaded_file and 'name' in uploaded_file:
                    content = uploaded_file['content']
                    filename = uploaded_file['name']
                else:
                    logger.error("Uploaded file missing required metadata")
                    return
                
                # Save file temporarily
                with open(filename, 'wb') as f:
                    f.write(content)
                    
                self.current_image_path = filename
                logger.info(f"Image uploaded: {filename}")
                
                # Load and display image
                self.current_image = Image.open(self.current_image_path).convert("RGB")
                ImageProcessor.visualize_image(self.current_image)
                
                # Enable generate button if model is loaded
                if self.model_manager is not None:
                    self.generate_button.disabled = False
                    
            except Exception as e:
                logger.error(f"Error handling file upload: {str(e)}")
                
    def handle_generate_caption(self, button):
        """Handle generate caption button click"""
        if not self.current_image_path or not self.model_manager:
            return
            
        with self.output_area:
            clear_output()
            
            try:
                # Get parameter values
                max_length = self.max_length_slider.value
                num_beams = self.num_beams_slider.value
                temperature = self.temperature_slider.value
                num_captions = self.num_captions_slider.value
                enhance_contrast = self.enhance_contrast_toggle.value
                contrast_factor = self.contrast_factor_slider.value
                
                logger.info("Processing image...")
                
                # Process image with selected options
                processed_image = ImageProcessor.preprocess_image(
                    self.current_image_path,
                    resize_dim=(224, 224),
                    enhance_contrast=enhance_contrast,
                    contrast_factor=contrast_factor
                )
                
                # Show processed image
                logger.info("Processed image:")
                ImageProcessor.visualize_image(processed_image)
                
                # Generate caption
                logger.info("Generating caption...")
                self.gpu_monitor.log_memory_usage()
                
                # Time the caption generation
                start_time = time.time()
                
                captions = self.model_manager.generate_caption(
                    processed_image,
                    max_length=max_length,
                    num_beams=num_beams,
                    temperature=temperature,
                    num_return_sequences=num_captions
                )
                
                elapsed = time.time() - start_time
                logger.info(f"Caption generation took {elapsed:.2f} seconds")
                
                # Display results
                if isinstance(captions, list):
                    display(HTML("<h3>Generated Captions:</h3>"))
                    for i, caption in enumerate(captions):
                        display(HTML(f"<div style='margin-bottom:10px'><b>Caption {i+1}:</b> {caption}</div>"))
                else:
                    display(HTML(f"<h3>Generated Caption:</h3><div><b>{captions}</b></div>"))
                    
                # Save caption to file
                with open(f"{os.path.splitext(self.current_image_path)[0]}_caption.txt", "w") as f:
                    if isinstance(captions, list):
                        for i, caption in enumerate(captions):
                            f.write(f"Caption {i+1}: {caption}\n")
                    else:
                        f.write(captions)
                
                logger.info("Caption saved to file")
                
            except Exception as e:
                logger.error(f"Error generating caption: {str(e)}")

# Function to run the app in a Jupyter cell
def run_image_caption_app():
    """Initialize and run the image captioning application"""
    # Check for GPU
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        device_name = torch.cuda.get_device_name(device)
        device_capability = torch.cuda.get_device_capability(device)
        logger.info(f"GPU detected: {device_name} (Compute capability: {device_capability[0]}.{device_capability[1]})")
    else:
        logger.warning("No GPU detected. Using CPU. Processing will be slower.")
    
    app = ImageCaptioningApp()
    return app

# For notebook direct execution
if __name__ == "__main__":
    run_image_caption_app()

VBox(children=(HTML(value='<h2>Advanced Image Caption Generator</h2>'), HBox(children=(VBox(children=(HTML(val…

In [12]:
# ================================================================
# Auto-Enhancing Assessor with Iterative Exposure Fix & Download
# ================================================================
import cv2, numpy as np, os, tempfile, logging, matplotlib.pyplot as plt
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image, ImageFilter, ImageEnhance
from IPython.display import display, HTML, clear_output, FileLink
import ipywidgets as widgets

# ------------------------- Logging -------------------------
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger("auto_assessor")

# ---------------------- Metric Framework ----------------------
@dataclass
class Threshold:
    lower: float|None=None
    upper: float|None=None
    def check(self, v):
        if self.lower is not None and v < self.lower: return False, f"{v:.2f} < {self.lower}"
        if self.upper is not None and v > self.upper: return False, f"{v:.2f} > {self.upper}"
        return True, ""

class Metric:
    name="base"
    def __init__(self, th:Threshold): self.th=th
    def compute(self, gray, bgr): raise NotImplementedError
    def __call__(self, gray, bgr):
        val = self.compute(gray, bgr)
        ok, reason = self.th.check(val)
        return {"metric":self.name, "value":round(float(val),3), "pass_":ok, "reason":reason}

class Brightness(Metric):  name="brightness"; compute=lambda s,g,b: np.mean(g)
class Contrast(Metric):    name="contrast";   compute=lambda s,g,b: np.std(g)
class SharpLap(Metric):    name="lap_var";    compute=lambda s,g,b: cv2.Laplacian(g,cv2.CV_64F).var()
class Entropy(Metric):
    name="entropy"
    def compute(self, g, _):
        hist=cv2.calcHist([g],[0],None,[256],[0,256]).flatten()
        p=hist/hist.sum(); p=p[p>0]
        return -(p*np.log2(p)).sum()
class Exposure(Metric):
    name="exposure"
    compute=lambda s,g,b: max(np.mean(g<30), np.mean(g>225))*100
class Colorfulness(Metric):
    name="color"
    def compute(self, _, bgr):
        b,g,r = cv2.split(bgr.astype(float))
        rg, yb = np.abs(r-g), np.abs(0.5*(r+g)-b)
        return np.sqrt(rg.var()+yb.var()) + 0.3*np.sqrt(rg.mean()**2+yb.mean()**2)

DEFAULT_THRESH = {
    "brightness": Threshold(40,210),
    "contrast":   Threshold(25,None),
    "lap_var":    Threshold(120,None),
    "entropy":    Threshold(4.5,None),
    "exposure":   Threshold(None,12),
    "color":      Threshold(10,None)
}
METRIC_CLS = dict(
    brightness=Brightness, contrast=Contrast, lap_var=SharpLap,
    entropy=Entropy, exposure=Exposure, color=Colorfulness
)

def _run_metrics(gray,bgr):
    ms = [METRIC_CLS[n](t) for n,t in DEFAULT_THRESH.items()]
    with ThreadPoolExecutor() as ex:
        fut = [ex.submit(m,gray,bgr) for m in ms]
        return [f.result() for f in as_completed(fut)]

def assess_pil(img:Image.Image) -> dict:
    arr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    gray = cv2.cvtColor(arr, cv2.COLOR_BGR2GRAY)
    res = _run_metrics(gray, arr)
    res.sort(key=lambda x: x["metric"])
    return {"overall_pass": all(r["pass_"] for r in res), "metrics": res}

# --------------------- Enhancement Steps ---------------------
def denoise_pil(img:Image.Image) -> Image.Image:
    arr = np.array(img)
    den = cv2.fastNlMeansDenoisingColored(arr, None, 10,10,7,21)
    return Image.fromarray(den.astype(np.uint8))

def sharpen_pil(img:Image.Image) -> Image.Image:
    arr = np.array(img)
    kernel = np.array([[-1,-1,-1],[-1,9,-1],[-1,-1,-1]])
    sharp = cv2.filter2D(arr, -1, kernel)
    return Image.fromarray(np.clip(sharp,0,255).astype(np.uint8))

def percentile_stretch(img:Image.Image, low:int, high:int) -> Image.Image:
    arr = np.array(img)
    yuv = cv2.cvtColor(arr, cv2.COLOR_RGB2YUV)
    y = yuv[:,:,0]
    lo, hi = np.percentile(y, low), np.percentile(y, high)
    y2 = np.clip((y-lo)/(hi-lo)*255,0,255).astype(np.uint8)
    yuv[:,:,0] = y2
    rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB)
    return Image.fromarray(rgb)

# ------------------- Display Helpers -------------------
def show(img, size=(5,3)):
    plt.figure(figsize=size); plt.imshow(img); plt.axis('off'); plt.show()

def html_table(rep:dict) -> str:
    rows = "".join(
        f"<tr><td>{r['metric']}</td>"
        f"<td>{r['value']:.2f}</td>"
        f"<td style='color:{'green' if r['pass_'] else 'red'}'>{'PASS' if r['pass_'] else 'FAIL'}</td>"
        f"<td>{r['reason']}</td></tr>"
        for r in rep["metrics"]
    )
    fg, txt = ("green","PASS ✅") if rep["overall_pass"] else ("red","FAIL ❌")
    return f"""
    <table style='border-collapse:collapse;margin-top:10px'>
      <tr><th>Metric</th><th>Value</th><th>Status</th><th>Reason</th></tr>
      {rows}
      <caption style='color:{fg};font-weight:bold'>{txt}</caption>
    </table>"""

# ---------------------- Upload Widget ----------------------
upload = widgets.FileUpload(accept='.jpg,.jpeg,.png', multiple=False, description='Upload Image')
out    = widgets.Output()

def extract_meta(v):
    if isinstance(v, dict): return next(iter(v.values()))
    if isinstance(v, (list,tuple)): return v[0]
    raise ValueError

def handle_upload(change):
    if not change.new: return
    out.clear_output()
    meta = extract_meta(upload.value)
    data = meta['content']
    fname = meta.get('name') or meta.get('metadata',{}).get('name','img')
    src = os.path.join(tempfile.gettempdir(), fname)
    with open(src,'wb') as f: f.write(data)
    orig = Image.open(src).convert("RGB")

    with out:
        # 1) Original
        display(HTML("<h4>Original Image</h4>"))
        show(orig)
        rep1 = assess_pil(orig); display(HTML(html_table(rep1)))

        if not rep1["overall_pass"]:
            # 2) Denoise
            dn = denoise_pil(orig)
            display(HTML("<h4>Denoised Image</h4>")); show(dn)
            rep2 = assess_pil(dn); display(HTML(html_table(rep2)))

            # 3) Sharpen
            sp = sharpen_pil(dn)
            display(HTML("<h4>Sharpened Image</h4>")); show(sp)
            rep3 = assess_pil(sp); display(HTML(html_table(rep3)))

            # 4) Iterative exposure‐stretch
            attempts = [(1,99),(2,98),(5,95),(10,90),(15,85),(20,80)]
            final = None
            for i,(low,high) in enumerate(attempts,1):
                es = percentile_stretch(sp, low, high)
                display(HTML(f"<h4>Stretch Attempt {i}: low={low}% high={high}%</h4>"))
                show(es)
                rep_es = assess_pil(es); display(HTML(html_table(rep_es)))
                if rep_es["overall_pass"]:
                    final = es
                    break
            if final is None:
                final = es  # last attempt

            # 5) Save & Download
            out_name = f"enhanced_{fname}"
            out_path = os.path.join(os.getcwd(), out_name)
            final.save(out_path)
            display(FileLink(out_path, result_html_prefix="📥 Download final enhanced image: "))

upload.observe(handle_upload, names='value')
display(HTML("<h3>Image Quality Assessment & Auto‐Enhancement Pipeline</h3>"))
display(upload, out)

FileUpload(value=(), accept='.jpg,.jpeg,.png', description='Upload Image')

Output()