In [9]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from PIL import Image

# Set the dimensions for input images
IMG_WIDTH = 256
IMG_HEIGHT = 256
RGB_CHANNELS = 3
HSI_CHANNELS = 31  # Number of spectral bands for hyperspectral images

# 1. Data Acquisition and Preprocessing
def load_rgb_images(image_path, target_size=(IMG_WIDTH, IMG_HEIGHT)):
    """Load and resize RGB images from a specified directory."""
    images = []
    for filename in os.listdir(image_path):
        if filename.endswith('_clean.png'):  # Update to match the RGB naming convention
            img = tf.keras.preprocessing.image.load_img(os.path.join(image_path, filename), target_size=target_size)
            img = tf.keras.preprocessing.image.img_to_array(img) / 255.0  # Normalize
            images.append(img)
    return np.array(images)

def load_hsi_images_from_all_folders(base_folder, target_size=(IMG_WIDTH, IMG_HEIGHT)):
    all_hsi_images = []

    # Iterate over each folder in the base folder
    for folder in os.listdir(base_folder):
        folder_path = os.path.join(base_folder, folder)
        if os.path.isdir(folder_path):
            images = []

            # Load each HSI image in the folder
            for file in sorted(os.listdir(folder_path)):
                if file.endswith('.tiff') or file.endswith('.tif'):
                    image_path = os.path.join(folder_path, file)
                    img = Image.open(image_path).convert('L')  # Convert to grayscale
                    img = img.resize(target_size)  # Resize to target size
                    img = np.array(img)[..., np.newaxis]  # Add channel dimension
                    images.append(img)

            # Check the number of images loaded
            if len(images) == 31:
                stacked_images = np.array(images)  # Shape will be (31, height, width, 1)
                print(f"Stacked images shape before transpose for folder {folder}: {stacked_images.shape}")
                # Ensure that the shape is (31, 256, 256, 1)
                if stacked_images.shape == (31, IMG_HEIGHT, IMG_WIDTH, 1):
                    # Transpose to get shape (256, 256, 31)
                    all_hsi_images.append(stacked_images.reshape(IMG_HEIGHT, IMG_WIDTH, HSI_CHANNELS))  # Now it will be (height, width, 31)
                else:
                    print(f"Unexpected shape for stacked images in folder {folder}: {stacked_images.shape}")
            else:
                print(f"Warning: Expected 31 images but got {len(images)} for folder {folder}")

    return np.array(all_hsi_images)

# Load RGB images
rgb_images = load_rgb_images(r"C:\Harshi\ECS-II\Dataset\RGB_7_files")

# Load HSI images from all folders
hsi_images = load_hsi_images_from_all_folders(r"C:\Harshi\ECS-II\Dataset\HSI_Dataset_TIFF")

# Data Augmentation
data_gen = ImageDataGenerator(rotation_range=20,
                               width_shift_range=0.1,
                               height_shift_range=0.1,
                               shear_range=0.1,
                               zoom_range=0.1,
                               horizontal_flip=True,
                               fill_mode='nearest')

# Example of augmented RGB images
augmented_rgb_images = next(data_gen.flow(rgb_images, batch_size=8))

# 2. Model Architecture
def build_generator():
    inputs = layers.Input(shape=(IMG_HEIGHT, IMG_WIDTH, RGB_CHANNELS))

    # Encoder
    down1 = layers.Conv2D(64, (4, 4), strides=2, padding='same')(inputs)
    down1 = layers.ReLU()(down1)
    down2 = layers.Conv2D(128, (4, 4), strides=2, padding='same')(down1)
    down2 = layers.ReLU()(down2)

    # Bottleneck
    bottleneck = layers.Conv2D(256, (4, 4), padding='same')(down2)
    bottleneck = layers.ReLU()(bottleneck)

    # Decoder
    up1 = layers.Conv2DTranspose(128, (4, 4), strides=2, padding='same')(bottleneck)
    up1 = layers.ReLU()(up1)
    up2 = layers.Conv2DTranspose(64, (4, 4), strides=2, padding='same')(up1)
    up2 = layers.ReLU()(up2)

    outputs = layers.Conv2D(HSI_CHANNELS, (3, 3), padding='same', activation='sigmoid')(up2)  # Hyperspectral output (31 bands)
    
    return models.Model(inputs, outputs)

def build_discriminator():
    inputs = layers.Input(shape=(IMG_HEIGHT, IMG_WIDTH, RGB_CHANNELS + HSI_CHANNELS))  # RGB + HSI
    x = layers.Conv2D(64, (4, 4), strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(128, (4, 4), strides=2, padding='same')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(256, (4, 4), padding='same')(x)
    x = layers.LeakyReLU(0.2)(x)
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(x)  # PatchGAN output
    
    return models.Model(inputs, outputs)

# 3. Loss Functions
def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
    fake_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return tf.reduce_mean(total_loss)

def generator_loss(fake_output):
    return tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.ones_like(fake_output), fake_output))

def pixel_loss(generated, target):
    return tf.reduce_mean(tf.square(generated - target))

# Evaluation Metrics
def mean_squared_error(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

def peak_signal_to_noise_ratio(y_true, y_pred):
    return tf.reduce_mean(tf.image.psnr(y_true, y_pred, max_val=1.0))

def spectral_angle_mapper(y_true, y_pred):
    dot_product = tf.reduce_sum(y_true * y_pred, axis=-1)
    norm_true = tf.norm(y_true, axis=-1)
    norm_pred = tf.norm(y_pred, axis=-1)
    cos_theta = dot_product / (norm_true * norm_pred)
    cos_theta = tf.clip_by_value(cos_theta, -1.0, 1.0)  # Clip values to avoid NaNs
    sam = tf.acos(cos_theta)
    return tf.reduce_mean(sam)

# 4. Training Loop

epochs = 50
batch_size = 16

generator = build_generator()
discriminator = build_discriminator()
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Setup TensorBoard logging
log_dir = "logs/"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
summary_writer = tf.summary.create_file_writer(log_dir)

# Print log directory for debugging
print(f"Logging to directory: {log_dir}")

# Setup model checkpointing
checkpoint_dir = "checkpoints/"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

def train_gan(rgb_images, hsi_images):
    # Convert images to tensors
    rgb_images = tf.convert_to_tensor(rgb_images, dtype=tf.float32)
    hsi_images = tf.convert_to_tensor(hsi_images, dtype=tf.float32)

    # Enable tracing
    tf.summary.trace_on(graph=True, profiler=True)

    for epoch in range(epochs):
        for i in range(0, len(rgb_images), batch_size):
            rgb_batch = rgb_images[i:i + batch_size]
            hsi_batch = hsi_images[i:i + batch_size]

            # Augment the RGB batch using data_gen.flow
            augmented_rgb_batch = next(data_gen.flow(rgb_batch.numpy(), batch_size=batch_size))

            # Convert augmented batch back to tensor
            augmented_rgb_batch = tf.convert_to_tensor(augmented_rgb_batch, dtype=tf.float32)

            # Generate HSI images from the generator
            generated_hsi = generator(augmented_rgb_batch)

            # Resize the generated HSI and augmented RGB batch to the same shape
            target_shape = tf.shape(hsi_batch)[1:3]  # Assuming hsi_batch has the correct shape
            generated_hsi_resized = tf.image.resize(generated_hsi, target_shape)
            augmented_rgb_batch_resized = tf.image.resize(augmented_rgb_batch, target_shape)

            # Create input for discriminator
            combined_real = tf.concat([hsi_batch, rgb_batch], axis=-1)
            combined_fake = tf.concat([generated_hsi_resized, augmented_rgb_batch_resized], axis=-1)

            with tf.GradientTape() as disc_tape:
                disc_real = discriminator(combined_real)
                disc_fake = discriminator(combined_fake)
                disc_loss = discriminator_loss(disc_real, disc_fake)

            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

            with tf.GradientTape() as gen_tape:
                generated_hsi = generator(augmented_rgb_batch)
                generated_hsi_resized = tf.image.resize(generated_hsi, target_shape)
                combined_fake = tf.concat([generated_hsi_resized, augmented_rgb_batch_resized], axis=-1)
                gen_loss = generator_loss(discriminator(combined_fake))

            gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

            # Calculate evaluation metrics
            mse = mean_squared_error(hsi_batch, generated_hsi_resized)
            psnr = peak_signal_to_noise_ratio(hsi_batch, generated_hsi_resized)
            sam = spectral_angle_mapper(hsi_batch, generated_hsi_resized)

            # Log metrics to TensorBoard
            with summary_writer.as_default():
                tf.summary.scalar('Discriminator Loss', tf.reduce_mean(disc_loss), step=epoch * len(rgb_images) // batch_size + i // batch_size)
                tf.summary.scalar('Generator Loss', tf.reduce_mean(gen_loss), step=epoch * len(rgb_images) // batch_size + i // batch_size)
                tf.summary.scalar('MSE', tf.reduce_mean(mse), step=epoch * len(rgb_images) // batch_size + i // batch_size)
                tf.summary.scalar('PSNR', tf.reduce_mean(psnr), step=epoch * len(rgb_images) // batch_size + i // batch_size)
                tf.summary.scalar('SAM', tf.reduce_mean(sam), step=epoch * len(rgb_images) // batch_size + i // batch_size)

            print(f'Epoch: {epoch}, Batch: {i // batch_size}, Discriminator Loss: {disc_loss.numpy()}, Generator Loss: {gen_loss.numpy()}, MSE: {mse.numpy()}, PSNR: {psnr.numpy()}, SAM: {sam.numpy()}')

        # Save model checkpoint at the end of each epoch
        checkpoint.save(file_prefix=checkpoint_prefix)

    # Export the trace
    with summary_writer.as_default():
        tf.summary.trace_export(name="GeneratorGraph", step=0, profiler_outdir=log_dir)

# Call the train function with your RGB and HSI images
train_gan(rgb_images, hsi_images)

Stacked images shape before transpose for folder ARAD_HS_0151: (31, 256, 256, 1)
Stacked images shape before transpose for folder ARAD_HS_0152: (31, 256, 256, 1)
Stacked images shape before transpose for folder ARAD_HS_0153: (31, 256, 256, 1)
Stacked images shape before transpose for folder ARAD_HS_0155: (31, 256, 256, 1)
Stacked images shape before transpose for folder ARAD_HS_0160: (31, 256, 256, 1)
Stacked images shape before transpose for folder ARAD_HS_0161: (31, 256, 256, 1)
Stacked images shape before transpose for folder ARAD_HS_0163: (31, 256, 256, 1)
Logging to directory: logs/
Epoch: 0, Batch: 0, Discriminator Loss: 1.3419339656829834, Generator Loss: 1.2493051290512085, MSE: 0.25001752376556396, PSNR: 6.020295143127441, SAM: nan
Epoch: 1, Batch: 0, Discriminator Loss: 1.0383168458938599, Generator Loss: 1.9509830474853516, MSE: 0.24992474913597107, PSNR: 6.021907329559326, SAM: nan
Epoch: 2, Batch: 0, Discriminator Loss: 0.8570003509521484, Generator Loss: 2.883414030075073

KeyboardInterrupt: 

### Steps to View Logs on TensorBoard

Run your training script:

```python your_training_script.py```

### Start TensorBoard:

```tensorboard --logdir=logs/```

Open TensorBoard in a Web Browser: Go to http://localhost:6006/ in your web browser.