<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/Dec21_ge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import kagglehub

# Setup Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Constants
IMG_WIDTH = 256
IMG_HEIGHT = 256
BATCH_SIZE = 16  # Reduced batch size for safety
LEARNING_RATE = 0.001

# Download Data
dataset_path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")
dataset_path = f'{dataset_path}/celeba_hq_256'
print(f"Dataset path: {dataset_path}")

Using device: cuda
Downloading from https://www.kaggle.com/api/v1/datasets/download/badasstechie/celebahq-resized-256x256?dataset_version_number=1...


100%|██████████| 283M/283M [00:13<00:00, 21.3MB/s]

Extracting files...





Dataset path: /root/.cache/kagglehub/datasets/badasstechie/celebahq-resized-256x256/versions/1/celeba_hq_256


In [None]:
class InpaintingDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def create_mask(self):
        # Create a mask with a random rectangular hole
        mask = torch.ones((1, IMG_HEIGHT, IMG_WIDTH))
        h_hole, w_hole = IMG_HEIGHT // 3, IMG_WIDTH // 3

        y1 = random.randint(0, IMG_HEIGHT - h_hole)
        x1 = random.randint(0, IMG_WIDTH - w_hole)

        mask[:, y1:y1+h_hole, x1:x1+w_hole] = 0
        return mask

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        mask = self.create_mask()
        masked_img = img * mask

        return masked_img, mask, img

    def __len__(self):
        return len(self.image_paths)

In [None]:
class UnetE(nn.Module):
    def __init__(self):
        super(UnetE, self).__init__()

        # Encoder
        self.enc1 = self.double_conv(3, 64)
        self.enc2 = self.double_conv(64, 128)
        self.enc3 = self.double_conv(128, 256)

        # Decoder
        self.dec3 = self.double_conv(256 + 128, 128)
        self.dec2 = self.double_conv(128 + 64, 64)
        self.dec1 = nn.Conv2d(64, 3, kernel_size=1)

        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        d3 = self.dec3(torch.cat([self.upsample(e3), e2], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e1], dim=1))

        return torch.sigmoid(self.dec1(d2))

In [None]:
class HINTE(nn.Module):
    def __init__(self, dim=128, num_heads=4):
        super(HINTE, self).__init__()

        # Downsample: 256x256 -> 16x16 using strided convolutions
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 32, 4, stride=2, padding=1),   # 128
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 64
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # 32
            nn.ReLU(),
            nn.Conv2d(128, dim, 4, stride=2, padding=1), # 16
            nn.ReLU()
        )

        # Transformer
        self.transformer_blocks = nn.ModuleList([
            nn.Sequential(
                nn.MultiheadAttention(dim, num_heads, batch_first=True),
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * 2),
                nn.ReLU(),
                nn.Linear(dim * 2, dim),
                nn.LayerNorm(dim)
            ) for _ in range(2)
        ])

        # Upsample back to 256x256
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(dim, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        # Input (B, 4, 256, 256)
        inp = torch.cat([x, mask], dim=1)

        # Embed and Downsample -> (B, dim, 16, 16)
        features = self.encoder(inp)
        b, c, h, w = features.shape

        # Flatten for Transformer -> (B, 256, dim)
        x_flat = features.flatten(2).permute(0, 2, 1)

        for block in self.transformer_blocks:
            attn, _ = block[0](x_flat, x_flat, x_flat)
            x_flat = x_flat + attn
            x_flat = block[1](x_flat)
            mlp_out = block[4](block[3](block[2](x_flat)))
            x_flat = x_flat + mlp_out
            x_flat = block[5](x_flat)

        # Reshape back -> (B, dim, 16, 16)
        x_reshaped = x_flat.permute(0, 2, 1).view(b, c, h, w)

        # Decode -> (B, 3, 256, 256)
        return self.decoder(x_reshaped)

In [None]:
class ConvLSTMCellE(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCellE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.padding = kernel_size // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

class FastXLSTM(nn.Module):
    def __init__(self, hidden_size=64):
        super(FastXLSTM, self).__init__()
        # Feature extractor
        self.enc = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, hidden_size, 3, padding=1)
        )

        # ConvLSTM Cell (Vectorized replacement for pixel loops)
        self.conv_lstm = ConvLSTMCellE(hidden_size, hidden_size, 3, True)

        # Output decoder
        self.dec = nn.Sequential(
            nn.Conv2d(hidden_size, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.enc(x)
        b, c, h, w = features.shape

        # Initialize states
        h_state = torch.zeros(b, c, h, w).to(x.device)
        c_state = torch.zeros(b, c, h, w).to(x.device)

        # We run the ConvLSTM once (or a few steps) to refine features
        # Since it's an image, we can treat it as a sequence of length 1 (refinement)
        # or iterate if we had video frames. For static inpainting, a single pass
        # acting as a gated recurrent refiner is sufficient and fast.
        h_state, c_state = self.conv_lstm(features, (h_state, c_state))

        return self.dec(h_state)

In [None]:
class CombinedModel(nn.Module):
    def __init__(self, unet, hint, xlstm):
        super(CombinedModel, self).__init__()
        self.unet = unet
        self.hint = hint
        self.xlstm = xlstm
        # Learnable weights initialized to 0 (softmax will make them equal)
        self.weights = nn.Parameter(torch.zeros(3))

    def forward(self, x, mask):
        unet_out = self.unet(x)
        hint_out = self.hint(x, mask)
        xlstm_out = self.xlstm(x)

        w = torch.softmax(self.weights, dim=0)
        combined = (w[0] * unet_out + w[1] * hint_out + w[2] * xlstm_out)
        return combined

def train_one_epoch(model, loader, criterion, optimizer, epoch, name):
    model.train()
    running_loss = 0.0

    pbar = tqdm(loader, desc=f"{name} Epoch {epoch}")
    for masked, mask, original in pbar:
        masked, mask, original = masked.to(DEVICE), mask.to(DEVICE), original.to(DEVICE)

        optimizer.zero_grad()

        if isinstance(model, (HINTE, CombinedModel)):
            output = model(masked, mask)
        else:
            output = model(masked)

        loss = criterion(output, original)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})

    return running_loss / len(loader)

def evaluate(models, loader):
    metrics = {'Model': [], 'PSNR': [], 'L1': []}

    for name, model in models.items():
        model.eval()
        psnrs, l1s = [], []

        with torch.no_grad():
            for masked, mask, original in loader:
                masked, mask, original = masked.to(DEVICE), mask.to(DEVICE), original.to(DEVICE)

                if isinstance(model, (HINTE, CombinedModel)):
                    output = model(masked, mask)
                else:
                    output = model(masked)

                mse = torch.mean((original - output) ** 2)
                psnr = -10 * torch.log10(mse + 1e-8)
                l1 = torch.mean(torch.abs(original - output))

                psnrs.append(psnr.item())
                l1s.append(l1.item())

        metrics['Model'].append(name)
        metrics['PSNR'].append(np.mean(psnrs))
        metrics['L1'].append(np.mean(l1s))

    return pd.DataFrame(metrics)

In [8]:
# --- Main Execution ---

# 1. Prepare Data
image_paths = glob.glob(f'{dataset_path}/*.jpg')
# Limit for quick testing if needed, remove [:1000] for full run
train_paths, test_paths = train_test_split(image_paths, test_size=0.1, random_state=42)

print(f"Train size: {len(train_paths)}, Test size: {len(test_paths)}")

train_loader = DataLoader(InpaintingDataset(train_paths, transforms.ToTensor()),
                         batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(InpaintingDataset(test_paths, transforms.ToTensor()),
                        batch_size=BATCH_SIZE, shuffle=False)

# 2. Initialize Models
print("Initializing models...")
unet = UnetE().to(DEVICE)
hint = HINTE().to(DEVICE)
xlstm = FastXLSTM().to(DEVICE)
combined = CombinedModel(unet, hint, xlstm).to(DEVICE)

# 3. Training Setup
criterion = nn.MSELoss()
models_config = {
    'unet': (unet, optim.Adam(unet.parameters(), lr=LEARNING_RATE)),
    'hint': (hint, optim.Adam(hint.parameters(), lr=LEARNING_RATE)),
    'xlstm': (xlstm, optim.Adam(xlstm.parameters(), lr=LEARNING_RATE)),
    'combined': (combined, optim.Adam(combined.parameters(), lr=LEARNING_RATE))
}

# 4. Training Loop
EPOCHS = 1  # Set to desired number

for name, (model, optimizer) in models_config.items():
    print(f"\nTraining {name}...")
    loss_history = []

    for epoch in range(EPOCHS):
        loss = train_one_epoch(model, train_loader, criterion, optimizer, epoch+1, name)
        loss_history.append(loss)
        print(f"Epoch {epoch+1} Avg Loss: {loss:.4f}")

    # Save Model
    torch.save(model.state_dict(), f"{name}_final.pth")

    # Plot Loss
    plt.plot(loss_history)
    plt.title(f"{name} Loss")
    plt.savefig(f"{name}_loss.png")
    plt.close()

# 5. Evaluation
print("\nRunning Evaluation...")
results_df = evaluate({n: m for n, (m, _) in models_config.items()}, test_loader)
print(results_df)

# 6. Visualize Result
def show_sample():
    model = combined
    model.eval()
    masked, mask, orig = next(iter(test_loader))
    masked, mask, orig = masked.to(DEVICE), mask.to(DEVICE), orig.to(DEVICE)

    with torch.no_grad():
        out = model(masked, mask)

    # Convert to CPU for plotting
    masked = masked[0].cpu().permute(1,2,0)
    out = out[0].cpu().permute(1,2,0)
    orig = orig[0].cpu().permute(1,2,0)

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1); plt.imshow(masked); plt.title("Input")
    plt.subplot(1,3,2); plt.imshow(out); plt.title("Prediction")
    plt.subplot(1,3,3); plt.imshow(orig); plt.title("Ground Truth")
    plt.show()

show_sample()

Train size: 27000, Test size: 3000
Initializing models...

Training unet...


unet Epoch 1:  10%|█         | 169/1688 [01:28<13:15,  1.91it/s, loss=0.00522]


KeyboardInterrupt: 