In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import os

# Define the U-Net Generator
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        # Define the layers for U-Net architecture
        self.down1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
        # Up-sampling layers
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.up3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        # Forward pass through down-sampling layers
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        
        # Forward pass through up-sampling layers with skip connections
        x = self.up1(x3)
        x = self.up2(x + x2)  # Skip connection
        x = self.up3(x + x1)  # Skip connection
        return torch.tanh(x)  # Output should be in the range [-1, 1]

# Define the PatchGAN Discriminator
class PatchGANDiscriminator(nn.Module):
    def __init__(self):
        super(PatchGANDiscriminator, self).__init__()
        # Define the layers for PatchGAN
        self.conv1 = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.conv4 = nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x, y):
        # Concatenate the input and generated/real images along the channel dimension
        input = torch.cat([x, y], dim=1)
        x = self.conv1(input)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

# Define loss functions
adversarial_loss = nn.BCEWithLogitsLoss()  # Binary Cross Entropy for GANs
l1_loss = nn.L1Loss()  # L1 loss for image reconstruction

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = UNetGenerator().to(device)
discriminator = PatchGANDiscriminator().to(device)

# Define optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Define the training loop
def train(dataloader, epochs, save_interval=10):
    for epoch in range(epochs):
        for i, (masked, unmasked) in enumerate(dataloader):
            masked, unmasked = masked.to(device), unmasked.to(device)

            # ---------------------
            # Train Discriminator
            # ---------------------
            d_optimizer.zero_grad()

            # Real images
            real_validity = discriminator(masked, unmasked)
            real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity))

            # Fake images
            fake_unmasked = generator(masked)
            fake_validity = discriminator(masked, fake_unmasked.detach())
            fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity))

            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            d_optimizer.step()

            # -----------------
            # Train Generator
            # -----------------
            g_optimizer.zero_grad()

            # Generate images and calculate generator loss
            fake_unmasked = generator(masked)
            fake_validity = discriminator(masked, fake_unmasked)
            g_adv_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity))

            # L1 loss
            g_l1_loss = l1_loss(fake_unmasked, unmasked)

            # Total generator loss
            g_loss = g_adv_loss + 100 * g_l1_loss
            g_loss.backward()
            g_optimizer.step()

            if i % save_interval == 0:
                print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(dataloader)} "
                      f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

# Define the image preprocessing pipeline
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Define dataset path
dataset_path = "MaskTheFace/faces"
masked_path = os.path.join(dataset_path, "Celebrity Faces Dataset_masked")
unmasked_path = os.path.join(dataset_path, "Celebrity Faces Dataset")

class MaskedFaceDataset(Dataset):
    def __init__(self, masked_dir, unmasked_dir, transform=None):
        self.masked_dir = masked_dir
        self.unmasked_dir = unmasked_dir
        self.transform = transform

        # Recursively find all image paths in subdirectories
        self.masked_images, self.unmasked_images = self._get_paired_image_paths()

        print(f"Found {len(self.masked_images)} masked images and {len(self.unmasked_images)} unmasked images.")
        assert len(self.masked_images) == len(self.unmasked_images), "Mismatch between masked and unmasked images count!"

    def _get_paired_image_paths(self):
        # Recursively find all image files in each directory
        masked_files = []
        unmasked_files = []

        for root, _, files in os.walk(self.masked_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')) and not file.startswith('.'):
                    masked_files.append(os.path.join(root, file))

        for root, _, files in os.walk(self.unmasked_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')) and not file.startswith('.'):
                    unmasked_files.append(os.path.join(root, file))

        # Find matching pairs based on the filenames without the '_surgical' suffix
        paired_masked_paths = []
        paired_unmasked_paths = []

        for masked_file in masked_files:
            base_name = os.path.basename(masked_file).replace('_surgical', '')

            # Look for a matching unmasked file by the base name
            matching_unmasked_file = next((uf for uf in unmasked_files if os.path.basename(uf) == base_name), None)

            if matching_unmasked_file:
                paired_masked_paths.append(masked_file)
                paired_unmasked_paths.append(matching_unmasked_file)

        return sorted(paired_masked_paths), sorted(paired_unmasked_paths)

    def __len__(self):
        return len(self.masked_images)

    def __getitem__(self, idx):
        masked_image_path = self.masked_images[idx]
        unmasked_image_path = self.unmasked_images[idx]

        # Read the images using PIL for better compatibility
        masked_image = Image.open(masked_image_path).convert("RGB")
        unmasked_image = Image.open(unmasked_image_path).convert("RGB")

        if self.transform:
            masked_image = self.transform(masked_image)
            unmasked_image = self.transform(unmasked_image)

        return masked_image, unmasked_image


# Create dataloader
dataset = MaskedFaceDataset(masked_path, unmasked_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Train the GAN
train(dataloader, epochs=50, save_interval=100)


Found 1780 masked images and 1780 unmasked images.
Epoch [0/50] Batch 0/112 D Loss: 0.7234, G Loss: 74.1440
Epoch [0/50] Batch 100/112 D Loss: 0.5183, G Loss: 11.6424
Epoch [1/50] Batch 0/112 D Loss: 0.4288, G Loss: 11.4126
Epoch [1/50] Batch 100/112 D Loss: 0.3030, G Loss: 10.5109
Epoch [2/50] Batch 0/112 D Loss: 0.1533, G Loss: 9.6724
Epoch [2/50] Batch 100/112 D Loss: 0.1510, G Loss: 9.7579
Epoch [3/50] Batch 0/112 D Loss: 0.2139, G Loss: 9.7484
Epoch [3/50] Batch 100/112 D Loss: 0.0123, G Loss: 11.9706
Epoch [4/50] Batch 0/112 D Loss: 0.0191, G Loss: 12.2753
Epoch [4/50] Batch 100/112 D Loss: 0.0053, G Loss: 12.4170
Epoch [5/50] Batch 0/112 D Loss: 0.0082, G Loss: 11.4801
Epoch [5/50] Batch 100/112 D Loss: 0.0074, G Loss: 12.6663
Epoch [6/50] Batch 0/112 D Loss: 0.0082, G Loss: 14.4509
Epoch [6/50] Batch 100/112 D Loss: 1.3093, G Loss: 6.5933
Epoch [7/50] Batch 0/112 D Loss: 0.6324, G Loss: 6.8912
Epoch [7/50] Batch 100/112 D Loss: 0.0136, G Loss: 10.8017
Epoch [8/50] Batch 0/112 D