<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/FastInpaint_Jan14.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

In [7]:
!pip install datasets


import os
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import kagglehub
from google.colab import drive
import torchvision.datasets as datasets




In [4]:
# Download latest version
path = kagglehub.dataset_download("jessicali9530/caltech256")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/jessicali9530/caltech256?dataset_version_number=2...


100%|██████████| 2.12G/2.12G [00:30<00:00, 75.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/jessicali9530/caltech256/versions/2


In [5]:
# Model name and setup
model_name = "caltech256-fastInpaint"
drive.mount('/content/drive')
CHECKPOINTS_DIR = '/content/drive/MyDrive/ckpts'

Mounted at /content/drive


In [9]:
def setup_data(img_size=256, batch_size=8):
    """Setup data loaders for Caltech256"""
    # Download dataset
    base_path = kagglehub.dataset_download("jessicali9530/caltech256")
    image_dir = os.path.join(base_path, '256_ObjectCategories')

    print(f"Base path: {base_path}")
    print(f"Image directory: {image_dir}")

    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create dataset
    dataset = datasets.ImageFolder(root=image_dir, transform=transform)
    print(f"\nLoaded dataset with {len(dataset)} images")

    # Split dataset
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    # Create splits
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size]
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=2
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=2
    )

    print(f"Train size: {len(train_dataset)}")
    print(f"Val size: {len(val_dataset)}")
    print(f"Test size: {len(test_dataset)}")

    return train_loader, val_loader, test_loader


In [10]:
def visualize_batch(original_images, masked_images, outputs, epoch, batch_idx, save_dir='visualization'):
    """Visualize a batch of images: original, masked, and reconstructed"""
    os.makedirs(save_dir, exist_ok=True)

    # Convert tensors to numpy arrays and move to CPU if needed
    original_images = original_images.cpu().detach().numpy()
    masked_images = masked_images.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()

    # Create a figure with three rows: original, masked, and reconstructed
    fig, axes = plt.subplots(3, min(4, original_images.shape[0]), figsize=(15, 10))

    for i in range(min(4, original_images.shape[0])):
        # Original
        axes[0, i].imshow(np.transpose(original_images[i], (1, 2, 0)) * 0.5 + 0.5)
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original')

        # Masked
        axes[1, i].imshow(np.transpose(masked_images[i], (1, 2, 0)) * 0.5 + 0.5)
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Masked')

        # Reconstructed
        axes[2, i].imshow(np.transpose(outputs[i], (1, 2, 0)) * 0.5 + 0.5)
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_title('Reconstructed')

    plt.tight_layout()
    plt.savefig(f'{save_dir}/epoch_{epoch}_batch_{batch_idx}.png')
    plt.close()

In [11]:
def visualize_sample_batch(data_loader, save_path='sample_batch.png'):
    """Visualize and save a sample batch of images"""
    # Get a batch of images
    images = next(iter(data_loader))[0]  # [0] because ImageFolder returns (images, labels)

    # Create figure
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    axes = axes.ravel()

    for idx, img in enumerate(images[:8]):  # Show first 8 images
        # Convert tensor to numpy and denormalize
        img_np = img.numpy().transpose(1, 2, 0) * 0.5 + 0.5
        axes[idx].imshow(img_np)
        axes[idx].axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"Sample batch visualization saved to {save_path}")

In [12]:
def save_checkpoint(model, optimizer, epoch):
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    checkpoint_path = f'{CHECKPOINTS_DIR}/{model_name}.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)
    print(f"Checkpoint saved for {model_name} at epoch {epoch}")

In [13]:
def load_checkpoint(model, optimizer):
    ckpt_path = f'{CHECKPOINTS_DIR}/{model_name}.pth'
    if not os.path.exists(ckpt_path):
        print(f"No checkpoint found for {model_name}, starting from epoch 0")
        return 0

    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Checkpoint loaded for {model_name}, resuming from epoch {start_epoch}")
    return start_epoch

In [20]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.InstanceNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm2 = nn.InstanceNorm2d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

In [14]:
class InpaintingNet(nn.Module):
    def __init__(self):
        super(InpaintingNet, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(6, 64, 7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # Middle blocks
        self.middle = nn.Sequential(*[ResidualBlock(256) for _ in range(6)])

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 3, 7, padding=3),
            nn.Tanh()
        )

    def forward(self, x, mask):
        x = torch.cat([x, mask], dim=1)
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x


In [15]:
def create_random_mask(image):
    """Create random rectangular masks"""
    batch_size, _, height, width = image.shape
    mask = torch.ones_like(image)

    for i in range(batch_size):
        h = torch.randint(height//4, height//2, (1,)).item()
        w = torch.randint(width//4, width//2, (1,)).item()

        top = torch.randint(0, height - h, (1,)).item()
        left = torch.randint(0, width - w, (1,)).item()

        mask[i, :, top:top+h, left:left+w] = 0

    return mask

In [16]:
def plot_losses(train_losses, val_losses, save_dir='visualization'):
    """Plot training and validation losses"""
    os.makedirs(save_dir, exist_ok=True)
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    plt.savefig(f'{save_dir}/losses.png')
    plt.close()

In [24]:
def train_model(model, train_loader, val_loader, num_epochs=30, device='cuda'):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.L1Loss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)

    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    # Load checkpoint if exists
    start_epoch = load_checkpoint(model, optimizer)

    for epoch in range(start_epoch, num_epochs):
        # Training
        model.train()
        train_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader): # Access images directly from the tuple
            images = images.to(device)
            masks = create_random_mask(images).to(device)
            masked_images = images * masks

            optimizer.zero_grad()
            outputs = model(masked_images, masks)
            loss = criterion(outputs, images)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Visualize every 100 batches
            if batch_idx % 100 == 0:
                visualize_batch(images, masked_images, outputs, epoch, batch_idx)

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_idx, (images, _) in enumerate(val_loader): # unpack the batch
                images = images.to(device) # images is the tensor you need
                masks = create_random_mask(images).to(device)
                masked_images = images * masks

                outputs = model(masked_images, masks)
                loss = criterion(outputs, images)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # Update learning rate
        scheduler.step(avg_val_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

        # Save checkpoint and best model
        save_checkpoint(model, optimizer, epoch)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_inpainting.pth')

        # Plot losses
        plot_losses(train_losses, val_losses)



In [25]:
def main():
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    try:
        # Setup data
        train_loader, val_loader, test_loader = setup_data(img_size=256, batch_size=8)

        # Visualize sample batch
        visualize_sample_batch(train_loader)

        # Create and train model
        model = InpaintingNet()
        train_model(model, train_loader, val_loader, num_epochs=10, device=device)

    except Exception as e:
        print(f"Error in main: {str(e)}")
        raise

In [26]:
if __name__ == "__main__":
    main()

Using device: cuda
Base path: /root/.cache/kagglehub/datasets/jessicali9530/caltech256/versions/2
Image directory: /root/.cache/kagglehub/datasets/jessicali9530/caltech256/versions/2/256_ObjectCategories

Loaded dataset with 30607 images
Train size: 21424
Val size: 4591
Test size: 4592
Sample batch visualization saved to sample_batch.png
No checkpoint found for caltech256-fastInpaint, starting from epoch 0
Error in main: list indices must be integers or slices, not str


TypeError: list indices must be integers or slices, not str