<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/CM_GAN_Jan5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The architecture of CM-GAN focuses on image inpainting, specifically designed to fill in missing or corrupted regions of images with realistic content. While
### **1. Generator Architecture**
The generator is responsible for creating realistic inpainted images. CM-GAN uses **cascaded modulation** to process inputs. A common flow:

#### Input:
- An **image** with missing regions (masked image).
- A **binary mask** representing the missing areas (1 for missing, 0 for existing pixels).

#### Layers:
1. **Convolutional Layers with Mask Concatenation**:
   - Initial layers concatenate the image with the binary mask.
   - Convolutions extract features from the masked regions.

   **Purpose**: Learn the structure and surrounding context of the image.

2. **Cascaded Modulation Block**:
   - Combines **global modulation** (to understand overall image semantics) with **spatially adaptive modulation** (to handle local details).
   - Global modulation uses a feature map that spans the entire image.
   - Adaptive modulation applies location-specific adjustments.

   **Purpose**: Balance global coherence and local realism.

3. **Feature Propagation via Attention Mechanisms**:
   - **Enhanced Attention** to propagate contextual information from known to unknown areas.

   **Purpose**: Ensures accurate filling of missing regions based on surrounding context.

4. **Output Layers**:
   - A final set of convolutions or deconvolutions reconstructs the inpainted image.

   **Purpose**: Generate the final high-quality inpainted output.

---

### **2. Discriminator Architecture**
The discriminator evaluates the inpainted images for realism.

1. **Input**:
   - The inpainted image (from the generator).
   - The corresponding ground truth image (actual image without missing areas).

2. **Layers**:
   - Convolutional layers extract features.
   - Outputs a **realism score**, indicating how realistic the inpainted image is.

3. **Loss Function**:
   - Often uses an **adversarial loss** (e.g., Wasserstein or hinge loss) to train the generator and discriminator in a competitive manner.

---

### **Key Components of CM-GAN**
1. **Object-Aware Training**:
   - Focuses on challenging regions, like objects, using annotations (e.g., panoptic segmentation).
   - Ensures that the generator fills object regions more realistically.

2. **Mask-Aware Encoding**:
   - Explicitly considers the mask during feature extraction.
   - Helps the generator learn to handle varied mask sizes and shapes.

3. **Enhanced Attention**:
   - Propagates information from visible areas to missing areas.
   - Improves inpainting quality for complex patterns.

---

### **How the Architecture Works**
1. **Training**:
   - The generator creates inpainted images.
   - The discriminator evaluates their realism.
   - Both networks are updated iteratively to improve their performance.

2. **Inference**:
   - Given an input image and a mask, the generator fills the missing regions.
   - No discriminator is needed during inference.


# Step 1: Gather the Dataset


In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from skimage.draw import random_shapes
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Parameters
DATA_DIR = "./places2"  # Path to your dataset
IMG_SIZE = (128, 128)  # Target image size
BATCH_SIZE = 32
MASK_TYPE = "random"  # Options: "random", "rectangle"

# Generate a binary mask
def generate_mask(img_size, mask_type="random"):
    if mask_type == "random":
        # Generate random irregular mask
        mask, _ = random_shapes(img_size, max_shapes=5, min_size=50, max_size=100, multichannel=False)
        mask = (mask == 255).astype(np.float32)  # Convert to binary mask
    elif mask_type == "rectangle":
        # Generate rectangular mask
        mask = np.ones(img_size, dtype=np.float32)
        x1, y1 = np.random.randint(0, img_size[0] // 2), np.random.randint(0, img_size[1] // 2)
        x2, y2 = np.random.randint(x1, img_size[0]), np.random.randint(y1, img_size[1])
        mask[x1:x2, y1:y2] = 0
    return mask

# Load and preprocess a single image
def load_and_preprocess_image(img_path, img_size):
    img = load_img(img_path, target_size=img_size)
    img = img_to_array(img) / 255.0  # Normalize to [0, 1]
    return img

# Create a TensorFlow Dataset
def create_dataset(data_dir, img_size, batch_size, mask_type="random"):
    # Get list of image paths
    image_paths = [os.path.join(data_dir, img_name) for img_name in os.listdir(data_dir)
                   if img_name.endswith(('.jpg', '.png', '.jpeg'))]

    # Split into training and validation sets
    train_paths, val_paths = train_test_split(image_paths, test_size=0.2, random_state=42)

    # Function to load and preprocess images and masks
    def process_image(img_path):
        # Load and preprocess image
        img = tf.numpy_function(load_and_preprocess_image, [img_path, img_size], tf.float32)
        img.set_shape(img_size + (3,))  # Set shape explicitly

        # Generate mask
        mask = tf.numpy_function(generate_mask, [img_size, mask_type], tf.float32)
        mask.set_shape(img_size)  # Set shape explicitly

        # Apply mask to image
        masked_img = img * tf.expand_dims(mask, axis=-1)

        return masked_img, img, mask

    # Create TensorFlow Dataset
    train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)
    train_dataset = train_dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    val_dataset = tf.data.Dataset.from_tensor_slices(val_paths)
    val_dataset = val_dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)
    val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return train_dataset, val_dataset

# Create the dataset
train_dataset, val_dataset = create_dataset(DATA_DIR, IMG_SIZE, BATCH_SIZE, MASK_TYPE)

# Visualize a sample batch
def visualize_batch(dataset):
    for masked_images, original_images, masks in dataset.take(1):
        plt.figure(figsize=(10, 5))
        for i in range(3):  # Display 3 samples
            plt.subplot(3, 3, i * 3 + 1)
            plt.title("Masked Image")
            plt.imshow(masked_images[i])
            plt.axis("off")

            plt.subplot(3, 3, i * 3 + 2)
            plt.title("Original Image")
            plt.imshow(original_images[i])
            plt.axis("off")

            plt.subplot(3, 3, i * 3 + 3)
            plt.title("Mask")
            plt.imshow(masks[i], cmap='gray')
            plt.axis("off")
        plt.show()

visualize_batch(train_dataset)

FileNotFoundError: [Errno 2] No such file or directory: './places2'