In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# --- Unzip the dataset ---
# IMPORTANT: Update this path to match the location of your zip file in Drive.
zip_path = '/content/drive/MyDrive/ColabNotebooks/Vision/inpainting/ffhq256_10ksubset.zip'

# The destination folder in the local Colab environment.
destination_path = '/content'

print("Unzipping dataset...")
# The -q flag makes the output cleaner (quiet mode)
!unzip -q {zip_path} -d {destination_path}

print(f"✅ Dataset unzipped to {destination_path}")

Unzipping dataset...
✅ Dataset unzipped to /content


In [3]:
# Install
!pip -q install pytorch_wavelets torchmetrics lpips torch-fidelity

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.9/54.9 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.0/983.0 kB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m81.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m97.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import torch._dynamo
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn.utils import spectral_norm
from torchvision.utils import save_image, make_grid
from torchvision import transforms
from PIL import Image
import pandas as pd
import time
import random
import math
from tqdm.auto import tqdm
from tqdm.autonotebook import tqdm
import os
import glob
import matplotlib.pyplot as plt

import torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance
import lpips

import torch.fft
from pytorch_wavelets import DWTForward
from einops import rearrange

# Enable TensorFloat32
torch.set_float32_matmul_precision('high')

# --- Configuration & Hyperparameters ---

# Set the path to your image folder in Google Drive
# IMPORTANT: Update this path to match where you saved your dataset.
DATASET_PATH = '/content/ffhq_subset_10k'
NUM_IMAGES_TO_USE_CNN = 10000
NUM_IMAGES_TO_USE_GAN = 1000

# Training settings
NUM_EPOCHS_CNN = 40
NUM_EPOCHS_GAN = 200
LEARNING_RATE = 1e-4
BATCH_SIZE = 4

# Set the device (use GPU if available)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

Using device: cuda


In [5]:
class FFHQDataset(Dataset):
    """Custom PyTorch Dataset for loading FFHQ images."""
    def __init__(self, img_dir, transform=None, num_images=None): # Add num_images parameter
        """
        Args:
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            num_images (int, optional): Number of images to use. If None, use all images.
        """
        # Find all files with .png or .jpg extension
        self.img_paths = glob.glob(os.path.join(img_dir, '*.png'))
        self.img_paths.extend(glob.glob(os.path.join(img_dir, '*.jpg')))
        self.transform = transform

        if num_images:
            # If a number is specified, shuffle all paths and take a random subset
            random.shuffle(self.img_paths)
            self.img_paths = self.img_paths[:num_images]

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Your transform definition remains the same
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [6]:
def tensor_to_image(tensor):
    """
    Converts a PyTorch tensor to a displayable NumPy image.
    It denormalizes, moves to CPU, and changes dimension order.
    """
    # Denormalize the image from [-1, 1] to [0, 1]
    image = tensor * 0.5 + 0.5
    # Move tensor to CPU and convert to NumPy array
    image = image.cpu().numpy()
    # Transpose dimensions from (C, H, W) to (H, W, C) for plotting
    image = image.transpose(1, 2, 0)
    # Clip values to be in the valid [0, 1] range for images
    image = np.clip(image, 0, 1)
    return image

In [7]:
def create_mask(image, mask_percentage=0.025):
    """
    Creates masks that are more likely to cover facial features.
    """
    batch_size, _, height, width = image.shape
    mask = torch.ones_like(image)

    # Face regions typically in center 60% of image
    center_bias = 0.3  # 30% border on each side

    for i in range(batch_size):
        mask_h = int(np.sqrt(height * width * mask_percentage))
        mask_w = mask_h

        # Bias towards center
        top_min = int(height * center_bias)
        top_max = int(height * (1 - center_bias)) - mask_h
        left_min = int(width * center_bias)
        left_max = int(width * (1 - center_bias)) - mask_w

        top = np.random.randint(top_min, max(top_min + 1, top_max))
        left = np.random.randint(left_min, max(left_min + 1, left_max))

        mask[i, :, top:top+mask_h, left:left+mask_w] = 0

    masked_image = image * mask
    return masked_image, mask

In [9]:
class SelfAttention(nn.Module):
    """ A simple self-attention layer """
    def __init__(self, channels):
        super().__init__()
        self.query = nn.Conv2d(channels, channels // 8, 1)
        self.key   = nn.Conv2d(channels, channels // 8, 1)
        self.value = nn.Conv2d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        q = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        k = self.key(x).view(batch_size, -1, width * height)
        v = self.value(x).view(batch_size, -1, width * height)

        attention_map = F.softmax(torch.bmm(q, k), dim=-1)

        out = torch.bmm(v, attention_map.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)

        return self.gamma * out + x # Add skip connection

class UpsampleBlock(nn.Module):
    """An upsampling block using Conv2d and PixelShuffle."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # The Conv2d layer produces 4x the channels for a 2x upscale
        self.conv = nn.Conv2d(in_channels, out_channels * 4, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(2) # Rearranges channels to upscale by 2x
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.pixel_shuffle(self.conv(x)))

class GatedConv2d(nn.Module):
    """
    A Gated Convolutional Layer.
    It learns a dynamic feature mask for each channel at every location.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super().__init__()
        # Convolution for the features
        self.conv_feature = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, dilation
        )
        # Convolution for the gating mechanism
        self.conv_gate = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, dilation
        )

    def forward(self, x):
        # Get the features and the gate
        features = self.conv_feature(x)
        gate = torch.sigmoid(self.conv_gate(x)) # Gate values are between 0 and 1

        # Element-wise multiplication to apply the learned gate
        return features * gate

class GatedResidualBlock(nn.Module):
    """A Residual Block that uses Gated Convolutions."""
    def __init__(self, channels, dilation=1):
        super().__init__()
        padding = dilation

        # Replace nn.Conv2d with GatedConv2d
        self.conv1 = GatedConv2d(channels, channels, kernel_size=3, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = GatedConv2d(channels, channels, kernel_size=3, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = out + residual  # Residual connection
        out = self.relu(out)
        return out

class UNetSR(nn.Module):
    """
    A U-Net architecture with corrected channel dimensions for the decoder.
    """
    def __init__(self, in_channels=4, out_channels=3, num_channels=64):
        super().__init__()

        # --- Initial Convolution ---
        self.init_conv = nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)

        # --- Encoder Path ---
        self.enc1 = GatedResidualBlock(num_channels, dilation=1)
        self.enc2 = GatedResidualBlock(num_channels, dilation=1)
        self.pool = nn.MaxPool2d(2)

        # --- Bottleneck with Dilation and Attention ---
        self.bottleneck = nn.Sequential(
            GatedResidualBlock(num_channels, dilation=2),
            # SelfAttention(num_channels), # Add attention layer
            GatedResidualBlock(num_channels, dilation=4)
        )

        # --- Decoder Path ---
        self.upconv2 = UpsampleBlock(num_channels, num_channels)
        # Input channels = upsampled (64) + skip connection from e2 (64) = 128
        self.dec2 = GatedResidualBlock(num_channels * 2, dilation=1)

        self.upconv1 = UpsampleBlock(num_channels * 2, num_channels)
        # Input channels = upsampled (64) + skip connection from e1 (64) = 128
        self.dec1 = GatedResidualBlock(num_channels * 2, dilation=1)

        # --- Final Output Layer ---
        # The input to this layer comes from dec1, which outputs 128 channels.
        self.out_conv = nn.Conv2d(num_channels * 2, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # Initial feature extraction
        x0 = self.init_conv(x)

        # Encoder
        e1 = self.enc1(x0)
        p1 = self.pool(e1)

        e2 = self.enc2(p1)
        p2 = self.pool(e2)

        # Bottleneck
        b = self.bottleneck(p2)

        # Decoder with Skip Connections
        d2 = self.upconv2(b)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        # Final Output
        out = self.out_conv(d1)

        return torch.tanh(out)

In [10]:
# --- 1. Create the main dataset ---
# This should use the LARGER number of images you intend to work with.
# Let's assume NUM_IMAGES_TO_USE_CNN is the total pool of images.
print("Creating the main dataset...")
full_dataset = FFHQDataset(
    img_dir=DATASET_PATH,
    transform=transform,
    num_images=NUM_IMAGES_TO_USE_CNN # Use the total number of images available for the experiment
)
print(f"✅ Main dataset created with {len(full_dataset)} images.")

# --- 2. Split the dataset into Training, Validation, and Test sets ---
print("\nSplitting data into training, validation, and test sets...")
dataset_size = len(full_dataset)
indices = list(range(dataset_size))
np.random.seed(42) # for reproducibility
np.random.shuffle(indices)

# Define split points for an 80/10/10 split
train_split = int(np.floor(0.8 * dataset_size))
val_split = int(np.floor(0.9 * dataset_size))

# Create indices for each set
train_indices = indices[:train_split]
val_indices = indices[train_split:val_split]
test_indices = indices[val_split:]

# Create PyTorch Subsets
train_data = Subset(full_dataset, train_indices)
val_data = Subset(full_dataset, val_indices)
test_data = Subset(full_dataset, test_indices)

print(f"✅ Training set size: {len(train_data)}")
print(f"✅ Validation set size: {len(val_data)}")
print(f"✅ Test set size: {len(test_data)}")

# --- 3. Create the Diffusion Model's training subset ---
# This should be a subset of the TRAINING data.
print("\nCreating a subset of the training data for the GAN model...")
gan_indices = train_indices[:NUM_IMAGES_TO_USE_GAN] # Take from the start of shuffled train indices
gan_data = Subset(full_dataset, gan_indices)
print(f"✅ Diffusion training set size: {len(gan_data)}")


# --- 4. Create DataLoaders for each set ---
print("\nCreating DataLoaders...")
# The main CNN will now train on the 'train_data' subset
cnn_dataloader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
# New DataLoader for validation
val_dataloader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=False, # No need to shuffle validation data
    num_workers=4,
    pin_memory=True
)
# New DataLoader for testing
test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False, # No need to shuffle test data
    num_workers=4,
    pin_memory=True
)
# GAN DataLoader uses its own subset of the training data
gan_dataloader = DataLoader(
    gan_data,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)
print("✅ All DataLoaders created.")

# --- 5. Initialize models and optimizers (unchanged) ---
print("\nInitializing models and optimizers...")
cnn_model = UNetSR().to(DEVICE)

# Compile the models for a speed boost
cnn_model = torch.compile(cnn_model)

# Initialize optimizers
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=LEARNING_RATE)

print("\nSetup complete. Ready for CNN training!")

Creating the main dataset...
✅ Main dataset created with 10000 images.

Splitting data into training, validation, and test sets...
✅ Training set size: 8000
✅ Validation set size: 1000
✅ Test set size: 1000

Creating a subset of the training data for the GAN model...
✅ Diffusion training set size: 1000

Creating DataLoaders...
✅ All DataLoaders created.

Initializing models and optimizers...

Setup complete. Ready for CNN training!


In [11]:
# ============================================
# 1. GENERATOR - Residual Refinement Network
# ============================================

class ResidualBlock(nn.Module):
    """Residual block with spectral normalization"""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = spectral_norm(nn.Conv2d(channels, channels, 3, 1, 1))
        self.conv2 = spectral_norm(nn.Conv2d(channels, channels, 3, 1, 1))
        self.norm1 = nn.InstanceNorm2d(channels)
        self.norm2 = nn.InstanceNorm2d(channels)

    def forward(self, x):
        residual = x
        x = F.leaky_relu(self.norm1(self.conv1(x)), 0.2)
        x = self.norm2(self.conv2(x))
        return x + residual

class SelfAttention(nn.Module):
    """Self-attention layer for capturing long-range dependencies"""
    def __init__(self, channels):
        super().__init__()
        self.query = spectral_norm(nn.Conv2d(channels, channels // 8, 1))
        self.key = spectral_norm(nn.Conv2d(channels, channels // 8, 1))
        self.value = spectral_norm(nn.Conv2d(channels, channels, 1))
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        b, c, h, w = x.shape

        q = self.query(x).view(b, -1, h * w).permute(0, 2, 1)
        k = self.key(x).view(b, -1, h * w)
        v = self.value(x).view(b, -1, h * w)

        attention = F.softmax(torch.bmm(q, k), dim=-1)
        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(b, c, h, w)

        return self.gamma * out + x

class RefinementGenerator(nn.Module):
    """Generator that refines CNN output"""
    def __init__(self, in_channels=7, out_channels=3, base_channels=64):
        super().__init__()

        # Input: CNN output (3) + original masked (3) + mask (1) = 7 channels
        self.encoder = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels, base_channels, 7, 1, 3)),
            nn.LeakyReLU(0.2, True)
        )

        # Downsampling
        self.down1 = nn.Sequential(
            spectral_norm(nn.Conv2d(base_channels, base_channels * 2, 4, 2, 1)),
            nn.InstanceNorm2d(base_channels * 2),
            nn.LeakyReLU(0.2, True)
        )

        self.down2 = nn.Sequential(
            spectral_norm(nn.Conv2d(base_channels * 2, base_channels * 4, 4, 2, 1)),
            nn.InstanceNorm2d(base_channels * 4),
            nn.LeakyReLU(0.2, True)
        )

        # Residual blocks with attention
        self.res_blocks = nn.Sequential(
            ResidualBlock(base_channels * 4),
            ResidualBlock(base_channels * 4),
            SelfAttention(base_channels * 4),
            ResidualBlock(base_channels * 4),
            ResidualBlock(base_channels * 4),
        )

        # Upsampling
        self.up1 = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, 2, 1)),
            nn.InstanceNorm2d(base_channels * 2),
            nn.LeakyReLU(0.2, True)
        )

        self.up2 = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(base_channels * 2, base_channels, 4, 2, 1)),
            nn.InstanceNorm2d(base_channels),
            nn.LeakyReLU(0.2, True)
        )

        # Output residual
        self.output = nn.Sequential(
            spectral_norm(nn.Conv2d(base_channels, out_channels, 7, 1, 3)),
            nn.Tanh()
        )

        # Learnable residual weight
        self.residual_weight = nn.Parameter(torch.tensor(0.01))

    def forward(self, cnn_output, masked_image, mask):
        """
        Args:
            cnn_output: CNN's coarse prediction [B, 3, H, W]
            masked_image: Original image with mask applied [B, 3, H, W]
            mask: Binary mask [B, 1, H, W]
        """
        # Concatenate inputs
        x = torch.cat([cnn_output, masked_image, mask], dim=1)

        # Encode
        x = self.encoder(x)

        # Downsample
        d1 = self.down1(x)
        d2 = self.down2(d1)

        # Process with residual blocks
        x = self.res_blocks(d2)

        # Upsample with skip connections
        x = self.up1(x)
        x = self.up2(x + d1)  # Skip connection

        # Generate residual
        residual = self.output(x)

        # Add weighted residual to CNN output
        refined = cnn_output + self.residual_weight * residual

        # Ensure output is in [-1, 1]
        return torch.tanh(refined)

In [12]:
# ============================================
# 2. DISCRIMINATOR - Multi-scale PatchGAN
# ============================================

class MultiscaleDiscriminator(nn.Module):
    """Multi-scale discriminator for better gradient flow"""
    def __init__(self, in_channels=3, base_channels=64, num_scales=3):
        super().__init__()

        self.num_scales = num_scales
        self.discriminators = nn.ModuleList()

        for _ in range(num_scales):
            self.discriminators.append(self._make_discriminator(in_channels, base_channels))

        self.downsample = nn.AvgPool2d(2)

    def _make_discriminator(self, in_channels, base_channels):
        return nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels, base_channels, 4, 2, 1)),
            nn.LeakyReLU(0.2, True),

            spectral_norm(nn.Conv2d(base_channels, base_channels * 2, 4, 2, 1)),
            nn.InstanceNorm2d(base_channels * 2),
            nn.LeakyReLU(0.2, True),

            spectral_norm(nn.Conv2d(base_channels * 2, base_channels * 4, 4, 2, 1)),
            nn.InstanceNorm2d(base_channels * 4),
            nn.LeakyReLU(0.2, True),

            spectral_norm(nn.Conv2d(base_channels * 4, base_channels * 8, 4, 1, 1)),
            nn.InstanceNorm2d(base_channels * 8),
            nn.LeakyReLU(0.2, True),

            spectral_norm(nn.Conv2d(base_channels * 8, 1, 4, 1, 1))
        )

    def forward(self, x):
        outputs = []

        for i in range(self.num_scales):
            outputs.append(self.discriminators[i](x))
            if i < self.num_scales - 1:
                x = self.downsample(x)

        return outputs

In [13]:
# ============================================
# 3. LOSS FUNCTIONS
# ============================================

class PerceptualLoss(nn.Module):
    """VGG-based perceptual loss"""
    def __init__(self):
        super().__init__()
        import torchvision.models as models
        from torchvision.models import VGG19_Weights

        vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features

        # Extract specific layers
        self.layers = nn.ModuleList([
            vgg[:4],   # relu1_2
            vgg[4:9],  # relu2_2
            vgg[9:18], # relu3_4
        ])

        # Freeze VGG
        for param in self.parameters():
            param.requires_grad = False

        # Normalization
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, pred, target):
        # Normalize from [-1, 1] to [0, 1]
        pred = (pred + 1) / 2
        target = (target + 1) / 2

        # VGG normalization
        pred = (pred - self.mean) / self.std
        target = (target - self.mean) / self.std

        loss = 0
        x_pred, x_target = pred, target

        for layer in self.layers:
            x_pred = layer(x_pred)
            x_target = layer(x_target)
            loss += F.l1_loss(x_pred, x_target)

        return loss

def hinge_loss_d(real_pred, fake_pred):
    """Hinge loss for discriminator"""
    loss = 0
    for real_p, fake_p in zip(real_pred, fake_pred):
        loss += torch.mean(F.relu(1 - real_p)) + torch.mean(F.relu(1 + fake_p))
    return loss / len(real_pred)

def hinge_loss_g(fake_pred):
    """Hinge loss for generator"""
    loss = 0
    for fake_p in fake_pred:
        loss += -torch.mean(fake_p)
    return loss / len(fake_pred)

In [14]:
# ============================================
# 4. TRAINING FUNCTION
# ============================================

def train_gan(generator, discriminator, cnn_model, train_loader, val_loader,
              num_epochs=80, device='cuda'):
    """Train the GAN for refinement"""

    # Optimizers
    g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.0, 0.999))

    # Add GradScalers for mixed precision
    g_scaler = torch.amp.GradScaler('cuda')
    d_scaler = torch.amp.GradScaler('cuda')

    # Loss functions
    l1_loss = nn.L1Loss()
    perceptual_loss = PerceptualLoss().to(device)

    # Training history
    history = {'g_loss': [], 'd_loss': [], 'val_psnr': []}

    # Training loop
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()

        epoch_g_loss = 0
        epoch_d_loss = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for batch_idx, batch in enumerate(pbar):
            real_images = batch.to(device)
            batch_size = real_images.size(0)

            # Create masks and get CNN predictions
            with torch.no_grad():
                masked_images, masks = create_mask(real_images, mask_percentage=0.025)
                masks = masks.to(device)

                cnn_input = torch.cat([masked_images, masks[:, 0:1]], dim=1)
                coarse_output = cnn_model(cnn_input)

            # ==================
            # Train Discriminator
            # ==================
            d_optimizer.zero_grad()

            # Wrap forward passes in autocast
            with torch.amp.autocast('cuda'):
                # Generate refined images
                refined_images = generator(coarse_output, masked_images, masks[:, 0:1])

                # Discriminator predictions
                real_pred = discriminator(real_images)
                fake_pred = discriminator(refined_images.detach())

                # Hinge loss
                d_loss = hinge_loss_d(real_pred, fake_pred)

            # Use scaler for backward pass
            d_scaler.scale(d_loss).backward()
            d_scaler.step(d_optimizer)
            d_scaler.update()

            # ==================
            # Train Generator
            # ==================
            # Train generator every 1 discriminator steps
            # Modified loss to heavily prioritize hole improvement
            if batch_idx % 1 == 0:
                g_optimizer.zero_grad()

                with torch.amp.autocast('cuda'):
                    refined_images = generator(coarse_output, masked_images, masks[:, 0:1])
                    fake_pred = discriminator(refined_images)

                    # AGGRESSIVE HOLE FOCUS:
                    hole_mask = 1 - masks
                    valid_mask = masks

                    # 1. Adversarial - let it be stronger for realism
                    g_adv_loss = hinge_loss_g(fake_pred) * 0.1

                    # 2. Hole-only reconstruction (PRIMARY LOSS)
                    g_hole_loss = l1_loss(refined_images * hole_mask, real_images * hole_mask) * 200  # HUGE weight

                    # 3. Perceptual loss ONLY in holes
                    g_perc_hole = perceptual_loss(refined_images * hole_mask, real_images * hole_mask) * 1.0

                    # 4. Edge/boundary loss for better blending
                    dilated_mask = F.max_pool2d(1 - masks[:, 0:1], 5, stride=1, padding=2)
                    boundary = dilated_mask - (1 - masks[:, 0:1])
                    g_boundary_loss = l1_loss(refined_images * boundary, real_images * boundary) * 50

                    # 5. Small penalty for changing non-hole regions
                    g_preserve_loss = l1_loss(refined_images * valid_mask, coarse_output * valid_mask) * 10

                    # Total loss focuses on improving holes
                    g_loss = g_adv_loss + g_hole_loss + g_perc_hole + g_boundary_loss + g_preserve_loss

                # Use scaler for backward pass
                g_scaler.scale(g_loss).backward()
                g_scaler.step(g_optimizer)
                g_scaler.update()

                # Update residual weight (optional)
                generator.residual_weight.data.clamp_(0.01, 0.1)

            # Update progress bar
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            pbar.set_postfix({
                'G': f'{g_loss.item():.4f}',
                'D': f'{d_loss.item():.4f}',
                'α': f'{generator.residual_weight.item():.3f}'
            })

        # Validation
        if epoch % 5 == 0:
            val_psnr = validate_gan(generator, cnn_model, val_loader, device)
            history['val_psnr'].append(val_psnr)
            print(f"\nValidation PSNR: {val_psnr:.2f} dB")

            # Save sample images
            save_samples(generator, cnn_model, val_loader, epoch, device)

        # Save checkpoint with scaler states
        if epoch % 10 == 0:
            torch.save({
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
                'g_optimizer': g_optimizer.state_dict(),
                'd_optimizer': d_optimizer.state_dict(),
                'g_scaler': g_scaler.state_dict(),  # Save scaler state
                'd_scaler': d_scaler.state_dict(),  # Save scaler state
                'epoch': epoch
            }, f'gan_checkpoint_epoch_{epoch}.pth')

        history['g_loss'].append(epoch_g_loss / len(train_loader))
        history['d_loss'].append(epoch_d_loss / len(train_loader))

    return history

# ============================================
# 5. VALIDATION AND INFERENCE
# ============================================

def validate_gan(generator, cnn_model, val_loader, device):
    """Calculate validation metrics with mixed precision"""
    generator.eval()
    total_psnr = 0
    count = 0

    with torch.no_grad():
        for batch in val_loader:
            real_images = batch.to(device)

            # Add autocast for validation too
            with torch.amp.autocast('cuda'):
                masked_images, masks = create_mask(real_images, mask_percentage=0.025)
                masks = masks.to(device)

                cnn_input = torch.cat([masked_images, masks[:, 0:1]], dim=1)
                coarse_output = cnn_model(cnn_input)

                refined_images = generator(coarse_output, masked_images, masks[:, 0:1])

            # Calculate metrics in FP32 (outside autocast)
            mse = F.mse_loss((refined_images + 1) / 2, (real_images + 1) / 2)
            psnr = 20 * torch.log10(torch.tensor(1.0)) - 10 * torch.log10(mse)

            total_psnr += psnr.item() * real_images.size(0)
            count += real_images.size(0)

    return total_psnr / count

def save_samples(generator, cnn_model, val_loader, epoch, device):
    """Save sample images"""
    generator.eval()

    with torch.no_grad():
        batch = next(iter(val_loader))
        real_images = batch[:4].to(device)

        # Create masks and get CNN predictions
        masked_images, masks = create_mask(real_images, mask_percentage=0.025)
        masks = masks.to(device)

        cnn_input = torch.cat([masked_images, masks[:, 0:1]], dim=1)
        coarse_output = cnn_model(cnn_input)

        # Generate refined images
        refined_images = generator(coarse_output, masked_images, masks[:, 0:1])

        # Create comparison grid
        comparison = torch.cat([
            (real_images + 1) / 2,
            (masked_images + 1) / 2,
            (coarse_output + 1) / 2,
            (refined_images + 1) / 2
        ], dim=0)

        save_image(comparison, f'gan_samples_epoch_{epoch}.png', nrow=4)

# ============================================
# 6. INFERENCE FUNCTION
# ============================================

def inference_gan(generator, cnn_model, image, mask, device):
    """Run inference with the trained GAN"""
    generator.eval()
    cnn_model.eval()

    with torch.no_grad():
        # Get CNN prediction
        masked_image = image * mask
        cnn_input = torch.cat([masked_image, mask[:, 0:1]], dim=1)
        coarse_output = cnn_model(cnn_input)

        # Refine with GAN
        refined = generator(coarse_output, masked_image, mask[:, 0:1])

        # Ensure unmasked regions are preserved
        final_output = refined * (1 - mask) + image * mask

        return final_output, coarse_output

In [None]:
# ============================================
# USAGE
# ============================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
generator = RefinementGenerator(in_channels=7, out_channels=3).to(device)
discriminator = MultiscaleDiscriminator(in_channels=3).to(device)

# Load your pre-trained CNN
cnn_model = UNetSR().to(device)  # Use the same initialization as training!

# Load and fix the state dict
cnn_state_dict = torch.load('/content/drive/MyDrive/ColabNotebooks/Vision/inpainting/final_cnn_model.pth', map_location=device)

# Remove _orig_mod. prefix if it exists
if any(k.startswith('_orig_mod.') for k in cnn_state_dict.keys()):
    cnn_state_dict = {k.replace('_orig_mod.', ''): v for k, v in cnn_state_dict.items()}

# Now load the fixed state dict
cnn_model.load_state_dict(cnn_state_dict)

# Set to eval and freeze
cnn_model.eval()
for param in cnn_model.parameters():
    param.requires_grad = False

print("✅ CNN loaded successfully")

# Train
history = train_gan(
    generator,
    discriminator,
    cnn_model,
    train_loader=gan_dataloader,
    val_loader=val_dataloader,
    num_epochs=80,
    device=device
)

# Save final models
torch.save(generator.state_dict(), 'gan_generator_final.pth')
torch.save(discriminator.state_dict(), 'gan_discriminator_final.pth')

✅ CNN loaded successfully


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 233MB/s]


Epoch 1/80:   0%|          | 0/31 [00:00<?, ?it/s]


Validation PSNR: 26.85 dB


Epoch 2/80:   0%|          | 0/31 [00:00<?, ?it/s]

Epoch 3/80:   0%|          | 0/31 [00:00<?, ?it/s]

Epoch 4/80:   0%|          | 0/31 [00:00<?, ?it/s]