In [57]:
import sys
sys.path.append('../src')
import main
import numpy as np
from PIL import Image
import config
import utils.ckpt as ckpt
import utils.misc as misc
import models.model as model
import torch
import os
import matplotlib.pyplot as plt
import matplotlib as mpl

plt.rcParams.update({'font.size': 18})

In [67]:
def compute_batch_image_gradients(generator, discriminator, z_dim, num_classes, batch_size=64, device='cuda'):
    """
    Compute gradients from each discriminator with respect to a batch of fake images.
    
    Args:
        generator: The generator model
        discriminator: The discriminator model (ensemble)
        z_dim: Dimension of the latent space
        num_classes: Number of classes
        batch_size: Batch size to use (matching evaluation code)
        device: Device to use ('cuda' or 'cpu')
        
    Returns:
        dict: Dictionary containing fake images, discriminator outputs, and gradients
    """
    
    # Set models to eval mode
    generator.eval()
    discriminator.eval()
    
    # Generate random latent vectors
    z = torch.randn(batch_size, z_dim, device=device)
    
    # Generate random class labels
    fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
    
    # Generate the images
    with torch.no_grad():
        if hasattr(generator, 'module'):
            fake_images = generator.module(z, fake_labels)
        else:
            fake_images = generator(z, fake_labels)
    
    # Compute discriminator outputs and gradients
    D_outs = []
    fake_grads = []
    
    # Enable gradients for the fake images
    fake_images.requires_grad_(True)
    
    # Check if discriminator has discriminators attribute (ensemble)
    if hasattr(discriminator, 'discriminators'):
        discriminators = discriminator.discriminators
    else:
        # If single discriminator, wrap in a list
        discriminators = [discriminator]
    
    # Get discriminator outputs and gradients
    for i, discr in enumerate(discriminators):
        # Forward pass
        D_out = discr(fake_images, fake_labels)["adv_output"]
        
        # Compute mean of discriminator outputs across the batch
        D_out_mean = D_out.mean()
        D_outs.append(D_out_mean.item())
        
        # Compute gradient
        fake_images.grad = None  # Clear previous gradients
        D_out_mean.backward(retain_graph=True)
        
        # Save gradient
        current_grad = fake_images.grad.clone()
        fake_grads.append(current_grad.detach().cpu().numpy())
    
    return {
        'images': fake_images.detach().cpu(),
        'labels': fake_labels.detach().cpu(),
        'D_outs': D_outs,
        'gradients': fake_grads
    }


def analyze_gradient_statistics(results):
    """
    Analyze and compare gradient statistics across discriminators and images.
    
    Args:
        results: Dictionary from compute_batch_image_gradients
    
    Returns:
        dict: Statistics about the gradients
    """    
    gradients = results['gradients']
    stats = {}
    
    # Overall statistics for each discriminator
    for i, grad in enumerate(gradients):
        stats[f'disc_{i}_mean_abs_grad'] = np.abs(grad).mean()
        stats[f'disc_{i}_std_abs_grad'] = np.abs(grad).std()
        stats[f'disc_{i}_max_abs_grad'] = np.abs(grad).max()
        
        # Per-channel statistics
        for c in range(3):
            stats[f'disc_{i}_channel_{c}_mean_abs_grad'] = np.abs(grad[:, c]).mean()
    
    # Compute inter-discriminator correlation
    if len(gradients) > 1:
        for i in range(len(gradients)):
            for j in range(i+1, len(gradients)):
                grad_i_flat = np.abs(gradients[i]).reshape(gradients[i].shape[0], -1)
                grad_j_flat = np.abs(gradients[j]).reshape(gradients[j].shape[0], -1)
                
                # Compute correlation across all dimensions
                corr = np.corrcoef(grad_i_flat.mean(axis=1), grad_j_flat.mean(axis=1))[0, 1]
                stats[f'disc_{i}_{j}_correlation'] = corr
    
    return stats


def visualize_batch_gradients(results, sample_idx=0, channel_idx=0, save_path='./batch_gradients'):
    """
    Visualize gradients for a specific image from the batch and a specific channel.
    
    Args:
        results: Dictionary from compute_batch_image_gradients
        sample_idx: Index of the sample image to visualize
        channel_idx: Index of the color channel to visualize (0=Red, 1=Green, 2=Blue)
        save_path: Directory to save visualizations
    """
    
    os.makedirs(save_path, exist_ok=True)
    
    # Make sure sample_idx is valid
    if sample_idx >= len(results['images']):
        raise ValueError(f"Sample index {sample_idx} out of range (batch size: {len(results['images'])})")
    
    # Make sure channel_idx is valid
    if channel_idx not in [0, 1, 2]:
        raise ValueError(f"Channel index must be 0 (Red), 1 (Green), or 2 (Blue), got {channel_idx}")
    
    images = results['images']
    gradients = results['gradients']
    D_outs = results['D_outs']
    
    channel_names = ['Red', 'Green', 'Blue']
    channel_cmaps = ['Reds', 'Greens', 'Blues']  # Channel-specific colormaps
    
    # Convert the image from [-1, 1] to [0, 255] uint8 format
    img = images[sample_idx].numpy().transpose(1, 2, 0)
    img = np.clip((img + 1) / 2.0 * 255, 0, 255).astype(np.uint8)
    
    # Create a figure with 2 + len(discriminators) subplots
    num_discs = len(gradients)
    fig, axes = plt.subplots(1, 2 + num_discs, figsize=(5 * (2 + num_discs), 5))
    
    # 1. Plot the RGB image
    axes[0].imshow(img)
    #axes[0].set_title(f"Generated Image")
    axes[0].axis('off')
    
    # 2. Plot the selected channel only
    channel_img = np.zeros_like(img)
    channel_img[:, :, channel_idx] = img[:, :, channel_idx]
    axes[1].imshow(channel_img)
    #axes[1].set_title(f"{channel_names[channel_idx]} Channel Features")
    axes[1].axis('off')
    
    # 3. Plot gradient maps for the selected channel for each discriminator
    cmap = channel_cmaps[channel_idx]  # Use the channel-specific colormap
    
    # Plot each discriminator's gradient with its own maximum, but all starting at 0
    for disc_idx in range(num_discs):
        grad = gradients[disc_idx][sample_idx][channel_idx]
        abs_grad = np.abs(grad)
        
        # Set vmin to 0 (white) and vmax to the maximum value for this gradient
        vmin = 0
        vmax = abs_grad.max()
        
        im = axes[2 + disc_idx].imshow(abs_grad, cmap=cmap, vmin=vmin, vmax=vmax)
        axes[2 + disc_idx].set_title(f"Discriminator {disc_idx}")
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=axes[2 + disc_idx], shrink=0.8)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f'sample_{sample_idx}_{channel_names[channel_idx]}_channel_gradients.pdf'), dpi=150)
    plt.close()

In [64]:
#### Initialize configuration with your config file
config_path = '../src/configs/CIFAR10-Ensemble/DCGAN-ens-3-ew3.yaml'
cfgs = config.Configurations(config_path)

# Path to your checkpoint directory
ckpt_dir = '../logs/CIFAR10-Ensemble-trash/checkpoints/CIFAR10-DCGAN-ens-3-ew3-train-2024_06_06_15_10_31/'

run_configs = {
    'ckpt_dir': ckpt_dir,
    'mixed_precision': False,  # Set this to match your training settings
    'train': False,  # We're not training, just loading a checkpoint
    'distributed_data_parallel': False,
    'seed': 42,
    'cfg_file': config_path,
    'freezeD': -1,
    'langevin_sampling': False
    # Add other run-time configurations as needed
}

# Update the configuration
cfgs.update_cfgs(run_configs, super="RUN")

# Create model instances
Gen, Gen_mapping, Gen_synthesis, Dis, Gen_ema, Gen_ema_mapping, Gen_ema_synthesis, ema = model.load_generator_discriminator(
    DATA=cfgs.DATA,
    OPTIMIZATION=cfgs.OPTIMIZATION,
    MODEL=cfgs.MODEL,
    STYLEGAN=cfgs.STYLEGAN,
    MODULES=cfgs.MODULES,
    RUN=cfgs.RUN,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    logger=None
)

# Define optimizers (needed for loading checkpoint)
cfgs.define_optimizer(Gen, Dis)

# Load checkpoint
run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, lecam_emas, logger = ckpt.load_StudioGAN_ckpts(
    ckpt_dir=ckpt_dir,
    load_best=True,  # Set to True if you want to load the best checkpoint
    Gen=Gen,
    Dis=Dis,
    g_optimizer=cfgs.OPTIMIZATION.g_optimizer,
    d_optimizer=cfgs.OPTIMIZATION.d_optimizer,
    run_name="test",
    apply_g_ema=cfgs.MODEL.apply_g_ema,
    Gen_ema=Gen_ema,
    ema=ema,
    is_train=False,
    RUN=cfgs.RUN,
    logger=None,
    global_rank=0,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    cfg_file=cfgs.RUN.cfg_file
)

# Now you have loaded models ready to use
print(f"Loaded checkpoint from step {step}")

# Ensure generator is in eval mode
generator = Gen_ema if cfgs.MODEL.apply_g_ema and Gen_ema is not None else Gen
generator.eval()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Loaded checkpoint from step 74000


In [65]:
# Run the batch gradient computation
batch_results = compute_batch_image_gradients(
    generator=generator,
    discriminator=Dis,
    z_dim=cfgs.MODEL.z_dim,
    num_classes=cfgs.DATA.num_classes,
    batch_size=64,  # Match the batch size in the metrics code
    device=device
)

# Analyze gradient statistics
stats = analyze_gradient_statistics(batch_results)
for key, value in stats.items():
    print(f"{key}: {value:.6f}")


disc_0_mean_abs_grad: 0.001772
disc_0_std_abs_grad: 0.001734
disc_0_max_abs_grad: 0.025737
disc_0_channel_0_mean_abs_grad: 0.001957
disc_0_channel_1_mean_abs_grad: 0.001722
disc_0_channel_2_mean_abs_grad: 0.001639
disc_1_mean_abs_grad: 0.018019
disc_1_std_abs_grad: 0.029191
disc_1_max_abs_grad: 0.563287
disc_1_channel_0_mean_abs_grad: 0.025705
disc_1_channel_1_mean_abs_grad: 0.005081
disc_1_channel_2_mean_abs_grad: 0.023272
disc_2_mean_abs_grad: 0.002806
disc_2_std_abs_grad: 0.002696
disc_2_max_abs_grad: 0.045883
disc_2_channel_0_mean_abs_grad: 0.002962
disc_2_channel_1_mean_abs_grad: 0.002693
disc_2_channel_2_mean_abs_grad: 0.002762
disc_0_1_correlation: 0.105877
disc_0_2_correlation: -0.063537
disc_1_2_correlation: -0.099517


In [80]:
# Visualize gradients for a few images
visualize_batch_gradients(
    batch_results, 
    sample_idx=13, 
    channel_idx=0,  # 0=Red, 1=Green, 2=Blue
    save_path='./03_gradient_map')