<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/FastInpaintingNet-Jan25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step 1: Import Libraries

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
import numpy as np

Purpose: Import all necessary libraries for building and training the model, handling data, and performing image operations.

Key Libraries:

- torch: Core PyTorch library for tensor operations and neural networks.

- torch.nn: Neural network modules (e.g., layers, loss functions).

- torch.optim: Optimization algorithms (e.g., Adam).

- torchvision: Datasets, transforms, and utilities for image processing.

- numpy: Numerical computations.

# Step 2: Define the Inpainting Model


In [6]:
class FastInpaintingNet(nn.Module):
    def __init__(self):
        super(FastInpaintingNet, self).__init__()

        # Encoder: Extract features from the input
        self.encoder = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=3, padding=1),  # Input channels: 6 (image + mask)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )

        # Middle Blocks: Process features with dilated convolutions
        self.middle = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True)
        )

        # Decoder: Reconstruct the inpainted image
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),  # Output channels: 3 (RGB image)
            nn.Tanh()  # Normalize output to [-1, 1]
        )

    def forward(self, x, mask):
        # Concatenate the input image and mask along the channel dimension
        x = torch.cat([x, mask], dim=1)
        # Pass through the encoder, middle blocks, and decoder
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

Purpose: Define a neural network for image inpainting.

Key Components:

- Encoder: Reduces spatial dimensions while increasing feature depth.

- Middle Blocks: Use dilated convolutions to capture larger receptive fields.

- Decoder: Upsamples features to reconstruct the inpainted image.

- Forward Pass: Concatenates the input image and mask, processes them through the network, and outputs the inpainted image.

# Step 3: Set Up the Data Pipeline


In [7]:
def setup_data(root_dir='./data', img_size=128, batch_size=32):
    """
    Prepare the CIFAR-10 dataset for training, validation, and testing.
    """
    # Define image transformations
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),  # Resize images to the specified size
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])

    # Load the CIFAR-10 dataset
    dataset = datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform)

    # Split the dataset into train, validation, and test sets
    total_size = len(dataset)
    train_size = int(0.7 * total_size)  # 70% for training
    val_size = int(0.15 * total_size)   # 15% for validation
    test_size = total_size - train_size - val_size  # Remaining 15% for testing

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, val_loader, test_loader

Purpose: Prepare the dataset and create data loaders for training, validation, and testing.

Steps:

- Define image transformations (resize, normalize, etc.).

- Load the CIFAR-10 dataset.

- Split the dataset into train, validation, and test sets.

- Create DataLoader objects for each split.

# Step 4: Create a Mask for Inpainting


In [8]:
def create_fast_mask(image):
    """
    Create a rectangular mask for inpainting.
    """
    _, h, w = image.shape  # Get image height and width
    mask = torch.ones_like(image)  # Initialize mask with ones

    # Define a centered rectangular region to mask out
    mask_h = h // 3  # Mask height
    mask_w = w // 3  # Mask width
    top = (h - mask_h) // 2  # Top position
    left = (w - mask_w) // 2  # Left position

    # Set the rectangular region to 0 (masked area)
    mask[:, top:top+mask_h, left:left+mask_w] = 0
    return mask

Purpose: Generate a binary mask for inpainting.

Steps:

- Create a mask of ones with the same shape as the input image.

- Define a centered rectangular region and set its values to 0 (masked area).

# Step 2: Denormalize Function

In [2]:
def denormalize(tensor):
    """Denormalize the tensor from [-1,1] to [0,1] range"""
    tensor = tensor.clone()
    mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1).to(tensor.device)
    std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1).to(tensor.device)
    return torch.clamp(tensor * std + mean, 0, 1)

Purpose: Converts normalized tensors (with values in the range [-1, 1]) back to the original range [0, 1].

Steps:

- Clone the tensor to avoid modifying the original.

- Define mean and standard deviation tensors.

- Apply the denormalization formula: tensor = (tensor * std) + mean.

- Clamp the values to ensure they stay within [0, 1].

# Step 3: Evaluate and Visualize Function


In [3]:
def evaluate_and_visualize(model, test_loader, device, num_samples=8, save_path='inpainting_results.png'):
    """
    Evaluate the model and create detailed visualizations of the results
    """
    model.eval()

    # Lists to store metrics
    psnr_scores = []
    ssim_scores = []
    inference_times = []

    # Get a batch of test images
    batch = next(iter(test_loader))
    images = batch[0][:num_samples].to(device)

    # Create figure
    fig, axes = plt.subplots(4, num_samples, figsize=(20, 16))
    plt.suptitle('Inpainting Results', fontsize=16)

    with torch.no_grad():
        # Process each image
        for i in range(num_samples):
            # Original image
            original = denormalize(images[i])
            axes[0, i].imshow(original.cpu().permute(1, 2, 0))
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].set_title('Original')

            # Create and apply mask
            mask = create_fast_mask(original)
            masked = original * mask
            axes[1, i].imshow(mask.cpu().permute(1, 2, 0), cmap='gray')
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].set_title('Mask')

            # Masked image
            axes[2, i].imshow(masked.cpu().permute(1, 2, 0))
            axes[2, i].axis('off')
            if i == 0:
                axes[2, i].set_title('Masked Input')

            # Time the inference
            start_time = time.time()
            inpainted = model(masked.unsqueeze(0), mask.unsqueeze(0))
            inference_time = time.time() - start_time
            inference_times.append(inference_time)

            # Denormalize and show inpainted result
            inpainted = denormalize(inpainted[0])
            axes[3, i].imshow(inpainted.cpu().permute(1, 2, 0))
            axes[3, i].axis('off')
            if i == 0:
                axes[3, i].set_title('Inpainted Result')

            # Calculate metrics
            original_np = original.cpu().permute(1, 2, 0).numpy()
            inpainted_np = inpainted.cpu().permute(1, 2, 0).numpy()

            psnr_score = psnr(original_np, inpainted_np)
            ssim_score = ssim(original_np, inpainted_np, channel_axis=2, data_range=1.0)

            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)

            # Add metrics as text under the image
            axes[3, i].text(0.5, -0.2, f'PSNR: {psnr_score:.1f}\nSSIM: {ssim_score:.3f}',
                          ha='center', transform=axes[3, i].transAxes)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()

    # Print detailed evaluation results
    print("\n=== Inpainting Model Evaluation ===")
    print(f"\nImage Quality Metrics (averaged over {num_samples} samples):")
    print(f"PSNR: {np.mean(psnr_scores):.2f} dB (±{np.std(psnr_scores):.2f})")
    print(f"SSIM: {np.mean(ssim_scores):.3f} (±{np.std(ssim_scores):.3f})")

    print("\nPerformance Metrics:")
    print(f"Average inference time: {np.mean(inference_times)*1000:.1f}ms (±{np.std(inference_times)*1000:.1f}ms)")

    # Interpret results
    print("\nModel Performance Interpretation:")
    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)

    # PSNR interpretation
    print("\nPSNR Analysis:")
    if avg_psnr > 30:
        print("✓ Excellent quality (>30 dB)")
    elif avg_psnr > 25:
        print("✓ Good quality (25-30 dB)")
    else:
        print("⚠ Fair to poor quality (<25 dB)")

    # SSIM interpretation
    print("\nSSIM Analysis:")
    if avg_ssim > 0.90:
        print("✓ Excellent structural similarity (>0.90)")
    elif avg_ssim > 0.80:
        print("✓ Good structural similarity (0.80-0.90)")
    else:
        print("⚠ Fair to poor structural similarity (<0.80)")

    # Speed interpretation
    avg_time = np.mean(inference_times) * 1000
    print("\nSpeed Analysis:")
    if avg_time < 50:
        print("✓ Very fast (<50ms)")
    elif avg_time < 100:
        print("✓ Fast (50-100ms)")
    else:
        print("⚠ Moderate to slow (>100ms)")

    return {
        'psnr': np.mean(psnr_scores),
        'ssim': np.mean(ssim_scores),
        'inference_time': np.mean(inference_times)
    }

- **Purpose**: Evaluates the model's performance on a test dataset and visualizes the results.
- **Steps**:
  1. Set the model to evaluation mode.
  2. Initialize lists to store metrics (PSNR, SSIM, inference times).
  3. Load a batch of test images.
  4. Create a figure to display the results.
  5. For each image:
     - Display the original image.
     - Create and apply a mask.
     - Display the masked image.
     - Perform inference and measure the time taken.
     - Display the inpainted result.
     - Calculate PSNR and SSIM metrics.
  6. Save and display the figure.
  7. Print detailed evaluation results and interpret the metrics.

