<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/FastInpaintingNet-Jan25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step 1: Import Libraries

In [1]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import time

- torch: PyTorch library for tensor operations and neural networks.

- matplotlib.pyplot: For plotting and visualizing images.

- torchvision.utils.make_grid: Utility to create a grid of images.

- numpy: For numerical operations.

- skimage.metrics: For calculating image quality metrics like SSIM and PSNR.

- time: For measuring inference time.

# Step 2: Denormalize Function

In [2]:
def denormalize(tensor):
    """Denormalize the tensor from [-1,1] to [0,1] range"""
    tensor = tensor.clone()
    mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1).to(tensor.device)
    std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1).to(tensor.device)
    return torch.clamp(tensor * std + mean, 0, 1)

Purpose: Converts normalized tensors (with values in the range [-1, 1]) back to the original range [0, 1].

Steps:

- Clone the tensor to avoid modifying the original.

- Define mean and standard deviation tensors.

- Apply the denormalization formula: tensor = (tensor * std) + mean.

- Clamp the values to ensure they stay within [0, 1].

# Step 3: Evaluate and Visualize Function


In [3]:
def evaluate_and_visualize(model, test_loader, device, num_samples=8, save_path='inpainting_results.png'):
    """
    Evaluate the model and create detailed visualizations of the results
    """
    model.eval()

    # Lists to store metrics
    psnr_scores = []
    ssim_scores = []
    inference_times = []

    # Get a batch of test images
    batch = next(iter(test_loader))
    images = batch[0][:num_samples].to(device)

    # Create figure
    fig, axes = plt.subplots(4, num_samples, figsize=(20, 16))
    plt.suptitle('Inpainting Results', fontsize=16)

    with torch.no_grad():
        # Process each image
        for i in range(num_samples):
            # Original image
            original = denormalize(images[i])
            axes[0, i].imshow(original.cpu().permute(1, 2, 0))
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].set_title('Original')

            # Create and apply mask
            mask = create_fast_mask(original)
            masked = original * mask
            axes[1, i].imshow(mask.cpu().permute(1, 2, 0), cmap='gray')
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].set_title('Mask')

            # Masked image
            axes[2, i].imshow(masked.cpu().permute(1, 2, 0))
            axes[2, i].axis('off')
            if i == 0:
                axes[2, i].set_title('Masked Input')

            # Time the inference
            start_time = time.time()
            inpainted = model(masked.unsqueeze(0), mask.unsqueeze(0))
            inference_time = time.time() - start_time
            inference_times.append(inference_time)

            # Denormalize and show inpainted result
            inpainted = denormalize(inpainted[0])
            axes[3, i].imshow(inpainted.cpu().permute(1, 2, 0))
            axes[3, i].axis('off')
            if i == 0:
                axes[3, i].set_title('Inpainted Result')

            # Calculate metrics
            original_np = original.cpu().permute(1, 2, 0).numpy()
            inpainted_np = inpainted.cpu().permute(1, 2, 0).numpy()

            psnr_score = psnr(original_np, inpainted_np)
            ssim_score = ssim(original_np, inpainted_np, channel_axis=2, data_range=1.0)

            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)

            # Add metrics as text under the image
            axes[3, i].text(0.5, -0.2, f'PSNR: {psnr_score:.1f}\nSSIM: {ssim_score:.3f}',
                          ha='center', transform=axes[3, i].transAxes)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()

    # Print detailed evaluation results
    print("\n=== Inpainting Model Evaluation ===")
    print(f"\nImage Quality Metrics (averaged over {num_samples} samples):")
    print(f"PSNR: {np.mean(psnr_scores):.2f} dB (±{np.std(psnr_scores):.2f})")
    print(f"SSIM: {np.mean(ssim_scores):.3f} (±{np.std(ssim_scores):.3f})")

    print("\nPerformance Metrics:")
    print(f"Average inference time: {np.mean(inference_times)*1000:.1f}ms (±{np.std(inference_times)*1000:.1f}ms)")

    # Interpret results
    print("\nModel Performance Interpretation:")
    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)

    # PSNR interpretation
    print("\nPSNR Analysis:")
    if avg_psnr > 30:
        print("✓ Excellent quality (>30 dB)")
    elif avg_psnr > 25:
        print("✓ Good quality (25-30 dB)")
    else:
        print("⚠ Fair to poor quality (<25 dB)")

    # SSIM interpretation
    print("\nSSIM Analysis:")
    if avg_ssim > 0.90:
        print("✓ Excellent structural similarity (>0.90)")
    elif avg_ssim > 0.80:
        print("✓ Good structural similarity (0.80-0.90)")
    else:
        print("⚠ Fair to poor structural similarity (<0.80)")

    # Speed interpretation
    avg_time = np.mean(inference_times) * 1000
    print("\nSpeed Analysis:")
    if avg_time < 50:
        print("✓ Very fast (<50ms)")
    elif avg_time < 100:
        print("✓ Fast (50-100ms)")
    else:
        print("⚠ Moderate to slow (>100ms)")

    return {
        'psnr': np.mean(psnr_scores),
        'ssim': np.mean(ssim_scores),
        'inference_time': np.mean(inference_times)
    }