In [1]:
import numpy as np
import h5py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# --------------------------
# Parameter Settings
# --------------------------
img_size = 125
channels = 3
T = 50  # Number of diffusion steps; lower T reduces training computation
beta_start = 1e-4
beta_end = 0.02
betas = np.linspace(beta_start, beta_end, T, dtype=np.float32)
alphas = 1 - betas
alpha_bars = np.cumprod(alphas)  # Cumulative product

# --------------------------
# Use Local Data: Only take the first 5000 samples
# --------------------------
subset_size = 1000
with h5py.File('quark-gluon_data-set_n139306.hdf5', 'r') as f:
    total_samples = f['X_jets'].shape[0]
    print("Total number of samples:", total_samples)

# Define local dataset indices: first 4000 for training, next 1000 for validation
train_subset = np.arange(0, 4000)
val_subset = np.arange(4000, subset_size)

# --------------------------
# Define Diffusion Data Generator
# --------------------------
# For each image sample, randomly choose a diffusion step t,
# generate a noisy image using the formula, and return (noisy_image, t) as input;
# the target is the added noise
def diffusion_data_generator(indices):
    with h5py.File('quark-gluon_data-set_n139306.hdf5', 'r') as f:
        ds = f['X_jets']
        for i in indices:
            image = ds[i].astype('float32')
            image = image / (np.max(image) + 1e-8)
            t = np.random.randint(0, T)  # Randomly select a diffusion step
            noise = np.random.normal(0, 1, size=image.shape).astype('float32')
            alpha_bar_t = alpha_bars[t]
            noisy_image = np.sqrt(alpha_bar_t) * image + np.sqrt(1 - alpha_bar_t) * noise
            yield (noisy_image, np.array([t], dtype=np.int32)), noise

# Create tf.data.Dataset and set batch size and prefetching
train_dataset = tf.data.Dataset.from_generator(
    lambda: diffusion_data_generator(train_subset),
    output_types=((tf.float32, tf.int32), tf.float32),
    output_shapes=(((img_size, img_size, channels), (1,)), (img_size, img_size, channels))
).batch(64).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: diffusion_data_generator(val_subset),
    output_types=((tf.float32, tf.int32), tf.float32),
    output_shapes=(((img_size, img_size, channels), (1,)), (img_size, img_size, channels))
).batch(64).prefetch(tf.data.AUTOTUNE)

# --------------------------
# Define the Diffusion Network Model
# --------------------------
def get_diffusion_model():
    image_input = keras.Input(shape=(img_size, img_size, channels), name='noisy_image')
    t_input = keras.Input(shape=(1,), dtype=tf.int32, name='t')

    # Time embedding: Embed t into a vector, use Dense layer to reshape to (125, 125, 1)
    t_emb = layers.Embedding(input_dim=T, output_dim=32)(t_input)  # (batch, 1, 32)
    t_emb = layers.Flatten()(t_emb)                                # (batch, 32)
    t_emb = layers.Dense(img_size * img_size, activation='relu')(t_emb)
    t_emb = layers.Reshape((img_size, img_size, 1))(t_emb)

    # Merge image and time embedding: channels become channels + 1
    x = layers.Concatenate(axis=-1)([image_input, t_emb])

    # A few convolutional layers to predict the added noise
    x = layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')(x)
    x = layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')(x)
    x = layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    output = layers.Conv2D(channels, kernel_size=3, activation='linear', padding='same')(x)

    return keras.Model([image_input, t_input], output, name='diffusion_model')

diffusion_model = get_diffusion_model()
diffusion_model.summary()

# Compile model using Mean Squared Error (MSE) loss
diffusion_model.compile(optimizer='adam', loss='mse')

# --------------------------
# Train the Model
# --------------------------
# Set epochs to 10 for demo speed; adjust as needed
epochs = 10
history_diffusion = diffusion_model.fit(train_dataset, epochs=epochs, validation_data=val_dataset)

# --------------------------
# Reconstruction Function and Evaluation
# --------------------------
def reconstruct_image(original_image, t_fixed):
    """
    Given an original image and fixed diffusion step t_fixed,
    generate noisy image using the diffusion formula,
    use the model to predict the noise,
    and reverse the process to approximately reconstruct the original image.
    """
    alpha_bar_t = alpha_bars[t_fixed]
    noise = np.random.normal(0, 1, size=original_image.shape).astype('float32')
    noisy_image = np.sqrt(alpha_bar_t) * original_image + np.sqrt(1 - alpha_bar_t) * noise
    predicted_noise = diffusion_model.predict([np.expand_dims(noisy_image, axis=0), np.array([[t_fixed]])])
    reconstructed = (noisy_image - np.sqrt(1 - alpha_bar_t) * predicted_noise[0]) / np.sqrt(alpha_bar_t)
    return noisy_image, reconstructed

# Select 10 images from the validation set for reconstruction demo
n = 10
reconstructed_images = []
original_images = []
mse_list = []
t_fixed = 25  # Use a middle diffusion step for reconstruction

with h5py.File('quark-gluon_data-set_n139306.hdf5', 'r') as f:
    ds = f['X_jets']
    for i in val_subset[:n]:
        img = ds[i].astype('float32')
        img = img / (np.max(img) + 1e-8)
        original_images.append(img)
        noisy_img, recon_img = reconstruct_image(img, t_fixed)
        reconstructed_images.append(recon_img)
        mse_list.append(np.mean((img - recon_img) ** 2))

avg_mse = np.mean(mse_list)
print("Average MSE over {} samples: {:.6f}".format(n, avg_mse))

# --------------------------
# Image Display: Compare Original and Reconstructed Images Side-by-Side
# --------------------------
plt.figure(figsize=(20, 4))
for i in range(n):
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(original_images[i])
    plt.title("Original")
    plt.axis("off")

    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(reconstructed_images[i])
    plt.title("Reconstructed")
    plt.axis("off")
plt.show()


Total number of samples: 139306
Instructions for updating:
Use output_signature instead
Instructions for updating:
Use output_signature instead


Epoch 1/10
      5/Unknown [1m1431s[0m 298s/step - loss: 0.9734

KeyboardInterrupt: 