In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset, random_split
from torch.autograd import Variable
from PIL import Image
import kagglehub
import numpy as np
from torch.nn.utils import spectral_norm

# Mount Google Drive (for saving checkpoints and outputs)
from google.colab import drive
drive.mount('/content/drive')

# Download dataset using kagglehub
def download_dataset():
    global path  # Make path a global variable
    path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")
    print("Path to dataset files:", path)

# Call the function to download the dataset
download_dataset()

# Debug: List files in the dataset directory
print("Files in dataset directory:", os.listdir(path))

# Check if the dataset directory exists
if not os.path.exists(path):
    raise FileNotFoundError(f"Dataset directory not found at: {path}")

# Check for subdirectories (e.g., 'images', 'train', 'test')
subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
if subdirs:
    print("Subdirectories found:", subdirs)
    # Update the path to point to the subdirectory containing images
    path = os.path.join(path, subdirs[0])  # Use the first subdirectory (e.g., 'images')
    print("Updated dataset path:", path)

# List image files in the dataset directory
files = [f for f in os.listdir(path) if f.endswith(('png', 'jpg', 'jpeg'))]
print(f"Number of images found: {len(files)}")  # Debug: Print the number of images

# Ensure the dataset is not empty
if len(files) == 0:
    raise ValueError("No images found in the dataset directory. Please check the path.")

# Create directories for saving images and checkpoints
os.makedirs("images", exist_ok=True)
os.makedirs("/content/drive/MyDrive/gan_unet_checkpoints", exist_ok=True)

Mounted at /content/drive
Path to dataset files: /kaggle/input/celebahq-resized-256x256
Files in dataset directory: ['celeba_hq_256']
Subdirectories found: ['celeba_hq_256']
Updated dataset path: /kaggle/input/celebahq-resized-256x256/celeba_hq_256
Number of images found: 30000


In [None]:
# U-Net Components
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
# Generator using U-Net
class UNetGenerator(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNetGenerator, self).__init__()
        self.unet = UNet(n_channels, n_classes)

    def forward(self, x):
        return self.unet(x)

# Discriminator with Spectral Normalization
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(img_shape[0], 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, img):
        features = self.model(img)
        validity = self.sigmoid(torch.mean(features, dim=(2, 3)))  # Global average pooling
        return validity

In [None]:
# Gradient Penalty
def gradient_penalty(discriminator, real_imgs, fake_imgs):
    alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda()
    interpolated = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)
    validity = discriminator(interpolated)
    gradients = torch.autograd.grad(
        outputs=validity,
        inputs=interpolated,
        grad_outputs=torch.ones(validity.size()).cuda(),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Loss functions
adversarial_loss = nn.BCELoss()
pixelwise_loss = nn.L1Loss()  # L1 loss for pixel-wise reconstruction

# Initialize models
generator = UNetGenerator(3, 3)  # RGB input and output
discriminator = Discriminator((3, 256, 256))

if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    pixelwise_loss.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))  # Lower LR for generator
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0004, betas=(0.5, 0.999))  # Higher LR for discriminator

# Transform
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
# Custom dataset class with masking
class CustomDataset(Dataset):
    def __init__(self, root, transform=None):
        self.files = [os.path.join(root, file) for file in os.listdir(root) if file.endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img = Image.open(self.files[index]).convert("RGB")
        if self.transform:
            img = self.transform(img)

        # Create a mask with a random gap
        mask = torch.ones_like(img)
        gap_size = 64  # Size of the gap
        x = np.random.randint(0, 256 - gap_size)
        y = np.random.randint(0, 256 - gap_size)
        mask[:, x:x+gap_size, y:y+gap_size] = 0  # Create a square gap
        masked_img = img * mask  # Apply the mask to the image

        return masked_img, img  # Return masked image and original image

# Create the dataset
dataset = CustomDataset(path, transform=transform)

# Split the dataset into training and testing sets
train_size = 24000  # Use only 20k images for training
test_size = len(dataset) - train_size  # Remaining images for testing

# Ensure train_size is valid
if train_size <= 0 or train_size > len(dataset):
    raise ValueError(f"Invalid train_size: {train_size}. Dataset has only {len(dataset)} images.")

# Perform the split
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")

Training dataset size: 24000
Testing dataset size: 6000


In [None]:
# Training loop
def train(epochs, start_epoch=0, patience=5):
    best_loss = float('inf')
    epochs_without_improvement = 0

    # Debug: Print dataset size
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Testing dataset size: {len(test_dataset)}")

    # Debug: Print a batch from the dataloader
    for batch in train_dataloader:
        print("Batch shape:", batch[0].shape)  # Print the shape of the masked images
        break  # Only print the first batch

    # Verify device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Move models to the correct device
    generator.to(device)
    discriminator.to(device)

    for epoch in range(start_epoch, epochs):
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0

        for i, (masked_imgs, real_imgs) in enumerate(train_dataloader):
            # Move data to the correct device
            masked_imgs = masked_imgs.to(device)
            real_imgs = real_imgs.to(device)

            # Prepare ground truths with label smoothing
            valid = Variable(torch.ones(real_imgs.size(0), 1).to(device) * 0.9, requires_grad=False)  # Smooth real labels
            fake = Variable(torch.zeros(real_imgs.size(0), 1).to(device), requires_grad=False)        # Fake labels remain 0

            # Train generator
            optimizer_G.zero_grad()
            gen_imgs = generator(masked_imgs)
            g_adv_loss = adversarial_loss(discriminator(gen_imgs), valid)  # Adversarial loss
            g_pixel_loss = pixelwise_loss(gen_imgs, real_imgs)  # Pixel-wise loss
            g_loss = g_adv_loss + 10 * g_pixel_loss  # Adjusted loss weighting
            g_loss.backward()
            optimizer_G.step()

            # Train discriminator
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            gp = gradient_penalty(discriminator, real_imgs, gen_imgs.detach())  # Gradient penalty
            d_loss = (real_loss + fake_loss) / 2 + 1 * gp  # Reduced gradient penalty weight
            d_loss.backward()
            optimizer_D.step()

            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()

            # Print loss values for every batch
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}] [GP: {gp.item()}]")

            if i % 100 == 0:
                save_image(gen_imgs.data[:25], f"images/{epoch}_{i}.png", nrow=5, normalize=True)

        # Calculate average losses for the epoch
        epoch_g_loss /= len(train_dataloader)
        epoch_d_loss /= len(train_dataloader)
        print(f"[Epoch {epoch}/{epochs}] [Avg D loss: {epoch_d_loss}] [Avg G loss: {epoch_g_loss}]")

        # Save checkpoints at the end of each epoch
        save_checkpoints(epoch)

        # Early stopping logic
        if epoch_g_loss < best_loss:
            best_loss = epoch_g_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).")
            break

# Save checkpoints
def save_checkpoints(epoch):
    torch.save(generator.state_dict(), f"/content/drive/MyDrive/gan_unet_checkpoints/generator_epoch_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"/content/drive/MyDrive/gan_unet_checkpoints/discriminator_epoch_{epoch}.pth")
    print(f"Checkpoints saved for epoch {epoch}.")

In [None]:
# Test the model with a custom image
def test_model_with_random_image(image_path):
    # Verify image path
    print(f"Testing image path: {image_path}")
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at: {image_path}")

    # Verify checkpoint loading
    last_saved_epoch = 49  # Update this to the last saved epoch
    checkpoint_path = f"/content/drive/MyDrive/gan_unet_checkpoints/generator_epoch_{last_saved_epoch}.pth"
    print(f"Loading checkpoint from: {checkpoint_path}")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at: {checkpoint_path}")
    generator.load_state_dict(torch.load(checkpoint_path))

    # Set the generator to evaluation mode
    generator.eval()

    # Ensure consistent device usage
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)

    # Load the custom image
    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)  # Add batch dimension and move to device
    print(f"Image tensor shape: {img.shape}")  # Debug: Print image tensor shape

    # Create a medium-sized mask
    mask = torch.ones_like(img)  # Create a mask of the same shape as the image
    gap_size = 96  # Medium-sized gap (96x96)
    x = (img.shape[2] - gap_size) // 2  # Center the mask vertically
    y = (img.shape[3] - gap_size) // 2  # Center the mask horizontally
    mask[:, :, x:x+gap_size, y:y+gap_size] = 0  # Create a medium-sized square gap
    print(f"Mask tensor shape: {mask.shape}")  # Debug: Print mask tensor shape

    # Save the mask for visualization
    save_image(mask, "mask.png", normalize=True)
    print("Mask saved as 'mask.png'")

    # Apply the mask to the image
    masked_img = img * mask  # Element-wise multiplication
    print(f"Masked image tensor shape: {masked_img.shape}")  # Debug: Print masked image tensor shape

    # Generate the filled image
    with torch.no_grad():
        filled_img = generator(masked_img)

    # Save and visualize results
    save_image(masked_img, "masked_image.png", normalize=True)
    save_image(filled_img, "filled_image.png", normalize=True)
    print("Masked and filled images saved as 'masked_image.png' and 'filled_image.png'")

    # Verify output directory
    print(f"Current working directory: {os.getcwd()}")
    if not os.path.exists("masked_image.png") or not os.path.exists("filled_image.png"):
        print("Error: Images were not saved. Check the output directory.")

In [None]:
# Example usage
if __name__ == "__main__":
    # Train the model for 50 epochs
   # train(epochs=50, start_epoch=0)

    # Test the model with a custom image
    custom_image_path = "/content/drive/MyDrive/Testing/0.jpg"  # Update this path
    test_model_with_random_image(custom_image_path)

Testing image path: /content/drive/MyDrive/Testing/0.jpg
Loading checkpoint from: /content/drive/MyDrive/gan_unet_checkpoints/generator_epoch_49.pth
Image tensor shape: torch.Size([1, 3, 256, 256])
Mask tensor shape: torch.Size([1, 3, 256, 256])
Mask saved as 'mask.png'
Masked image tensor shape: torch.Size([1, 3, 256, 256])
Masked and filled images saved as 'masked_image.png' and 'filled_image.png'
Current working directory: /content
