<a href="https://colab.research.google.com/github/ShaheemJ/CelestAI/blob/main/VQGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets pyarrow

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from datasets import load_dataset

#hyperparameters (improved)
IMG_SIZE        = 64
BATCH_SIZE      = 32
NUM_IMAGES      = 500
LATENT_DIM      = 256
NUM_EMBEDDINGS  = 1024
COMMITMENT_COST = 0.5
EPOCHS          = 200

In [None]:
#load & buffer the first NUM_IMAGES from legacysurvey dataset
ds_stream = load_dataset("MultimodalUniverse/legacysurvey", split="train", streaming=True)
it = iter(ds_stream)
raw = [next(it) for _ in range(NUM_IMAGES)]

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

IMG_SIZE = 64

def preprocess(example):
    flux = np.array(example["image"]["flux"]) #shape (bands, H, W)
    gray = np.mean(flux, axis=0) #average bands
    gray = np.clip(gray, 0, 1)
    img = Image.fromarray((gray * 255).astype(np.uint8))
    img = img.resize((IMG_SIZE, IMG_SIZE))
    arr = np.array(img, dtype=np.float32) / 255.0
    return np.expand_dims(arr, -1)

#build numpy array
train_images = np.stack([preprocess(x) for x in raw], axis=0)
print("Dataset shape:", train_images.shape)

#display a grid of images
def display_image_grid(images, grid_size=5, title="Processed Images"):
    """Display a grid of images from the dataset"""
    plt.figure(figsize=(10, 10))
    plt.suptitle(title, fontsize=16)

    #display up to grid_size × grid_size images
    num_images = min(grid_size * grid_size, len(images))

    for i in range(num_images):
        plt.subplot(grid_size, grid_size, i + 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.axis('off')
        plt.title(f"Image {i+1}")

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()

#display a grid of 25 images
display_image_grid(train_images)

In [None]:
#create tf.data pipeline
dataset = (
    tf.data.Dataset
      .from_tensor_slices(train_images)
      .shuffle(1000)
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
)

In [None]:
#vector quantizer layer
class VectorQuantizer(layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, **kwargs):
        super().__init__(**kwargs)
        self.num_embeddings  = num_embeddings
        self.embedding_dim   = embedding_dim
        self.commitment_cost = commitment_cost
        # codebook: [D, K]
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(shape=(embedding_dim, num_embeddings)),
            trainable=True, name="embeddings"
        )

    def call(self, inputs):
        # inputs: [B, H, W, D]
        flat = tf.reshape(inputs, [-1, self.embedding_dim])  # [BHW, D]
        # compute L2 distance to embeddings
        distances = (
            tf.reduce_sum(flat**2, axis=1, keepdims=True)
            - 2 * tf.matmul(flat, self.embeddings)
            + tf.reduce_sum(self.embeddings**2, axis=0, keepdims=True)
        )  # [BHW, K]
        indices = tf.argmin(distances, axis=1)               # [BHW]
        one_hot = tf.one_hot(indices, self.num_embeddings)   # [BHW, K]
        quantized = tf.matmul(one_hot, tf.transpose(self.embeddings))  # [BHW, D]
        quantized = tf.reshape(quantized, tf.shape(inputs))            # [B,H,W,D]

        # losses
        e_latent_loss = tf.reduce_mean((tf.stop_gradient(quantized) - inputs)**2)
        q_latent_loss = tf.reduce_mean((quantized - tf.stop_gradient(inputs))**2)
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        self.add_loss(loss)

        # straight‑through estimator
        return inputs + tf.stop_gradient(quantized - inputs)

In [None]:
#encoder & decoder

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def build_encoder():
    inp = keras.Input((IMG_SIZE, IMG_SIZE, 1))
    x = layers.Conv2D(128, 4, strides=2, padding="same", activation="relu")(inp)
    x = layers.Conv2D(256, 4, strides=2, padding="same", activation="relu")(x)
    x = layers.Conv2D(LATENT_DIM, 1, padding="same")(x)   # No activation on final bottleneck
    return keras.Model(inputs=inp, outputs=x, name="encoder")

def build_decoder():
    downscale = IMG_SIZE // 4  #two stride-2 convs reduce H/W by factor of 4
    inp_shape = (downscale, downscale, LATENT_DIM)
    inp = keras.Input(shape=inp_shape)
    x = layers.Conv2DTranspose(256, 4, strides=2, padding="same", activation="relu")(inp)
    x = layers.Conv2DTranspose(128, 4, strides=2, padding="same", activation="relu")(x)
    x = layers.Conv2DTranspose(1,   3, strides=1, padding="same", activation="sigmoid")(x)
    return keras.Model(inputs=inp, outputs=x, name="decoder")

#instantiate
encoder = build_encoder()
decoder = build_decoder()

# Summaries (optional)
encoder.summary()
decoder.summary()

In [None]:
#discriminator
def build_discriminator():
    inp = keras.Input((IMG_SIZE,IMG_SIZE,1))
    x = layers.Conv2D(64, 4, strides=2, padding="same", activation="leaky_relu")(inp)
    x = layers.Conv2D(128,4, strides=2, padding="same", activation="leaky_relu")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)
    return keras.Model(inp, x, name="discriminator")

discriminator = build_discriminator()

In [None]:
#VQGAN custom Model
class VQGAN(keras.Model):
    def __init__(self, encoder, quantizer, decoder, discriminator, **kwargs):
        super().__init__(**kwargs)
        self.encoder       = encoder
        self.quantizer     = quantizer
        self.decoder       = decoder
        self.discriminator = discriminator
        self.l1_loss       = keras.losses.MeanAbsoluteError()
        self.bce_loss      = keras.losses.BinaryCrossentropy(from_logits=True)

    def compile(self, g_optimizer, d_optimizer):
        super().compile()
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer

    def train_step(self, real):
        #generator (encoder - quantizer - decoder)
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # encode - quantize - decode
            z_e = self.encoder(real)
            z_q = self.quantizer(z_e)
            recon = self.decoder(z_q)

            #recon loss with increased weight
            recon_weight = 10.0  #higher weight on reconstruction quality
            recon_loss = self.l1_loss(real, recon)

            # adversarial loss for generator with decreased weight
            adv_weight = 0.1  #lower weight on adversarial loss initially
            fake_logits = self.discriminator(recon)
            adv_loss_g = self.bce_loss(tf.ones_like(fake_logits), fake_logits)

            g_loss = recon_weight * recon_loss + adv_weight * adv_loss_g + sum(self.quantizer.losses)

            #discriminator loss remains the same
            real_logits = self.discriminator(real)
            adv_loss_d = (
                self.bce_loss(tf.ones_like(real_logits), real_logits) +
                self.bce_loss(tf.zeros_like(fake_logits), fake_logits)
            )

        #gradients
        grads_g = gen_tape.gradient(g_loss, self.encoder.trainable_weights
                                       + self.quantizer.trainable_weights
                                       + self.decoder.trainable_weights)
        grads_d = disc_tape.gradient(adv_loss_d, self.discriminator.trainable_weights)

        #apply
        self.g_optimizer.apply_gradients(zip(grads_g, self.encoder.trainable_weights
                                                + self.quantizer.trainable_weights
                                                + self.decoder.trainable_weights))
        self.d_optimizer.apply_gradients(zip(grads_d, self.discriminator.trainable_weights))

        return {"g_loss": g_loss, "d_loss": adv_loss_d, "recon_loss": recon_loss}

In [None]:
#instantiate & compile VQGAN
quantizer = VectorQuantizer(NUM_EMBEDDINGS, LATENT_DIM, COMMITMENT_COST)
vqgan = VQGAN(encoder, quantizer, decoder, discriminator)

vqgan.compile(
    g_optimizer = keras.optimizers.Adam(1e-4),
    d_optimizer = keras.optimizers.Adam(4e-4)
)

In [None]:
#training loop with results printed every 10 epochs
import tensorflow as tf

#create a custom callback to print results every 10 epochs
class PrintEveryNEpochs(tf.keras.callbacks.Callback):
    def __init__(self, n=5):
        super(PrintEveryNEpochs, self).__init__()
        self.n = n

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.n == 0:
            print(f"\n=== Results after epoch {epoch + 1} ===")
            for metric_name, metric_value in logs.items():
                print(f"{metric_name}: {metric_value:.4f}")
            print("="*30)

#create the callback
print_callback = PrintEveryNEpochs(n=10)

#add the callback to model.fit
history = vqgan.fit(
    dataset,
    epochs=EPOCHS,
    callbacks=[print_callback]
)
average_loss = compute_average_reconstruction_loss(vqgan, dataset)
print(f"\nFinal Average Reconstruction Loss (L1): {average_loss:.6f}")

#to plot the training history after completion
def plot_training_history(history):
    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 8))

    #plot all metrics except for per-layer losses
    metrics_to_plot = [metric for metric in history.history.keys()
                      if not metric.startswith('encoder_') and
                         not metric.startswith('decoder_') and
                         not metric.startswith('discriminator_')]

    for metric in metrics_to_plot:
        plt.plot(history.history[metric], label=metric)

    plt.title('Training Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    plt.show()



def generate_improved(n=5, temperature=0.8):
    """Generate images with temperature control for smoother results"""
    

    #get a batch of real images from the dataset
    for batch in dataset.take(1):
        real_images = batch

    #encode and quantize
    z_e = encoder(real_images)
    z_q = quantizer(z_e)

    #create variations by adding controlled noise
    noise = tf.random.normal(shape=z_q.shape, stddev=temperature)
    z_variations = z_q + noise

    #generate n new samples with the modified latents
    selected = z_variations[:n]
    samples = decoder(selected)

    return samples.numpy()

def denoise_images(images, strength=0.7):
    """Apply simple denoising to the generated images"""
    from scipy.ndimage import gaussian_filter

    processed = []
    for img in images:
        #apply Gaussian blur for denoising
        denoised = gaussian_filter(img.squeeze(), sigma=strength)
        #normalize back to [0,1] range
        denoised = (denoised - denoised.min()) / (denoised.max() - denoised.min())
        #add channel dimension back if needed
        if len(denoised.shape) == 2:
            denoised = np.expand_dims(denoised, -1)
        processed.append(denoised)

    return np.array(processed)


In [None]:
#generate new samples
def generate(n=5):
    #print dimensions of quantizer embeddings to verify
    print(f"Quantizer embeddings shape: {quantizer.embeddings.shape}")

    #sample random indices
    idx = tf.random.uniform((n, IMG_SIZE//4, IMG_SIZE//4), maxval=NUM_EMBEDDINGS, dtype=tf.int32)
    one_hot = tf.one_hot(idx, NUM_EMBEDDINGS)  # [n, h, w, K]
    print(f"One-hot shape: {one_hot.shape}")

    
    flat_one_hot = tf.reshape(one_hot, [-1, NUM_EMBEDDINGS])  # [n*h*w, K]
    flat_quantized = tf.matmul(flat_one_hot, tf.transpose(quantizer.embeddings))  # [n*h*w, D]

    #reshape back to the expected format [n, h, w, D]
    h = w = IMG_SIZE // 4
    quantized = tf.reshape(flat_quantized, [n, h, w, LATENT_DIM])
    print(f"Final quantized shape: {quantized.shape}")

    images = decoder(quantized)
    return images.numpy()

#alternative implementation if the above doesn't work
def generate_alt(n=5):
    h, w = IMG_SIZE//4, IMG_SIZE//4

    
    print(f"Using LATENT_DIM: {LATENT_DIM}")

    random_latents = tf.random.normal((n, h, w, LATENT_DIM))
    print(f"Random latents shape: {random_latents.shape}")

    images = decoder(random_latents)
    return images.numpy()

# Try the first method
# Replace the existing try/except block with this
try:
    print("Generating with improved method...")
    imgs = generate_improved(10, temperature=0.5)
except Exception as e:
    print(f"Improved method failed: {e}")
    print("Trying alternative method...")
    try:
        imgs = generate(10)
    except Exception as e:
        print(f"Original method failed: {e}")
        print("Falling back to simplest method...")
        imgs = generate_alt(10)

# Optional: Apply denoising
try:
    imgs = denoise_images(imgs, strength=0.5)
    print("Applied denoising filter")
except ImportError:
    print("Scipy not available for denoising")
#display the generated images
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, len(imgs), figsize=(15,3))
for i, im in enumerate(imgs):
    # For color images
    if im.shape[-1] == 3:
        axs[i].imshow(im)
    # For grayscale
    else:
        axs[i].imshow(im.squeeze(), cmap="gray")
    axs[i].axis("off")
plt.show()

In [None]:
def compute_average_reconstruction_loss(model, dataset):
    total_loss = 0.0
    num_samples = 0

    for batch in dataset:
        recon = model.decoder(model.quantizer(model.encoder(batch)))
        loss = tf.reduce_sum(tf.abs(recon - batch))  # L1 loss
        total_loss += loss.numpy()
        num_samples += tf.size(batch).numpy()

    avg_loss = total_loss / num_samples
    return avg_loss
