<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/Image_inpaint_New_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [15]:
from huggingface_hub import login

login('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', add_to_git_credential=True)

Token is valid (permission: fineGrained).
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskAwareEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Conv2d(4, dim, 1)  # Changed to Conv2d for spatial dimensions

    def forward(self, x, mask):
        # Combine image and mask
        mask = mask.mean(dim=1, keepdim=True)  # [B, 1, H, W]
        combined = torch.cat([x, mask], dim=1)  # [B, C+1, H, W]
        return self.proj(combined)

class EnhancedAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.qkv = nn.Conv2d(dim, dim * 3, 1)
        self.proj = nn.Conv2d(dim, dim, 1)

    def forward(self, x, mask):
        B, C, H, W = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, 3, C, H * W).permute(1, 0, 2, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * (1.0 / C ** 0.5)
        mask_flat = mask.mean(dim=1).view(B, 1, H * W)
        attn = attn * (1 - mask_flat)
        attn = F.softmax(attn, dim=-1)

        x = (attn @ v).reshape(B, C, H, W)
        return self.proj(x)

class UNetHINT(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=32):
        super().__init__()

        # Encoder
        self.enc1 = self._block(in_channels, features)
        self.enc2 = self._block(features, features * 2)
        self.enc3 = self._block(features * 2, features * 4)
        self.enc4 = self._block(features * 4, features * 8)

        # Bottleneck
        self.bottleneck = self._block(features * 8, features * 16)

        # HINT components
        self.mask_encoding = MaskAwareEncoding(features * 16)
        self.attention = EnhancedAttention(features * 16)

        # Decoder
        self.dec4 = self._block(features * 24, features * 8)  # 16 + 8 from skip
        self.dec3 = self._block(features * 12, features * 4)  # 8 + 4 from skip
        self.dec2 = self._block(features * 6, features * 2)   # 4 + 2 from skip
        self.dec1 = self._block(features * 3, features)       # 2 + 1 from skip

        # Final layers
        self.final = nn.Sequential(
            nn.Conv2d(features, out_channels, 1),
            nn.Tanh()
        )

        # Pooling and upsampling
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, mask):
        # Encoder path
        e1 = self.enc1(x)           # features
        e2 = self.enc2(self.pool(e1))  # features * 2
        e3 = self.enc3(self.pool(e2))  # features * 4
        e4 = self.enc4(self.pool(e3))  # features * 8

        # Bottleneck
        b = self.bottleneck(self.pool(e4))  # features * 16

        # HINT processing
        encoded = self.mask_encoding(b, mask)
        attended = self.attention(encoded, mask)

        # Decoder path with skip connections
        d4 = self.dec4(torch.cat([self.upsample(attended), e4], dim=1))  # features * 8
        d3 = self.dec3(torch.cat([self.upsample(d4), e3], dim=1))       # features * 4
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))       # features * 2
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))       # features

        # Final output
        return self.final(d1)

def test_model():
    """Test the model with dummy data to verify dimensions"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNetHINT().to(device)

    # Create dummy input
    batch_size = 2
    channels = 3
    height = 256
    width = 256

    x = torch.randn(batch_size, channels, height, width).to(device)
    mask = torch.ones(batch_size, channels, height, width).to(device)
    mask[:, :, 100:150, 100:150] = 0  # Create a sample mask

    # Test forward pass
    try:
        output = model(x, mask)
        print(f"Input shape: {x.shape}")
        print(f"Output shape: {output.shape}")
        print("Model test successful!")
    except Exception as e:
        print(f"Error during model test: {e}")

if __name__ == "__main__":
    test_model()

Error during model test: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 256 for tensor number 1 in the list.


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from datasets import load_dataset
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
import pickle
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

class CelebAHQDataset(Dataset):
    def __init__(self, dataset, img_size=256, mask_size=100, cache_size=1000):
        self.ds = dataset
        self.img_size = img_size
        self.mask_size = mask_size
        self.cache_size = cache_size
        self.cache = {}

        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.ds)

    def create_mask(self):
        """Creates a random square mask"""
        # Initialize mask with ones
        mask = torch.ones(3, self.img_size, self.img_size)

        # Ensure mask_size is not larger than image_size
        effective_mask_size = min(self.mask_size, self.img_size - 1)

        # Generate random position for mask
        try:
            top = torch.randint(0, max(1, self.img_size - effective_mask_size), (1,)).item()
            left = torch.randint(0, max(1, self.img_size - effective_mask_size), (1,)).item()

            # Create the mask
            mask[:, top:top+effective_mask_size, left:left+effective_mask_size] = 0
        except RuntimeError as e:
            print(f"Error in create_mask: {e}")
            print(f"Image size: {self.img_size}, Mask size: {effective_mask_size}")
            # Fallback to a centered mask
            center = self.img_size // 2
            half_mask = effective_mask_size // 2
            mask[:, center-half_mask:center+half_mask, center-half_mask:center+half_mask] = 0

        return mask

    def __getitem__(self, idx):
        if idx in self.cache:
            image = self.cache[idx]
        else:
            image = self.transform(self.ds[idx]['image'])
            if len(self.cache) < self.cache_size:
                self.cache[idx] = image

        mask = self.create_mask()
        masked_image = image * mask
        return image, masked_image, mask

class MetricsTracker:
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.psnr_scores = []
        self.ssim_scores = []

    def update_train(self, loss):
        self.train_losses.append(loss)

    def update_val(self, loss):
        self.val_losses.append(loss)

    def update_metrics(self, psnr, ssim):
        self.psnr_scores.append(psnr)
        self.ssim_scores.append(ssim)

    def plot_metrics(self, save_path='metrics.png'):
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

        # Plot training loss
        ax1.plot(self.train_losses)
        ax1.set_title('Training Loss')
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Loss')

        # Plot validation loss
        ax2.plot(self.val_losses)
        ax2.set_title('Validation Loss')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')

        # Plot PSNR
        ax3.plot(self.psnr_scores)
        ax3.set_title('PSNR Score')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('PSNR')

        # Plot SSIM
        ax4.plot(self.ssim_scores)
        ax4.set_title('SSIM Score')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('SSIM')

        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

def save_model(model, optimizer, metrics, epoch, filename='model.pkl'):
    """Save the model, optimizer state, and metrics"""
    state = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics.__dict__,
        'epoch': epoch
    }
    with open(filename, 'wb') as f:
        pickle.dump(state, f)

def load_model(model, optimizer, filename='model.pkl'):
    """Load the model and optimizer state"""
    with open(filename, 'rb') as f:
        state = pickle.load(f)

    model.load_state_dict(state['model_state_dict'])
    optimizer.load_state_dict(state['optimizer_state_dict'])
    metrics = MetricsTracker()
    metrics.__dict__.update(state['metrics'])
    return model, optimizer, metrics, state['epoch']

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load dataset
    full_dataset = load_dataset("saitsharipov/CelebA-HQ", split='train[:1000]')

    # Create splits
    total_size = len(full_dataset)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size

    train_data, val_data, test_data = random_split(
        full_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Create datasets and dataloaders
    batch_size = 32
    train_dataset = CelebAHQDataset(train_data)
    val_dataset = CelebAHQDataset(val_data)
    test_dataset = CelebAHQDataset(test_data)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Initialize model, optimizer, and metrics
    model = UNetHINT().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    metrics = MetricsTracker()

    # Training loop
    num_epochs = 10
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training
        model.train()
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images, masked_images, masks = [x.to(device) for x in batch]

            optimizer.zero_grad()
            outputs = model(masked_images, masks)
            loss = F.l1_loss(outputs, images)

            loss.backward()
            optimizer.step()

            metrics.update_train(loss.item())

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                images, masked_images, masks = [x.to(device) for x in batch]
                outputs = model(masked_images, masks)
                val_loss += F.l1_loss(outputs, images).item()

        val_loss /= len(val_loader)
        metrics.update_val(val_loss)

        # Calculate PSNR and SSIM
        with torch.no_grad():
            batch = next(iter(test_loader))
            images, masked_images, masks = [x.to(device) for x in batch]
            outputs = model(masked_images, masks)

            # Convert to numpy for metric calculation
            img = images[0].cpu().numpy().transpose(1, 2, 0)
            out = outputs[0].cpu().numpy().transpose(1, 2, 0)

            psnr_score = psnr(img, out, data_range=1.0)
            ssim_score = ssim(img, out, channel_axis=2, data_range=1.0)
            metrics.update_metrics(psnr_score, ssim_score)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, optimizer, metrics, epoch)

        # Plot and save metrics
        metrics.plot_metrics()

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {metrics.train_losses[-1]:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'PSNR: {psnr_score:.2f}, SSIM: {ssim_score:.4f}')

if __name__ == '__main__':
    main()

Using device: cpu


Epoch 1/10:   0%|          | 0/25 [00:12<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 256 for tensor number 1 in the list.