<a href="https://colab.research.google.com/github/ShaheemJ/CelestAI/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets pyarrow

In [None]:
from datasets import load_dataset

dataset = load_dataset("MultimodalUniverse/legacysurvey", split="train", streaming=True)

dataset_iter = iter(dataset)

first_500 = [next(dataset_iter) for _ in range(500)]

In [None]:
print(first_500[0])

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

IMG_SIZE = 64

def preprocess_image(example):
    """Extract and process flux image data correctly."""
    flux_array = np.array(example["image"]["flux"])
    flux_array = np.mean(flux_array, axis=0)

    #resize to (64, 64)
    flux_resized = tf.image.resize(tf.convert_to_tensor(flux_array)[..., None], (IMG_SIZE, IMG_SIZE)).numpy()

    #normalize to [0, 1]
    flux_resized = (flux_resized - np.min(flux_resized)) / (np.max(flux_resized) - np.min(flux_resized))

    return flux_resized

#convert first 500 images
train_images = np.array([preprocess_image(img) for img in first_500])

print(f"Dataset shape: {train_images.shape}")

In [None]:
plt.imshow(train_images[0], cmap="gray")
plt.axis("off")
plt.show()

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
from PIL import Image
import io
import datasets

In [None]:
dataset = load_dataset("MultimodalUniverse/legacysurvey", split="train", streaming=True)

dataset_iter = iter(dataset)

first_500 = [next(dataset_iter) for _ in range(500)]

IMG_SIZE = 64

def preprocess_image(example):
    """Convert multimodal dataset format to a NumPy image array."""
    flux = np.array(example["image"]["flux"])

    image = flux[0]
    image = np.clip(image, 0, 1)
    image = Image.fromarray((image * 255).astype(np.uint8))
    image = image.resize((IMG_SIZE, IMG_SIZE))
    image = np.array(image) / 255.0

    return image

#process images without saving locally
train_images = np.array([preprocess_image(img) for img in first_500])

#expand dimensions to match TensorFlow format
train_images = np.expand_dims(train_images, axis=-1)

print(f"Dataset shape: {train_images.shape}")

In [None]:
class Sampling(layers.Layer):
    """Reparameterization trick to sample from N(mu, var)"""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Lambda, Reshape, Conv2DTranspose
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K

LATENT_DIM = 16  #latent space dimension

#encoder network
def build_encoder():
    input_img = Input(shape=(64, 64, 1))

    x = Conv2D(32, (3,3), activation="relu", strides=2, padding="same")(input_img)
    x = Conv2D(64, (3,3), activation="relu", strides=2, padding="same")(x)
    x = Flatten()(x)
    x = Dense(128, activation="relu")(x)

    #latent space variables
    z_mean = Dense(LATENT_DIM, name="z_mean")(x)
    z_log_var = Dense(LATENT_DIM, name="z_log_var")(x)

    #reparameterization 
    def sampling(args):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(K.shape(z_mean)[0], LATENT_DIM), mean=0., stddev=1.0)
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    z = Lambda(sampling, output_shape=(LATENT_DIM,), name="z")([z_mean, z_log_var])

    encoder = Model(input_img, [z_mean, z_log_var, z], name="Encoder")
    return encoder

encoder = build_encoder()
encoder.summary()

In [None]:
#decoder network
def build_decoder():
    latent_inputs = Input(shape=(LATENT_DIM,), name="z_sampling")

    x = Dense(16 * 16 * 64, activation="relu")(latent_inputs)
    x = Reshape((16, 16, 64))(x)
    x = Conv2DTranspose(64, (3,3), activation="relu", strides=2, padding="same")(x)
    x = Conv2DTranspose(32, (3,3), activation="relu", strides=2, padding="same")(x)
    x = Conv2DTranspose(1, (3,3), activation="sigmoid", padding="same")(x)

    decoder = Model(latent_inputs, x, name="Decoder")
    return decoder

decoder = build_decoder()
decoder.summary()

In [None]:
from tensorflow.keras.losses import binary_crossentropy

class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        return reconstructed

vae = VAE(encoder, decoder)

In [None]:
def vae_loss(y_true, y_pred):
    """Custom VAE loss combining reconstruction loss and KL divergence."""
    reconstruction_loss = binary_crossentropy(K.flatten(y_true), K.flatten(y_pred))
    reconstruction_loss *= 64 * 64

    #KL divergence
    z_mean, z_log_var, _ = encoder(y_true)
    kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)

    return K.mean(reconstruction_loss + kl_loss)

vae.compile(optimizer="adam", loss=vae_loss)

In [None]:
vae.fit(train_images, train_images, epochs=50, batch_size=32)

In [None]:
import numpy as np

def generate_images(decoder, num_images=5):
    random_latent_vectors = np.random.normal(size=(num_images, LATENT_DIM))
    generated_images = decoder.predict(random_latent_vectors)

    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    for i, img in enumerate(generated_images):
        axes[i].imshow(img.squeeze(), cmap="gray")
        axes[i].axis("off")
    plt.show()

generate_images(decoder, num_images=10)

In [None]:
import numpy as np

#compute reconstruction loss (MSE) between original and reconstructed images
def compute_reconstruction_loss(vae, dataset, num_samples=100):
    dataset_iter = iter(dataset)
    losses = []

    for _ in range(num_samples):
        example = next(dataset_iter)
        original_image = preprocess_image(example)  #convert to NumPy array
        original_image = np.expand_dims(original_image, axis=0)  #add batch dim

        reconstructed_image = vae(original_image)  #pass through VAE
        loss = np.mean((original_image - reconstructed_image.numpy())**2)  #MSE

        losses.append(loss)

    return np.mean(losses)

#test on first 100 images
reconstruction_loss = compute_reconstruction_loss(vae, first_500[:500])
print(f"Average Reconstruction Loss: {reconstruction_loss}")