<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/Dec21_ge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import kagglehub

# Setup Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Constants
IMG_WIDTH = 256
IMG_HEIGHT = 256
BATCH_SIZE = 16  # Reduced batch size for safety
LEARNING_RATE = 0.001

# Download Data
dataset_path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")
dataset_path = f'{dataset_path}/celeba_hq_256'
print(f"Dataset path: {dataset_path}")

Using device: cuda
Downloading from https://www.kaggle.com/api/v1/datasets/download/badasstechie/celebahq-resized-256x256?dataset_version_number=1...


100%|██████████| 283M/283M [00:13<00:00, 21.3MB/s]

Extracting files...





Dataset path: /root/.cache/kagglehub/datasets/badasstechie/celebahq-resized-256x256/versions/1/celeba_hq_256


In [2]:
class InpaintingDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def create_mask(self):
        # Create a mask with a random rectangular hole
        mask = torch.ones((1, IMG_HEIGHT, IMG_WIDTH))
        h_hole, w_hole = IMG_HEIGHT // 3, IMG_WIDTH // 3

        y1 = random.randint(0, IMG_HEIGHT - h_hole)
        x1 = random.randint(0, IMG_WIDTH - w_hole)

        mask[:, y1:y1+h_hole, x1:x1+w_hole] = 0
        return mask

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        mask = self.create_mask()
        masked_img = img * mask

        return masked_img, mask, img

    def __len__(self):
        return len(self.image_paths)

In [3]:
class UnetE(nn.Module):
    def __init__(self):
        super(UnetE, self).__init__()

        # Encoder
        self.enc1 = self.double_conv(3, 64)
        self.enc2 = self.double_conv(64, 128)
        self.enc3 = self.double_conv(128, 256)

        # Decoder
        self.dec3 = self.double_conv(256 + 128, 128)
        self.dec2 = self.double_conv(128 + 64, 64)
        self.dec1 = nn.Conv2d(64, 3, kernel_size=1)

        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        d3 = self.dec3(torch.cat([self.upsample(e3), e2], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e1], dim=1))

        return torch.sigmoid(self.dec1(d2))

In [4]:
class HINTE(nn.Module):
    def __init__(self, dim=128, num_heads=4):
        super(HINTE, self).__init__()

        # Downsample: 256x256 -> 16x16 using strided convolutions
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 32, 4, stride=2, padding=1),   # 128
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 64
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # 32
            nn.ReLU(),
            nn.Conv2d(128, dim, 4, stride=2, padding=1), # 16
            nn.ReLU()
        )

        # Transformer
        self.transformer_blocks = nn.ModuleList([
            nn.Sequential(
                nn.MultiheadAttention(dim, num_heads, batch_first=True),
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * 2),
                nn.ReLU(),
                nn.Linear(dim * 2, dim),
                nn.LayerNorm(dim)
            ) for _ in range(2)
        ])

        # Upsample back to 256x256
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(dim, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        # Input (B, 4, 256, 256)
        inp = torch.cat([x, mask], dim=1)

        # Embed and Downsample -> (B, dim, 16, 16)
        features = self.encoder(inp)
        b, c, h, w = features.shape

        # Flatten for Transformer -> (B, 256, dim)
        x_flat = features.flatten(2).permute(0, 2, 1)

        for block in self.transformer_blocks:
            attn, _ = block[0](x_flat, x_flat, x_flat)
            x_flat = x_flat + attn
            x_flat = block[1](x_flat)
            mlp_out = block[4](block[3](block[2](x_flat)))
            x_flat = x_flat + mlp_out
            x_flat = block[5](x_flat)

        # Reshape back -> (B, dim, 16, 16)
        x_reshaped = x_flat.permute(0, 2, 1).view(b, c, h, w)

        # Decode -> (B, 3, 256, 256)
        return self.decoder(x_reshaped)