<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/FastInpaintingNet-Jan25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install datasets
!pip install torch torchvision

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from datasets import load_dataset
import pandas as pd

# Step 1: Inspect the Dataset
def inspect_dataset():
    """
    Inspect the CelebA-HQ dataset to understand its structure and contents.
    """
    try:
        # Load the CelebA-HQ dataset
        ds = load_dataset("saitsharipov/CelebA-HQ")

        # Check if dataset loaded correctly
        if ds is None:
            raise ValueError("Dataset not loaded correctly")

        # Print basic information about the dataset
        print("Dataset structure:")
        print(ds)

        # Inspect the first few items in the training set
        print("\nFirst few items in the training set:")
        for i, item in enumerate(ds['train'].take(5)):
            print(f"Item {i}:")
            for key, value in item.items():
                if key == 'image':
                    print(f"  {key}: Image object")
                else:
                    print(f"  {key}: {value}")
            print()

        # Get all column names (attributes)
        column_names = list(ds['train'].features.keys())
        print("\nColumn names:")
        print(column_names)

        # Convert a small subset of the dataset to a pandas DataFrame for easier inspection
        df = pd.DataFrame(ds['train'].select(range(100)))
        print("\nDataFrame head:")
        print(df.head())

        # Print data types of columns
        print("\nColumn data types:")
        print(df.dtypes)

        return ds

    except Exception as e:
        print(f"Error inspecting dataset: {e}")
        return None

# Step 2: Load and Preprocess the Dataset
def load_and_preprocess_dataset(ds, max_images=1000, img_size=128):
    """
    Load and preprocess the CelebA-HQ dataset.

    Args:
    - ds: The dataset object returned by `load_dataset`
    - max_images: Maximum number of images to load
    - img_size: Target image size for resizing

    Returns:
    - Processed image numpy array
    """
    try:
        # Extract images
        images = ds['train']['image'][:max_images]

        # Check if images are loaded
        if not images:
            print("No images found in the dataset.")
            return None

        # Resize and normalize images
        image_paths = []
        for img in tqdm(images):
            # Convert PIL image to numpy array and resize
            img = img.resize((img_size, img_size))
            img = np.array(img) / 255.0  # Normalize to [0, 1]
            image_paths.append(img)

        return np.array(image_paths, dtype=np.float32)

    except Exception as e:
        print(f"Dataset loading error: {e}")
        return None

# Step 3: Define the Model Architecture
class InpaintingNet(nn.Module):
    def __init__(self):
        super(InpaintingNet, self).__init__()

        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )

        # Middle (Bottleneck)
        self.middle = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1),  # Output: 128 channels
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )

        # Final output layer
        self.final = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x, mask):
        # Concatenate input image and mask
        x = torch.cat([x, mask], dim=1)

        # Encoder
        e1 = self.encoder1(x)  # Output size: (batch, 64, H/2, W/2)
        e2 = self.encoder2(e1)  # Output size: (batch, 128, H/4, W/4)
        e3 = self.encoder3(e2)  # Output size: (batch, 256, H/8, W/8)

        # Middle
        m = self.middle(e3)  # Output size: (batch, 512, H/8, W/8)

        # Decoder with skip connections
        d3 = self.decoder3(m)  # Output size: (batch, 128, H/4, W/4)
        d3 = torch.cat([d3, e2], dim=1)  # Skip connection (batch, 128 + 128, H/4, W/4)

        d2 = self.decoder2(d3)  # Output size: (batch, 128, H/2, W/2)
        d2 = torch.cat([d2, e1], dim=1)  # Skip connection (batch, 128 + 64, H/2, W/2)

        d1 = self.decoder1(d2)  # Output size: (batch, 64, H, W)

        # Final output
        out = self.final(d1)  # Output size: (batch, 3, H, W)
        return torch.tanh(out)  # Normalize output to [-1, 1]

# Step 4: Define VGG-based Perceptual Loss
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = models.vgg16(pretrained=True).features[:16].eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, output, target):
        # Compute VGG features
        vgg_output = self.vgg(output)
        vgg_target = self.vgg(target)

        # Compute L1 loss between features
        return F.l1_loss(vgg_output, vgg_target)

# Step 5: Custom Dataset Class
class CelebAHQDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = self.transform(image)
        return image

# Step 6: Mask Generation
def create_irregular_mask(image):
    _, h, w = image.shape
    mask = torch.ones_like(image)

    # Randomly generate irregular shapes
    for _ in range(np.random.randint(5, 10)):
        mask_h = np.random.randint(h // 4, h // 2)
        mask_w = np.random.randint(w // 4, w // 2)
        top = np.random.randint(0, h - mask_h)
        left = np.random.randint(0, w - mask_w)
        mask[:, top:top+mask_h, left:left+mask_w] = 0

    return mask

# Step 7: Training Loop
def train_model(model, train_loader, val_loader, num_epochs=20, device='cuda'):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion_l1 = nn.L1Loss()
    criterion_vgg = VGGLoss().to(device)

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            images = batch.to(device)
            masks = torch.stack([create_irregular_mask(img) for img in images]).to(device)
            masked_images = images * masks

            optimizer.zero_grad()
            outputs = model(masked_images, masks)

            # Compute losses
            loss_l1 = criterion_l1(outputs, images)
            loss_vgg = criterion_vgg(outputs, images)
            loss = loss_l1 + 0.1 * loss_vgg  # Weighted combination

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Visualize intermediate results every 100 batches
            if batch_idx % 100 == 0:
                visualize_results(images, masked_images, outputs, epoch, batch_idx)

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                images = batch.to(device)
                masks = torch.stack([create_irregular_mask(img) for img in images]).to(device)
                masked_images = images * masks

                outputs = model(masked_images, masks)
                loss = criterion_l1(outputs, images)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Update learning rate
        scheduler.step()

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_inpainting_model.pth')

        torch.cuda.empty_cache()

# Step 8: Visualization Function
def visualize_results(images, masked_images, outputs, epoch, batch_idx):
    # Detach tensors and move to CPU
    images = images.detach().cpu()
    masked_images = masked_images.detach().cpu()
    outputs = outputs.detach().cpu()

    # Plot results
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(images[0].permute(1, 2, 0))
    plt.title('Original')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(masked_images[0].permute(1, 2, 0))
    plt.title('Masked')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(outputs[0].permute(1, 2, 0))
    plt.title('Inpainted')
    plt.axis('off')

    plt.suptitle(f'Epoch {epoch+1}, Batch {batch_idx}')
    plt.show()

# Step 9: Main Function
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Inspect the dataset
    ds = inspect_dataset()
    if ds is None:
        print("Failed to inspect dataset. Exiting.")
        return

    # Load and preprocess the dataset
    images = load_and_preprocess_dataset(ds, max_images=1000, img_size=128)
    if images is None:
        print("Failed to load dataset. Exiting.")
        return

    # Create dataset and data loaders
    dataset = CelebAHQDataset(images)
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=False)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=False)

    # Create and train model
    model = InpaintingNet()
    train_model(model, train_loader, val_loader, num_epochs=20, device=device)

if __name__ == "__main__":
    main()

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast
from datasets import load_dataset
import matplotlib.pyplot as plt

class ImageInpaintingDataset(Dataset):
    def __init__(self, dataset, max_images=500, img_size=128):
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.images = []
        for img in dataset['train']['image'][:max_images]:
            img_tensor = self.transform(img)
            self.images.append(img_tensor)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx]

class InpaintingNet(nn.Module):
    def __init__(self, in_channels=6, out_channels=3):
        super().__init__()
        # Similar structure to previous model, with slight optimization
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )

        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, out_channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x, mask):
        x = torch.cat([x, mask], dim=1)
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(pretrained=True).features[:16].eval()
        self.features = nn.Sequential(*list(vgg.children()))
        self.features.requires_grad_(False)

    def forward(self, pred, target):
        pred_features = self.features(pred)
        target_features = self.features(target)
        return F.l1_loss(pred_features, target_features)

def create_mask(image):
    mask = torch.ones_like(image)
    _, h, w = image.shape
    num_masks = np.random.randint(3, 7)

    for _ in range(num_masks):
        mask_h = np.random.randint(h//8, h//4)
        mask_w = np.random.randint(w//8, w//4)
        x = np.random.randint(0, h - mask_h)
        y = np.random.randint(0, w - mask_w)
        mask[:, x:x+mask_h, y:y+mask_w] = 0

    return mask

def train(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
    l1_loss = nn.L1Loss()
    perceptual_loss = PerceptualLoss().to(device)
    scaler = GradScaler()

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            images = batch.to(device)
            masks = torch.stack([create_mask(img) for img in images]).to(device)
            masked_images = images * masks

            optimizer.zero_grad()

            with autocast():
                output = model(masked_images, masks)
                pixel_loss = l1_loss(output, images)
                perc_loss = perceptual_loss(output, images)
                loss = pixel_loss + 0.1 * perc_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load small subset of dataset
    dataset = load_dataset("saitsharipov/CelebA-HQ")

    # Create dataset and loaders with reduced memory footprint
    full_dataset = ImageInpaintingDataset(dataset, max_images=500)
    train_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - train_size

    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

    model = InpaintingNet()
    train(model, train_loader, val_loader, device)

    # Save model
    torch.save(model.state_dict(), 'inpainting_model.pth')

if __name__ == "__main__":
    main()

Using device: cuda
