In [2]:
import torch
from torchvision import datasets, transforms

from torch.utils.data import random_split
import matplotlib.pyplot as plt
import numpy as np
import os
import math

In [3]:

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 20000
set_size = 60000
batch_size = 2
sample_dir = 'samples'
save_dir = 'save'

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])


# Data loader
# data_loader = torch.utils.data.DataLoader(dataset=mnist,
#                                          batch_size=batch_size, 
#                                          shuffle=True)
# Download MNIST dataset
mnist_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transform, 
                               download=True)
# Split original training set into 70% train and 30% validation
train_size = int(0.7 * len(mnist_dataset))
val_size = len(mnist_dataset) - train_size
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])

# Select a random image from the new training set
random_index = np.random.randint(len(train_dataset))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

mnist_images = train_dataloader.dataset.dataset.data
mnist_labels = train_dataset.dataset.targets


In [4]:

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 20000
set_size = 60000
batch_size = 2
sample_dir = 'samples'
save_dir = 'save'

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])


# Data loader
# data_loader = torch.utils.data.DataLoader(dataset=mnist,
#                                          batch_size=batch_size, 
#                                          shuffle=True)
# Download MNIST dataset
mnist_dataset = datasets.FashionMNIST(root='data', 
                               train=True, 
                               transform=transform, 
                               download=True)
# Split original training set into 70% train and 30% validation
train_size = int(0.7 * len(mnist_dataset))
val_size = len(mnist_dataset) - train_size
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])

# Select a random image from the new training set
random_index = np.random.randint(len(train_dataset))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

fashion_mnist_images = train_dataloader.dataset.dataset.data
fashion_mnist_labels = train_dataset.dataset.targets

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),                # Convert to tensor
    transforms.Grayscale(num_output_channels=1),  # Convert RGB to Grayscale
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

train_dataset = datasets.CIFAR10(root='data', 
                               train=True, 
                               transform=transform, 
                               download=True)

# # Split original training set into 70% train and 30% validation
train_size = int(0.7 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# # Select a random image from the new training set
# random_index = np.random.randint(len(train_dataset))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)
cifar_images = torch.tensor(train_dataloader.dataset.dataset.data).permute(0, 3, 1, 2) / 255.0  # Convert to Tensor & Normalize

# cifar_images = train_dataloader.dataset.dataset.data
cifar_labels = train_dataset.dataset.targets

to_grayscale = transforms.Grayscale(num_output_channels=1)
cifar_gray = torch.stack([to_grayscale(img) for img in cifar_images])

In [None]:
def compute_cdf_mapping(images):
    """
    Compute the CDF mapping for each pixel across all images.
    """
    
    vector_size = images.shape[1]
    images = images.view(-1, vector_size)  # Flatten images
    sorted_pixels, _ = torch.sort(images, dim=0)
    cdf = torch.linspace(0, 1, images.shape[0])
    pixel_cdf_map = {}
    
    for i in range(vector_size):
        pixel_cdf_map[i] = (sorted_pixels[:, i], cdf)
    
    return pixel_cdf_map

def transform_original_to_uniform(image, pixel_cdf_map, reverse=False):
    """
    Transform an image using the pixel CDF mapping.
    """
    image = image.view(-1)  # Flatten image
    transformed_image = torch.zeros_like(image, dtype=torch.float32)
    
    pixel_size = int(image.shape[0] ** 0.5)
    for i in range(pixel_size * pixel_size):
        pixel_values, cdf_values = pixel_cdf_map[i]
        
        if not reverse:
            # Forward transformation (find CDF value for each pixel)
            indices = torch.searchsorted(pixel_values, image[i])
            transformed_image[i] = cdf_values[min(indices, len(cdf_values) - 1)]
        else:
            # Reverse transformation (find original pixel from CDF value)
            indices = torch.searchsorted(cdf_values, image[i])
            transformed_image[i] = pixel_values[min(indices, len(pixel_values) - 1)]
    
    transformed_image = transformed_image.view(pixel_size, pixel_size)  # Reshape back to image size
    return transformed_image

def compute_mean_and_covariance(images, epsilon=1e-4):
    """
    Compute the covariance matrix of the dataset, ensuring it is positive definite.
    """
    images = images.float()  # Ensure floating-point type
    images = images.view(images.shape[0], -1)  # Flatten images
    mean_vector = torch.mean(images, dim=0, keepdim=True)
    centered_images = images - mean_vector
    covariance_matrix = torch.matmul(centered_images.T, centered_images) / (images.shape[0] - 1)
    
    # Regularization: Add a small identity matrix to ensure positive definiteness
    covariance_matrix += epsilon * torch.eye(covariance_matrix.shape[0])
    
    return covariance_matrix, mean_vector

def generate_random_uniform_images_from_gaussian(covariance_matrix_gaussian, num_samples = 1):
    vector_size = covariance_matrix_gaussian.size(0)
    num_pixels = int(vector_size**0.5)
    mean = torch.zeros(vector_size, device=covariance_matrix_gaussian.device)
    
    # Create the multivariate normal distribution
    mvn = torch.distributions.MultivariateNormal(
        loc=mean,
        covariance_matrix=covariance_matrix_gaussian
    )
    
    # Generate samples
    gaussian_samples = mvn.rsample((num_samples,))
    uniform_samples = 0.5 * (1 + torch.erf(gaussian_samples / np.sqrt(2)))  # Convert Gaussian to Uniform [0,1]
    return uniform_samples.reshape(num_samples, num_pixels, num_pixels)


def generate_random_uniform_images(covariance_matrix, num_samples = 1):
    """
    Generate samples from a multivariate normal distribution with mean 0 and given covariance matrix.
    
    Args:
        num_samples (int): Number of samples to generate.
        covariance_matrix (torch.Tensor): Covariance matrix of shape (num_pixels, num_pixels).
        
    Returns:
        torch.Tensor: Tensor of shape (num_samples, num_pixels) containing the samples.
    """
    vector_size = covariance_matrix.size(0)
    num_pixels = int(vector_size**0.5)
    mean = torch.zeros(vector_size, device=covariance_matrix.device)
    
    # Create the multivariate normal distribution
    mvn = torch.distributions.MultivariateNormal(
        loc=mean,
        covariance_matrix=covariance_matrix
    )
    
    # Generate samples
    gaussian_samples = mvn.rsample((num_samples,))
    uniform_samples = 0.5 * (1 + torch.erf(gaussian_samples / np.sqrt(2)))  # Convert Gaussian to Uniform [0,1]
    return uniform_samples.reshape(num_samples, num_pixels, num_pixels)


# def transform_dataset_to_uniform_vectorized(dataset):
#     """
#     Transform an entire dataset of images (tensor of shape (n, p, p)) so that for each pixel location,
#     the empirical distribution of pixel values becomes uniform in [0,1].
    
#     The transformation is performed using a double argsort to obtain the rank of each pixel value.
    
#     Args:
#         dataset (torch.Tensor): Tensor of shape (n, p, p)
        
#     Returns:
#         torch.Tensor: Transformed dataset with shape (n, p, p) where each pixel's value is in [0,1].
#     """
#     n, p, _ = dataset.shape
#     # Flatten images: shape (n, p*p)
#     flat_data = dataset.view(n, -1)
    
#     # Compute the rank for each pixel across the dataset.
#     # First argsort sorts the values; the second argsort recovers the rank.
#     ranks = flat_data.argsort(dim=0).argsort(dim=0).float()
    
#     # Normalize ranks to [0,1]
#     uniform_flat = ranks / (n - 1)
    
#     # Reshape back to (n, p, p)
#     return uniform_flat.view(n, p, p)
# import torch

def transform_dataset_to_uniform_and_gaussian_vectorized(dataset):
    """
    Transform an entire dataset of images (tensor of shape (n, p, p)) so that for each pixel location,
    the empirical distribution of pixel values becomes uniform in [0,1]. Then, using the relation
    z = sqrt(2) * erfinv(2u - 1), convert the uniform dataset into a Gaussian distributed dataset.
    
    Args:
        dataset (torch.Tensor): Tensor of shape (n, p, p)
        
    Returns:
        uniform_dataset (torch.Tensor): Transformed dataset with shape (n, p, p) with values in [0,1].
        gaussian_dataset (torch.Tensor): Dataset transformed to be Gaussian distributed.
    """
    n, p, _ = dataset.shape
    # Flatten images: shape (n, p*p)
    flat_data = dataset.view(n, -1)
    
    # Compute the rank for each pixel across the dataset via double argsort.
    # First argsort sorts the values; the second argsort recovers the rank.
    ranks = flat_data.argsort(dim=0).argsort(dim=0).float()
    
    # Normalize ranks to [0,1]
    uniform_flat = ranks / (n - 1)
    
    # Reshape back to (n, p, p)
    uniform_dataset = uniform_flat.view(n, p, p)
    
    # Convert the uniform dataset to Gaussian distributed data:
    # For each pixel, apply: z = sqrt(2) * erfinv(2u - 1)
    gaussian_dataset = torch.erfinv(2 * uniform_dataset - 1) * torch.sqrt(torch.tensor(2.0))
    
    return uniform_dataset, gaussian_dataset


# def generate_uniform_iman_conover(covariance_matrix, n_samples=10000):
#     """
#     Generate a sample from a d-dimensional distribution with Uniform(0,1) marginals 
#     whose covariance is approximately the given covariance_matrix.
    
#     The method uses the Iman–Conover procedure.
    
#     Args:
#         covariance_matrix (np.ndarray): A d x d target covariance matrix.
#             (For Uniform[0,1], the variance is 1/12, so the diagonal of covariance_matrix
#             should be about 1/12.)
#         n_samples (int): Number of samples to generate for the Iman–Conover adjustment.
        
#     Returns:
#         sample (np.ndarray): A d x 1 column vector drawn from the adjusted Uniform(0,1) distribution.
#     """
#     # Dimension (d) inferred from the covariance matrix
#     d = covariance_matrix.shape[0]
    
#     # For Uniform[0,1], variance = 1/12. So the target correlation matrix is:
#     R_target = covariance_matrix * 12.0
    
#     # Step 1: Generate independent Uniform(0,1) samples: shape (n_samples, d)
#     U = np.random.uniform(0, 1, size=(n_samples, d))
    
#     # Step 2: Standardize each column of U
#     U_std = (U - U.mean(axis=0)) / U.std(axis=0, ddof=1)
    
#     # Compute the empirical correlation matrix of U_std
#     R_empirical = np.corrcoef(U_std, rowvar=False)
    
#     # Step 3: Compute Cholesky factors for the empirical and target correlation matrices
#     L_empirical = np.linalg.cholesky(R_empirical)
#     L_target = np.linalg.cholesky(R_target)
    
#     # Step 4: Compute the adjustment matrix
#     A = L_target @ np.linalg.inv(L_empirical)
    
#     # Adjust the standardized samples
#     Z = U_std @ A.T
    
#     # Step 5: For each variable (column), reassign values based on the ranks of Z,
#     # so that the marginals remain Uniform(0,1) but the correlation structure is adjusted.
#     U_adjusted = np.empty_like(U)
#     for j in range(d):
#         order = np.argsort(Z[:, j])
#         sorted_vals = np.sort(U[:, j])
#         U_adjusted[order, j] = sorted_vals
    
#     # Step 6: Choose one sample (e.g., the first row) and return it as a column vector.
#     sample = U_adjusted[0, :].reshape(-1, 1)
#     uniform_random_image = torch.tensor(sample).view(-1)
#     return uniform_random_image

# def generate_uniform_from_gaussian(covariance_matrix_uniform, epsilon=1e-4):
#     """
#     Generates a column vector of uniform samples with the specified covariance matrix,
#     using a multivariate Gaussian transformation. Handles PyTorch tensors and allows
#     diagonal entries close to (but not exactly) 1/12.
#     """
#     d = covariance_matrix_uniform.size(0)
    
#     # Compute standard deviations for each uniform variable
#     sigma = torch.sqrt(torch.diag(covariance_matrix_uniform))  # Shape: (d,)
    
#     # Compute correlation matrix for the uniforms
#     outer_sigma = torch.outer(sigma, sigma)
#     R_uniform = covariance_matrix_uniform / outer_sigma  # Shape: (d, d)
    
#     # Compute Gaussian correlation matrix using adjusted formula
#     R_gaussian = 2 * torch.sin((math.pi / 6) * R_uniform)
#     R_gaussian.fill_diagonal_(1.0)  # Ensure diagonal is exactly 1
    
#     # Cholesky decomposition (requires positive definite matrix)
#     try:
#         L = torch.linalg.cholesky(R_gaussian + epsilon*torch.eye(d))
#     except RuntimeError as e:
#         raise ValueError("Invalid covariance: Resulting Gaussian correlation is not positive definite.") from e
    
#     # Generate multivariate Gaussian sample
#     z = torch.randn(d)  # Standard normal sample
#     gaussian_sample = L @ z  # Shape: (d,)
    
#     # Transform Gaussian to uniform using CDF
#     uniform_sample = 0.5 * (1 + torch.erf(gaussian_sample / math.sqrt(2)))  # Shape: (d,)
    
#     # Scale to match desired covariance (adjust variances and covariances)
#     scale_factor = torch.sqrt(12 * torch.diag(covariance_matrix_uniform))
#     scaled_uniform = uniform_sample * scale_factor + 0.5 * (1 - scale_factor)
    
#     return scaled_uniform.reshape(-1, 1)  # Return as column vector
def uniform_to_gaussian_covariance(sigma_u, epsilon = 1e-4):
    """
    Convert the covariance matrix of a multivariate uniform (0,1) distribution
    to the covariance matrix of a multivariate normal distribution 
    after applying the inverse Gaussian CDF transformation.

    Args:
        sigma_u (torch.Tensor): Covariance matrix of the uniform distribution (d x d)

    Returns:
        sigma_n (torch.Tensor): Covariance matrix of the corresponding Gaussian distribution (d x d)
    """
    # Compute the correlation matrix from Sigma_U
    std_u = torch.sqrt(torch.diag(sigma_u))
    corr_u = sigma_u / (std_u[:, None] * std_u[None, :])  # Normalize to get correlation

    # Apply the Gaussian copula transformation
    corr_n = 2 * torch.sin((torch.pi / 6) * corr_u) + torch.eye(sigma_u.shape[0])

    # Convert back to covariance by assuming standard normal variance (1)
    sigma_n = corr_n  # Since std_n = 1 for standard normal, no need to scale back

    return sigma_n
def compute_gaussian_covariance_from_uniform(Sigma_u, epsilon = 1e-4):
    """
    Compute the Gaussian covariance matrix from the uniform covariance matrix.
    
    Args:
        Sigma_u (torch.Tensor): Covariance matrix of the uniforms, shape (d, d).
    
    Returns:
        torch.Tensor: Gaussian covariance matrix, shape (d, d).
    """
    # Compute Pearson correlation matrix of the uniforms
    diag_var_u = torch.diag(Sigma_u)  # Variances of uniforms (should be ~1/12)
    std_u = torch.sqrt(diag_var_u)    # Standard deviations of uniforms
    outer_std = torch.outer(std_u, std_u)
    R_u = Sigma_u / outer_std         # Pearson correlation matrix of uniforms

    # Compute Gaussian correlation matrix
    R_n = 2 * torch.sin((math.pi / 6) * R_u)

    # Ensure diagonal is exactly 1 (due to numerical precision)
    R_n.fill_diagonal_(1.0)
    return R_n + torch.eye(R_u.shape[0])


def generate_uniform_random_cifar_rgb_per_class(label_idx, num_samples = 5000, set_name = 'cifar-rgb', ifPlot = False):
    num_bands = 1
    if set_name == 'mnist':
        images = mnist_images
        train_images = mnist_images[mnist_labels == label_idx]
        pixel_size = images.shape[2]
    elif set_name == 'fashion-mnist':
        images = fashion_mnist_images[fashion_mnist_labels == label_idx]
        pixel_size = images.shape[2]
        train_images = mnist_images[mnist_labels == label_idx]
    elif set_name == 'cifar-gray':
        pixel_size = cifar_images.shape[2]
        # images = cifar_gray.squeeze(dim=1)
        # images = cifar_gray[torch.tensor(cifar_labels) == label_idx].squeeze(dim=1)
        # train_images = images.reshape(images.shape[0], images.shape[1]*images.shape[2])
        train_images = cifar_gray[torch.tensour(cifar_labels) == label_idx,:,:].squeeze(dim=1)
        print(train_images.shape)
    elif set_name == 'cifar-rgb':
        num_bands = 3
        pixel_size = cifar_images.shape[2]
        for band_idx in range(num_bands): train_images = cifar_images[torch.tensor(cifar_labels) == label_idx,band_idx,:,:].squeeze(dim=1)
    else: return [], []
    
    uniform_real_images = torch.zeros(num_samples, num_bands, pixel_size, pixel_size)
    uniform_fake_images = torch.zeros(num_samples, num_bands, pixel_size, pixel_size)
    synthetic_image = torch.zeros(num_bands, pixel_size, pixel_size)
    pixel_cdf_map = [0]*num_bands
    for band_idx in range(num_bands):
        pixel_cdf_map[band_idx] = compute_cdf_mapping(train_images.reshape(-1, pixel_size**2))
        uniform_real_images[:, band_idx, :, :], _ = transform_dataset_to_uniform_and_gaussian_vectorized(train_images[0:num_samples])
        # gaussian_images[gaussian_images > 1000] = 1000
        covariance_matrix_uniform, mean_uniform = compute_mean_and_covariance(uniform_real_images[:, band_idx, :, :], epsilon = 0.000001)
        covariance_matrix_gaussian = uniform_to_gaussian_covariance(covariance_matrix_uniform, epsilon = 1e-4) 
        covariance_matrix_gaussian = compute_gaussian_covariance_from_uniform(covariance_matrix_uniform, epsilon = 0) 
        # uniform_fake_images[:, band_idx,:,:] = generate_random_uniform_images(covariance_matrix_gaussian, num_samples = num_samples)
        uniform_fake_images[:, band_idx,:,:] = generate_random_uniform_images_from_gaussian(covariance_matrix_uniform)


    if ifPlot :
        number_of_rows_or_columns = int(num_samples**0.5)
        fig, axs = plt.subplots(number_of_rows_or_columns, number_of_rows_or_columns, figsize=(number_of_rows_or_columns, number_of_rows_or_columns))
        idx = -1
        for idx_i in range(number_of_rows_or_columns):
            for idx_j in range (0, number_of_rows_or_columns):
                idx += 1
                print(f"Class: {label_idx}, image {idx+1}/{number_of_rows_or_columns**2}" , end="\r")
                # uniform_random_image = generate_uniform_iman_conover(covariance_matrix_uniform, n_samples=2000)
                # uniform_random_image =  generate_uniform_from_gaussian(covariance_matrix_uniform, epsilon=0.0001)
                # uniform_random_image, gaussian_random_image = generate_random_image(covariance_matrix_uniform)
                for band_idx in range(num_bands):
                    uniform_fake_image = uniform_fake_images[idx, band_idx,:,:]
                    synthetic_image[band_idx,:,:] = transform_original_to_uniform(uniform_fake_image.squeeze(), pixel_cdf_map[band_idx], reverse=True)

                # _, gaussian_random_image = generate_random_image(covariance_matrix_gaussian, 0*mean_uniform)
                # uniform_image_from_gaussian = transform_uniform_to_gaussian(gaussian_random_image, reverse=True)
                # synthetic_image  = transform_original_to_uniform(uniform_image_from_gaussian, pixel_cdf_map, reverse=True)

                if num_bands == 3:
                    axs[idx_i,idx_j].imshow(synthetic_image.permute(1,2,0))
                else:
                    axs[idx_i,idx_j].imshow(synthetic_image.permute(1,2,0), cmap='gray')
                # axs[idx_i,idx_j].imshow(cifar_images[np.random.randint(cifar_images.shape[0]),:,:].permute(1,2,0))

                # axs[idx_i,idx_j].set_title('Uniform', fontsize=10)
                axs[idx_i,idx_j].set_xticks([])  # Hide x-ticks
                axs[idx_i,idx_j].set_yticks([])  # Hide x-ticks
        fig.suptitle('Fake ' + set_name + ' images')
        plt.savefig('fake-' + set_name + '-class-'+ str(label_idx) +'.pdf')
        plt.close(fig)  # Close figure to free memory
    return uniform_real_images, uniform_fake_images

num_samples = 100
num_classes = 2
uniform_cifar_real_images = torch.zeros(num_samples*num_classes, 3, 32, 32)
uniform_cifar_fake_images = torch.zeros(num_samples*num_classes, 3, 32, 32)
set_name = 'mnist'
set_name = 'cifar-rgb'
ifPlot = False
if num_samples <= 400: ifPlot = True
for label_idx in range(num_classes): 
    print(f"Class: {label_idx}" , end="\r")
    # uniform_cifar_real_images[label_idx*5000:(1+label_idx)*5000,:,:,:], uniform_cifar_fake_images[label_idx*5000:(1+label_idx)*5000] = generate_uniform_random_cifar_rgb_per_class(label_idx, num_samples, ifPlot = False)
    tmp1, tmp2 = generate_uniform_random_cifar_rgb_per_class(label_idx, num_samples = num_samples, set_name = set_name, ifPlot = True)
    # print(tmp1.shape, tmp2.shape)
if ifPlot:
    os.system(f"pdftk fake-{set_name}-class-*.pdf output fake-{set_name}.pdf")
    os.system(f"rm fake-{set_name}*class*")

Class: 0