In [1]:
%pip install ipywidgets

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import shutil
import csv
import random
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from diffusers import ControlNetModel, StableDiffusionInpaintPipeline
from skimage.metrics import peak_signal_noise_ratio as compute_psnr
from skimage.metrics import structural_similarity as compute_ssim
import lpips  # pip install lpips
from glob import glob
from tqdm import tqdm
import cv2  # OpenCV

# ---------------------------
# Set seeds for reproducibility
# ---------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# ---------------------------
# Utility: Mask refinement using morphological closing
# ---------------------------
def refine_mask_cv(mask, kernel_size=3, iterations=1):
    """
    Refine a binary mask using morphological closing to smooth out noise.
    Adjust kernel_size and iterations for your data.
    """
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    refined_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=iterations)
    return refined_mask

# ---------------------------
# Inpainting functions
# ---------------------------
def run_lama_inpaint(image_path, mask_path, output_path):
    from simple_lama_inpainting import SimpleLama  # import here in case it's not global
    simple_lama = SimpleLama()
    image = Image.open(image_path)
    mask = Image.open(mask_path).convert('L')
    result = simple_lama(image, mask)
    result.save(output_path)
    
def run_opencv_inpaint(image_path, mask_path, output_path, inpaintRadius=3, method=cv2.INPAINT_TELEA, refine=True):
    """
    OpenCV inpainting using the Telea algorithm.
    Reads the image and mask using OpenCV, thresholds the mask to ensure it's binary,
    optionally refines the mask, performs inpainting, and saves the result.
    """
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    # Ensure mask is binary
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    if refine:
        mask = refine_mask_cv(mask, kernel_size=3, iterations=1)
    inpainted = cv2.inpaint(image, mask, inpaintRadius, method)
    cv2.imwrite(output_path, inpainted)

# ---------------------------
# ControlNet helper functions
# ---------------------------
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

def clean_huggingface_cache(model_path):
    """Remove unnecessary Hugging Face cache directories and .lock files."""
    for root, dirs, files in os.walk(model_path, topdown=False):
        for name in files:
            if name.endswith(".lock"):
                os.remove(os.path.join(root, name))
        for name in dirs:
            if name.startswith("models--") or name == "temp":
                shutil.rmtree(os.path.join(root, name), ignore_errors=True)

def get_latest_snapshot(model_path):
    """Find and move the correct snapshot folder for a downloaded model."""
    if os.path.exists(model_path):
        for subdir in os.listdir(model_path):
            snapshot_path = os.path.join(model_path, subdir, "snapshots")
            if os.path.exists(snapshot_path):
                snapshots = sorted(os.listdir(snapshot_path), reverse=True)
                if snapshots:
                    latest_snapshot = os.path.join(snapshot_path, snapshots[0])
                    for file_name in os.listdir(latest_snapshot):
                        src = os.path.join(latest_snapshot, file_name)
                        dest = os.path.join(model_path, file_name)
                        if not os.path.exists(dest):
                            shutil.move(src, dest)
                    shutil.rmtree(os.path.dirname(latest_snapshot), ignore_errors=True)
                    return model_path
    return model_path

def check_and_download_model(model_name, model_path, is_controlnet=False):
    """Check if the model exists; if not, download and move it to the correct directory."""
    if is_controlnet:
        model_path = os.path.join(model_path, "controlnet")
    else:
        model_path = os.path.join(model_path, "stable-diffusion")

    if os.path.exists(model_path) and os.listdir(model_path):
        return

    # Download silently
    temp_dir = os.path.join("models", "temp")
    if is_controlnet:
        ControlNetModel.from_pretrained(model_name, cache_dir=temp_dir)
    else:
        StableDiffusionInpaintPipeline.from_pretrained(model_name, cache_dir=temp_dir)

    correct_model_path = get_latest_snapshot(temp_dir)
    os.makedirs(model_path, exist_ok=True)
    for file_name in os.listdir(correct_model_path):
        src = os.path.join(correct_model_path, file_name)
        dest = os.path.join(model_path, file_name)
        if not os.path.exists(dest):
            shutil.move(src, dest)
    shutil.rmtree(temp_dir, ignore_errors=True)

def load_controlnet():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    models_dir = "models"
    os.makedirs(models_dir, exist_ok=True)
    controlnet_dir = os.path.join(models_dir, "controlnet")
    stable_diffusion_dir = os.path.join(models_dir, "stable-diffusion")
    os.makedirs(controlnet_dir, exist_ok=True)
    os.makedirs(stable_diffusion_dir, exist_ok=True)

    check_and_download_model("stabilityai/stable-diffusion-2-inpainting", models_dir, is_controlnet=False)
    check_and_download_model("lllyasviel/control_v11p_sd15_inpaint", models_dir, is_controlnet=True)

    clean_huggingface_cache(models_dir)

    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        stable_diffusion_dir, torch_dtype=torch_dtype, local_files_only=True
    ).to(device, dtype=torch_dtype)
    return pipe

def make_divisible_by_8(size):
    """Ensure both width and height are divisible by 8."""
    width, height = size
    width = (width // 8) * 8
    height = (height // 8) * 8
    return width, height

def run_controlnet_inpaint(image_path, mask_path, pipe, reference_images, prompt, output_path, seed=42):
    # Open image and mask
    image = Image.open(image_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")
    original_size = image.size
    adjusted_size = make_divisible_by_8(original_size)

    conditioning = None
    if reference_images:
        conditioning = [
            img.resize(adjusted_size, Image.Resampling.LANCZOS)
            for img in reference_images
        ]

    # Create a generator with a fixed seed for reproducibility
    device = "cuda" if torch.cuda.is_available() else "cpu"
    generator = torch.Generator(device=device).manual_seed(seed)

    result = pipe(
        prompt=prompt,
        image=image.resize(adjusted_size, Image.Resampling.LANCZOS),
        mask_image=mask.resize(adjusted_size, Image.Resampling.LANCZOS),
        conditioning_image=conditioning,
        height=adjusted_size[1],
        width=adjusted_size[0],
        generator=generator
    ).images[0]
    result = result.resize(original_size, Image.Resampling.LANCZOS)
    result.save(output_path)

# ---------------------------
# LPIPS model loading
# ---------------------------
def load_lpips_model(model_dir="models/lpips"):
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, "lpips_alex.pth")
    model = lpips.LPIPS(net='alex')
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
    else:
        torch.save(model.state_dict(), model_path)
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    return model

lpips_model = load_lpips_model()

# ---------------------------
# Evaluation functions
# ---------------------------
def prepare_for_lpips(pil_image):
    tensor = transforms.ToTensor()(pil_image).unsqueeze(0)
    tensor = tensor * 2 - 1
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

def evaluate_metrics(gt_img, inpaint_img):
    gt_np = np.array(gt_img).astype(np.float32) / 255.0
    inpaint_np = np.array(inpaint_img).astype(np.float32) / 255.0

    if gt_np.shape != inpaint_np.shape:
        inpaint_img = inpaint_img.resize(gt_img.size, Image.Resampling.LANCZOS)
        inpaint_np = np.array(inpaint_img).astype(np.float32) / 255.0

    psnr = compute_psnr(gt_np, inpaint_np, data_range=1.0)

    min_size = min(gt_np.shape[0], gt_np.shape[1])
    win_size = 7 if min_size >= 7 else (min_size if min_size % 2 == 1 else min_size - 1)
    ssim = compute_ssim(gt_np, inpaint_np, win_size=win_size, channel_axis=2, data_range=1.0)

    gt_tensor = prepare_for_lpips(gt_img)
    inpaint_tensor = prepare_for_lpips(inpaint_img)
    with torch.no_grad():
        lpips_distance = lpips_model(gt_tensor, inpaint_tensor).item()

    return psnr, ssim, lpips_distance

# ---------------------------
# Main combined evaluation
# ---------------------------
if __name__ == "__main__":
    image_dir = "DUT-OMRON-image"   # Ground truth images (JPEG)
    mask_dir = "DUT-OMRON-mask"     # Masks (PNG)
    results_dir = "results"
    lama_dir = os.path.join(results_dir, "lama")
    controlnet_dir = os.path.join(results_dir, "controlnet")
    opencv_dir = os.path.join(results_dir, "opencv")  # Directory for OpenCV results
    os.makedirs(lama_dir, exist_ok=True)
    os.makedirs(controlnet_dir, exist_ok=True)
    os.makedirs(opencv_dir, exist_ok=True)

    pipe = load_controlnet()
   
    prompt = (
        "Replace the masked region with a natural extension of the surrounding background, ensuring the textures, colors, and lighting blend seamlessly. Do not recreate any specific object shapes from the mask."
    )

    evaluation_results = []
    image_paths = sorted(glob(os.path.join(image_dir, "*.*")))
    pbar = tqdm(image_paths, total=len(image_paths), desc="Processing images", leave=True)
    
    for image_path in pbar:
        filename = os.path.basename(image_path)
        basename = os.path.splitext(filename)[0]
        mask_path = os.path.join(mask_dir, basename + ".png")  # Adjust extension if needed
        if not os.path.exists(mask_path):
            # silently skip missing masks
            continue

        out_lama = os.path.join(lama_dir, filename)
        out_controlnet = os.path.join(controlnet_dir, filename)
        out_opencv = os.path.join(opencv_dir, filename)

        try:
            run_lama_inpaint(image_path, mask_path, out_lama)
        except Exception as e:
            print(f"Error in Lama for {filename}: {e}")
            continue

        try:
            reference_images = None  # or [] if needed
            # You can optionally vary the seed per image (e.g., seed = 42 + idx) for diversity
            run_controlnet_inpaint(image_path, mask_path, pipe, reference_images, prompt, out_controlnet, seed=42)
        except Exception as e:
            print(f"Error in ControlNet for {filename}: {e}")
            continue

        try:
            run_opencv_inpaint(image_path, mask_path, out_opencv, refine=True)
        except Exception as e:
            print(f"Error in OpenCV for {filename}: {e}")
            continue

        gt_image = Image.open(image_path).convert("RGB")
        lama_result = Image.open(out_lama).convert("RGB")
        controlnet_result = Image.open(out_controlnet).convert("RGB")
        opencv_result = Image.open(out_opencv).convert("RGB")
        
        lama_psnr, _, _ = evaluate_metrics(gt_image, lama_result)
        controlnet_psnr, _, _ = evaluate_metrics(gt_image, controlnet_result)
        opencv_psnr, _, _ = evaluate_metrics(gt_image, opencv_result)
        
        # Update the progress bar with the latest metrics
        pbar.set_postfix({
            "Lama_PSNR": f"{lama_psnr:.2f}",
            "ControlNet_PSNR": f"{controlnet_psnr:.2f}",
            "OpenCV_PSNR": f"{opencv_psnr:.2f}"
        })
        
        evaluation_results.append({
            'filename': filename,
            'lama_PSNR': lama_psnr,
            'controlnet_PSNR': controlnet_psnr,
            'opencv_PSNR': opencv_psnr
        })
        
    # Write results to CSV
    csv_file_path = "evaluation_results.csv"
    csv_fields = ['filename', 'lama_PSNR', 'controlnet_PSNR', 'opencv_PSNR']
    with open(csv_file_path, mode='w', newline='') as csv_file:
        writer = csv.DictWriter(csv_file, fieldnames=csv_fields)
        writer.writeheader()
        for row in evaluation_results:
            writer.writerow(row)
    
    print(f"Processing images: {len(image_paths)}/{len(image_paths)} completed. Results saved to {csv_file_path}")


  "cipher": algorithms.TripleDES,
  "class": algorithms.Blowfish,
  "class": algorithms.TripleDES,


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: C:\Users\ng_mi\AppData\Roaming\Python\Python312\site-packages\lpips\weights\v0.1\alex.pth


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

Processing images:   0%|          | 0/5168 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 1/5168 [00:05<8:13:29,  5.73s/it, Lama_PSNR=13.53, ControlNet_PSNR=13.74, OpenCV_PSNR=14.42]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 2/5168 [00:10<7:17:24,  5.08s/it, Lama_PSNR=17.19, ControlNet_PSNR=16.68, OpenCV_PSNR=15.83]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 3/5168 [00:15<7:13:47,  5.04s/it, Lama_PSNR=22.79, ControlNet_PSNR=22.39, OpenCV_PSNR=21.31]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 4/5168 [00:19<6:52:31,  4.79s/it, Lama_PSNR=15.66, ControlNet_PSNR=10.60, OpenCV_PSNR=11.02]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 5/5168 [00:24<6:36:01,  4.60s/it, Lama_PSNR=16.55, ControlNet_PSNR=16.44, OpenCV_PSNR=21.23]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 6/5168 [00:28<6:37:10,  4.62s/it, Lama_PSNR=18.57, ControlNet_PSNR=13.69, OpenCV_PSNR=13.49]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 7/5168 [00:33<6:37:30,  4.62s/it, Lama_PSNR=20.18, ControlNet_PSNR=14.30, OpenCV_PSNR=14.74]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 8/5168 [00:38<6:58:19,  4.86s/it, Lama_PSNR=17.69, ControlNet_PSNR=14.67, OpenCV_PSNR=14.83]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 9/5168 [00:43<6:59:03,  4.87s/it, Lama_PSNR=16.56, ControlNet_PSNR=14.90, OpenCV_PSNR=16.41]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 10/5168 [00:47<6:41:24,  4.67s/it, Lama_PSNR=16.27, ControlNet_PSNR=15.46, OpenCV_PSNR=14.64]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 11/5168 [00:52<6:54:34,  4.82s/it, Lama_PSNR=23.26, ControlNet_PSNR=16.25, OpenCV_PSNR=19.24]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 12/5168 [00:57<6:44:04,  4.70s/it, Lama_PSNR=22.04, ControlNet_PSNR=18.64, OpenCV_PSNR=19.14]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 13/5168 [01:01<6:40:59,  4.67s/it, Lama_PSNR=18.33, ControlNet_PSNR=15.12, OpenCV_PSNR=18.09]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 14/5168 [01:06<6:39:12,  4.65s/it, Lama_PSNR=15.84, ControlNet_PSNR=14.00, OpenCV_PSNR=14.63]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 15/5168 [01:11<6:47:23,  4.74s/it, Lama_PSNR=19.15, ControlNet_PSNR=16.89, OpenCV_PSNR=20.16]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 16/5168 [01:16<6:56:51,  4.85s/it, Lama_PSNR=24.67, ControlNet_PSNR=25.99, OpenCV_PSNR=26.16]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 17/5168 [01:21<6:59:00,  4.88s/it, Lama_PSNR=13.96, ControlNet_PSNR=12.68, OpenCV_PSNR=16.05]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 18/5168 [01:26<6:51:16,  4.79s/it, Lama_PSNR=16.64, ControlNet_PSNR=16.33, OpenCV_PSNR=15.74]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 19/5168 [01:31<6:54:38,  4.83s/it, Lama_PSNR=17.18, ControlNet_PSNR=13.63, OpenCV_PSNR=15.75]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 20/5168 [01:35<6:51:40,  4.80s/it, Lama_PSNR=16.55, ControlNet_PSNR=17.87, OpenCV_PSNR=20.41]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 21/5168 [01:40<6:38:10,  4.64s/it, Lama_PSNR=21.04, ControlNet_PSNR=20.30, OpenCV_PSNR=20.51]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 22/5168 [01:45<6:47:24,  4.75s/it, Lama_PSNR=17.35, ControlNet_PSNR=13.37, OpenCV_PSNR=15.30]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 23/5168 [01:50<6:52:05,  4.81s/it, Lama_PSNR=47.21, ControlNet_PSNR=26.95, OpenCV_PSNR=38.52]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 24/5168 [01:54<6:41:22,  4.68s/it, Lama_PSNR=22.46, ControlNet_PSNR=22.90, OpenCV_PSNR=27.05]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   0%|          | 25/5168 [01:58<6:25:58,  4.50s/it, Lama_PSNR=13.65, ControlNet_PSNR=14.04, OpenCV_PSNR=16.95]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   1%|          | 26/5168 [02:02<6:20:12,  4.44s/it, Lama_PSNR=16.82, ControlNet_PSNR=22.83, OpenCV_PSNR=31.46]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   1%|          | 27/5168 [02:07<6:28:25,  4.53s/it, Lama_PSNR=15.97, ControlNet_PSNR=15.73, OpenCV_PSNR=17.70]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   1%|          | 28/5168 [02:11<6:24:35,  4.49s/it, Lama_PSNR=13.39, ControlNet_PSNR=14.96, OpenCV_PSNR=19.29]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   1%|          | 29/5168 [02:16<6:25:33,  4.50s/it, Lama_PSNR=20.93, ControlNet_PSNR=15.08, OpenCV_PSNR=16.80]

  0%|          | 0/50 [00:00<?, ?it/s]

Processing images:   1%|          | 30/5168 [02:21<6:32:32,  4.58s/it, Lama_PSNR=25.21, ControlNet_PSNR=25.66, OpenCV_PSNR=32.12]

  0%|          | 0/50 [00:00<?, ?it/s]