<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/BIDS_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bidirectional Interaction Dual-Stream Network (BIDS-Net)

> The article proposes a novel approach for image inpainting called the Bidirectional Interaction Dual-Stream Network (BIDS-Net), integrating CNN and Transformer models to enhance inpainting quality by leveraging their complementary strengths.

## Methodology Overview
1. Dual-Stream Structure:

- CNN Stream: Captures rich local patterns and refines details.
- Transformer Stream: Models long-range contextual correlations for global information.
- Both streams are based on a U-shaped encoder-decoder structure to facilitate efficient multi-scale context reasoning.

2. Bidirectional Feature Interaction (BFI):

Implements **bidirectional feature alignment and fusion** between the CNN and Transformer streams.
Employs **Selective Feature Fusion (SFF)** for adaptive feature integration by learning channel weights.
3. Fast Global Self-Attention:

Utilizes a **kernelizable fast-attention mechanism** for the Transformer, reducing computational complexity to linear.
4. Loss Functions:

Combines pixel-wise reconstruction, adversarial, perceptual, and style losses to ensure inpainting quality and perceptual consistency.

- Channel allocation: Optimal performance when CNN and Transformer streams have equal importance.
- Fusion methods: Bidirectional fusion outperforms unidirectional and unified-path approaches.
- Specific fusion techniques: SFF surpasses element-wise addition and concatenation.
- Number of random features: Optimal trade-off achieved with 72 orthogonal random features.

### **Mask Creation Process**

#### 1. **Purpose of Masking in Image Inpainting**
   - Masks simulate corrupted regions by marking areas of an image for restoration.
   - Masks represent regions with **value 1** (corrupted) and **value 0** (uncorrupted), facilitating selective processing during training.

#### 2. **Mask Datasets**
   - **Mask Set I**: Contains irregular shapes with various hole-to-image area ratios (10%–60%) to simulate real-world image corruption scenarios.
   - **Mask Set II**: Focuses on **large-scale corruptions**, derived from a large mask sampling strategy, targeting challenges in **large-hole inpainting**.

#### 3. **Techniques for Mask Creation**
   - **Random Irregular Masks**:
     - Generated using freehand-like curves and random polygons.
     - Often involve **random rotations** and **flipping** for augmentation.
   - **Large-Hole Masks**:
     - Created by sampling large continuous regions, ensuring high diversity in shape and size.
   - **Tools and Libraries**:
     - Python libraries like **OpenCV** and **NumPy** for procedural generation of irregular shapes.
     - **External mask datasets** for additional diversity, e.g., Mask datasets from previous works such as [29].

---

### **Model Architecture: BIDS-Net**

#### 1. **Overall Structure**
   - A **dual-stream network** combining **CNN** and **Transformer** models in a parallel design.
   - Built on a **U-shaped encoder-decoder structure** for multi-scale feature extraction.

#### 2. **Key Components**
   - **CNN Stream**:
     - Focus: Capturing **local patterns** for texture refinement.
     - Built with **pre-activation residual blocks** for efficient and robust feature learning.
   - **Transformer Stream**:
     - Focus: Modeling **long-range contextual correlations**.
     - Uses **fast global self-attention** for scalability and reduced computational overhead.
   - **Bidirectional Feature Interaction (BFI)**:
     - Bridges the CNN and Transformer streams with **feature alignment** and **adaptive fusion**.

#### 3. **Detailed Implementation Steps**
   - **Input Projection**:
     - Corrupted images and masks are projected into separate feature spaces for the CNN and Transformer streams.
     - Transformer features are downsampled to balance computational cost and performance.
   - **Encoding Stage**:
     - Each stream extracts features using **convolutional blocks (CNN)** and **Transformer blocks**.
     - Features are fused bidirectionally via the **BFI module**.
   - **Bottleneck Stage**:
     - Features from both streams interact for enhanced context reasoning at the lowest spatial resolution.
   - **Decoding Stage**:
     - Outputs from both streams are upsampled and concatenated for final refinement.
   - **Output Projection**:
     - Combined features are transformed back to the image space for inpainting results.

---

### **Relevant Techniques and Algorithms**

#### 1. **Fast Global Self-Attention**
   - Reduces standard attention's quadratic complexity to linear using:
     - **Kernelizable Attention**: Positive orthogonal random features replace softmax attention.
     - Ensures **scalability** and efficiency for high-resolution images.

#### 2. **Selective Feature Fusion (SFF)**
   - Adapts weights for each channel during fusion, ensuring:
     - CNN benefits from Transformer’s global context.
     - Transformer incorporates CNN’s local details.
   - Based on the **Selective Kernel Convolution** technique.

#### 3. **Loss Functions**
   - **Pixel-wise Reconstruction Loss**: Ensures pixel-level consistency.
   - **Adversarial Loss**: Improves texture realism by incorporating a discriminator network.
   - **Perceptual Loss**: Derived from a pre-trained VGG-19, enhancing perceptual similarity.
   - **Style Loss**: Preserves stylistic details using Gram matrices.

---

### **Tools and Libraries**
   - **Frameworks**: PyTorch (1.10.1), TensorFlow for alternate implementations.
   - **Visualization**: Matplotlib or OpenCV for displaying masks and inpainted results.
   - **GPU Hardware**: Tested on NVIDIA GeForce RTX 3090 for performance.

---

### **Considerations**
   - **Mask Diversity**: Critical for generalization across various corruption scenarios.
   - **Computational Efficiency**: Striking a balance between accuracy and runtime, particularly with Transformer integration.
   - **Evaluation Metrics**:
     - Quantitative: PSNR, SSIM, FID, LPIPS.
     - Qualitative: Visual coherence and texture consistency.



## Dataset

In [8]:
!pip install datasets -q


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m471.0/480.6 kB[0m [31m15.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [9]:
import os
import glob
import random
from PIL import Image
from torch.utils.data import Dataset,random_split
from datasets import Dataset, DatasetDict
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [10]:
model_name="caltech256-BIDS"

from google.colab import drive
drive.mount('/content/drive')

import os

CHECKPOINTS_DIR = '/content/drive/MyDrive/ckpts'

def save_checkpoint(model, optimizer, epoch):
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    checkpoint_path = f'{CHECKPOINTS_DIR}/{model_name}.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)
    print(f"ckpt saved for {model_name} at epoch {epoch}.")

def load_checkpoint(model, optimizer):
    ckpt_path = f'{CHECKPOINTS_DIR}/{model_name}.pth'
    if not os.path.exists(ckpt_path):
        print(f"no ckpt found for {model_name} starting from epoch 0.")
        return 0

    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"ckpt loaded for {model_name} from {ckpt_path}. resuming from epoch {start_epoch}.")

    return start_epoch

Mounted at /content/drive


In [11]:
import kagglehub
# Download latest version
path = kagglehub.dataset_download("jessicali9530/caltech256")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/jessicali9530/caltech256?dataset_version_number=2...


100%|██████████| 2.12G/2.12G [00:28<00:00, 79.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/jessicali9530/caltech256/versions/2


In [1]:
import cv2
import numpy as np
import random

def generate_irregular_mask(height, width, max_vertices=12, max_brush_width=50):
    """
    Generates an irregular mask with random shapes.
    Args:
        height (int): Height of the mask.
        width (int): Width of the mask.
        max_vertices (int): Maximum number of vertices for random polygons.
        max_brush_width (int): Maximum brush width for freehand-like curves.
    Returns:
        mask (np.array): Binary mask with irregular shapes.
    """
    mask = np.zeros((height, width), dtype=np.uint8)
    num_vertices = random.randint(3, max_vertices)
    vertices = np.array([[
        random.randint(0, width),
        random.randint(0, height)
    ] for _ in range(num_vertices)], dtype=np.int32)
    cv2.fillPoly(mask, [vertices], 1)

    # Add freehand-like curves
    for _ in range(random.randint(1, 5)):
        start_point = (random.randint(0, width), random.randint(0, height))
        end_point = (random.randint(0, width), random.randint(0, height))
        thickness = random.randint(10, max_brush_width)
        cv2.line(mask, start_point, end_point, 1, thickness)

    return mask

def generate_large_hole_mask(height, width, min_size=0.3, max_size=0.6):
    """
    Generates a mask with large holes.
    Args:
        height (int): Height of the mask.
        width (int): Width of the mask.
        min_size (float): Minimum size of the hole relative to the image.
        max_size (float): Maximum size of the hole relative to the image.
    Returns:
        mask (np.array): Binary mask with large holes.
    """
    mask = np.zeros((height, width), dtype=np.uint8)
    hole_size = random.uniform(min_size, max_size)
    hole_height = int(height * hole_size)
    hole_width = int(width * hole_size)

    x = random.randint(0, width - hole_width)
    y = random.randint(0, height - hole_height)
    cv2.rectangle(mask, (x, y), (x + hole_width, y + hole_height), 1, -1)

    return mask

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PreActResidualBlock(nn.Module):
    """Pre-activation Residual Block for CNN stream."""
    def __init__(self, in_channels, out_channels):
        super(PreActResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.BatchNorm2d(in_channels)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        return out + identity

class FastGlobalSelfAttention(nn.Module):
    """Fast Global Self-Attention for Transformer stream."""
    def __init__(self, dim, num_heads):
        super(FastGlobalSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.to_qkv = nn.Linear(dim, dim * 3)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2), qkv)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, H * W, C)
        out = self.to_out(out)
        out = out.transpose(1, 2).view(B, C, H, W)
        return out

class BFI(nn.Module):
    """Bidirectional Feature Interaction module."""
    def __init__(self, channels):
        super(BFI, self).__init__()
        self.cnn_norm = nn.BatchNorm2d(channels)
        self.trans_norm = nn.BatchNorm2d(channels)
        self.fusion = nn.Conv2d(channels * 2, channels, kernel_size=1)

    def forward(self, cnn_feat, trans_feat):
        cnn_feat = self.cnn_norm(cnn_feat)
        trans_feat = self.trans_norm(trans_feat)
        fused = torch.cat([cnn_feat, trans_feat], dim=1)
        return self.fusion(fused)

class BIDS_Net(nn.Module):
    """BIDS-Net architecture."""
    def __init__(self, in_channels=3, out_channels=3, num_heads=8):
        super(BIDS_Net, self).__init__()
        # CNN Stream
        self.cnn_stream = nn.Sequential(
            PreActResidualBlock(in_channels, 64),
            PreActResidualBlock(64, 128),
            PreActResidualBlock(128, 256)
        )
        # Transformer Stream
        self.trans_stream = nn.Sequential(
            FastGlobalSelfAttention(64, num_heads),
            FastGlobalSelfAttention(128, num_heads),
            FastGlobalSelfAttention(256, num_heads)
        )
        # BFI Modules
        self.bfi1 = BFI(64)
        self.bfi2 = BFI(128)
        self.bfi3 = BFI(256)
        # Output Projection
        self.output = nn.Conv2d(256, out_channels, kernel_size=1)

    def forward(self, x, mask):
        # CNN Stream
        cnn_feat1 = self.cnn_stream[0](x)
        cnn_feat2 = self.cnn_stream[1](cnn_feat1)
        cnn_feat3 = self.cnn_stream[2](cnn_feat2)
        # Transformer Stream
        trans_feat1 = self.trans_stream[0](x)
        trans_feat2 = self.trans_stream[1](trans_feat1)
        trans_feat3 = self.trans_stream[2](trans_feat2)
        # BFI Fusion
        fused1 = self.bfi1(cnn_feat1, trans_feat1)
        fused2 = self.bfi2(cnn_feat2, trans_feat2)
        fused3 = self.bfi3(cnn_feat3, trans_feat3)
        # Output
        out = self.output(fused3)
        return out * mask + x * (1 - mask)

In [3]:
# Example usage
model = BIDS_Net()
input_image = torch.randn(1, 3, 256, 256)  # Example input
mask = torch.from_numpy(generate_irregular_mask(256, 256)).unsqueeze(0).unsqueeze(0).float()
output = model(input_image, mask)

# Loss functions
criterion_l1 = nn.L1Loss()
criterion_perceptual = nn.MSELoss()  # Replace with VGG-based perceptual loss
loss = criterion_l1(output, input_image) + criterion_perceptual(output, input_image)
loss.backward()

RuntimeError: The size of tensor a (64) must match the size of tensor b (3) at non-singleton dimension 1

In [4]:
class BIDS_Net(nn.Module):
    """BIDS-Net architecture."""
    def __init__(self, in_channels=3, out_channels=3, num_heads=8):
        super(BIDS_Net, self).__init__()
        # CNN Stream
        self.cnn_stream = nn.Sequential(
            PreActResidualBlock(in_channels, 64),
            PreActResidualBlock(64, 128),
            PreActResidualBlock(128, 256)
        )
        # Transformer Stream
        self.trans_proj = nn.Conv2d(in_channels, 64, kernel_size=1)  # Projection layer
        self.trans_stream = nn.Sequential(
            FastGlobalSelfAttention(64, num_heads),
            FastGlobalSelfAttention(128, num_heads),
            FastGlobalSelfAttention(256, num_heads)
        )
        # BFI Modules
        self.bfi1 = BFI(64)
        self.bfi2 = BFI(128)
        self.bfi3 = BFI(256)
        # Output Projection
        self.output = nn.Conv2d(256, out_channels, kernel_size=1)

    def forward(self, x, mask):
        # CNN Stream
        cnn_feat1 = self.cnn_stream[0](x)
        cnn_feat2 = self.cnn_stream[1](cnn_feat1)
        cnn_feat3 = self.cnn_stream[2](cnn_feat2)

        # Transformer Stream
        trans_feat = self.trans_proj(x)  # Project input to match channels
        trans_feat1 = self.trans_stream[0](trans_feat)
        trans_feat2 = self.trans_stream[1](trans_feat1)
        trans_feat3 = self.trans_stream[2](trans_feat2)

        # BFI Fusion
        fused1 = self.bfi1(cnn_feat1, trans_feat1)
        fused2 = self.bfi2(cnn_feat2, trans_feat2)
        fused3 = self.bfi3(cnn_feat3, trans_feat3)

        # Output
        out = self.output(fused3)
        return out * mask + x * (1 - mask)

In [5]:
# Ensure mask has the same spatial dimensions as the input image
mask = torch.from_numpy(generate_irregular_mask(256, 256)).unsqueeze(0).unsqueeze(0).float()
mask = mask.expand(input_image.size(0), -1, -1, -1)  # Match batch size

In [6]:
def forward(self, x, mask):
    print(f"Input shape: {x.shape}, Mask shape: {mask.shape}")

    # CNN Stream
    cnn_feat1 = self.cnn_stream[0](x)
    print(f"CNN Feat1 shape: {cnn_feat1.shape}")
    cnn_feat2 = self.cnn_stream[1](cnn_feat1)
    print(f"CNN Feat2 shape: {cnn_feat2.shape}")
    cnn_feat3 = self.cnn_stream[2](cnn_feat2)
    print(f"CNN Feat3 shape: {cnn_feat3.shape}")

    # Transformer Stream
    trans_feat = self.trans_proj(x)
    print(f"Trans Feat shape: {trans_feat.shape}")
    trans_feat1 = self.trans_stream[0](trans_feat)
    print(f"Trans Feat1 shape: {trans_feat1.shape}")
    trans_feat2 = self.trans_stream[1](trans_feat1)
    print(f"Trans Feat2 shape: {trans_feat2.shape}")
    trans_feat3 = self.trans_stream[2](trans_feat2)
    print(f"Trans Feat3 shape: {trans_feat3.shape}")

    # BFI Fusion
    fused1 = self.bfi1(cnn_feat1, trans_feat1)
    print(f"Fused1 shape: {fused1.shape}")
    fused2 = self.bfi2(cnn_feat2, trans_feat2)
    print(f"Fused2 shape: {fused2.shape}")
    fused3 = self.bfi3(cnn_feat3, trans_feat3)
    print(f"Fused3 shape: {fused3.shape}")

    # Output
    out = self.output(fused3)
    print(f"Output shape: {out.shape}")
    return out * mask + x * (1 - mask)

In [7]:
# Example usage
model = BIDS_Net()
input_image = torch.randn(1, 3, 256, 256)  # Example input
mask = torch.from_numpy(generate_irregular_mask(256, 256)).unsqueeze(0).unsqueeze(0).float()
mask = mask.expand(input_image.size(0), -1, -1, -1)  # Match batch size

output = model(input_image, mask)
print(f"Final output shape: {output.shape}")

RuntimeError: The size of tensor a (64) must match the size of tensor b (3) at non-singleton dimension 1