In [None]:
import sys
import os
sys.path.append(os.path.abspath("/data2/eranario/scratch/rgb-to-multispectral-unet"))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from src.dataset import PotatoDatasetSpectra
from src.model import UNeTransformedSpectral
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
# set torch random seed
torch.manual_seed(42)

# Dataset

In [None]:
rgb_dir = "/data2/eranario/data/Multispectral-Potato/Dataset/RGB_Images"
spectral_dir = "/data2/eranario/data/Multispectral-Potato/Dataset/Spectral_Images"
spectral_file = "/data2/eranario/data/Multispectral-Potato/Signals/spectra.csv"

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor()
])

In [None]:
train_dataset = PotatoDatasetSpectra(rgb_dir, spectral_dir, spectral_file, transform=transform, mode='train', align=True)
val_dataset = PotatoDatasetSpectra(rgb_dir, spectral_dir, spectral_file, transform=transform, mode='val', align=True)
test_dataset = PotatoDatasetSpectra(rgb_dir, spectral_dir, spectral_file, transform=transform, mode='test', align=True)

# print the size of the datasets
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
batch_size = 24
num_workers = 4

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    num_workers=num_workers
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    num_workers=num_workers
)

In [None]:
# check first batch
for rgb, signal, _, _, _, _ in train_dataloader:
    print(f"RGB shape: {rgb.shape}")
    print(f"Signal shape: {signal.shape}")
    break

In [None]:
train_dataset.num_bands

# Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNeTransformedSpectral(
    in_channels=3, 
    out_channels=len(train_dataset.channels),
    num_bands=train_dataset.num_bands,
    spectral_dim=256,
    num_tokens=196,
    patch_size=4
).to(device)

# define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training

In [11]:
from skimage.metrics import structural_similarity as ssim
from IPython.display import clear_output
import numpy as np
import time

def evaluateEuclideanDistance(predictedImage, groundTruthImage):
    # Compute Euclidean distance between pixels
    pixelDifferences = np.sqrt(np.sum((predictedImage - groundTruthImage) ** 2, axis=-1))
    # Compute the average Euclidean distance for the image
    averagePixelDifferences = np.mean(pixelDifferences)
    return averagePixelDifferences

num_epochs = 100
train_losses = []
val_losses = []
train_similarities = []  # SSIM for training
val_similarities = []    # SSIM for validation
train_euclidean_distances = []  # Euclidean distances for training
val_euclidean_distances = []    # Euclidean distances for validation

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_similarity_score = 0.0
    train_euclidean_distance_score = 0.0
    train_loop = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Training", leave=True)

    for batch in train_loop:
        rgb_images, spectral_signal, *spectral_images = batch
        rgb_images = rgb_images.to(device)
        spectral_signal = spectral_signal.to(device)
        spectral_images = torch.stack(spectral_images, dim=1).squeeze(2).to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(rgb_images, spectral_signal)
        loss = criterion(outputs, spectral_images)
            
        loss.backward()
        optimizer.step()

        # Update progress
        train_loss += loss.item()

        # Compute SSIM and Euclidean distance for training
        for i in range(outputs.size(0)):  # Iterate over the batch size
            output_img = outputs[i].cpu().detach().numpy()  # Predicted image
            groundtruth_img = spectral_images[i].cpu().detach().numpy()  # Ground truth image

            # Normalize images to [0, 1] for SSIM computation
            output_img = (output_img - output_img.min()) / (output_img.max() - output_img.min() + 1e-8)
            groundtruth_img = (groundtruth_img - groundtruth_img.min()) / (groundtruth_img.max() - groundtruth_img.min() + 1e-8)

            # Compute SSIM for each spectral band separately
            band_ssim = [
                ssim(output_img[band], groundtruth_img[band], data_range=1.0)
                for band in range(output_img.shape[0])
            ]
            train_similarity_score += np.mean(band_ssim)

            # Compute Euclidean distance for each spectral band
            band_euclidean_distance = [
                evaluateEuclideanDistance(output_img[band], groundtruth_img[band])
                for band in range(output_img.shape[0])
            ]
            train_euclidean_distance_score += np.mean(band_euclidean_distance)

    avg_train_loss = train_loss / len(train_dataloader)
    avg_train_similarity = train_similarity_score / len(train_dataloader.dataset)
    avg_train_euclidean_distance = train_euclidean_distance_score / len(train_dataloader.dataset)

    train_losses.append(avg_train_loss)
    train_similarities.append(avg_train_similarity)
    train_euclidean_distances.append(avg_train_euclidean_distance)

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_similarity_score = 0.0
    val_euclidean_distance_score = 0.0
    val_loop = tqdm(val_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Validation", leave=True)

    with torch.no_grad():
        for batch in val_loop:
            rgb_images, spectral_signal, *spectral_images = batch
            rgb_images = rgb_images.to(device)
            spectral_signal = spectral_signal.to(device)
            spectral_images = torch.stack(spectral_images, dim=1).squeeze(2).to(device)

            outputs = model(rgb_images, spectral_signal)
            loss = criterion(outputs, spectral_images)

            val_loss += loss.item()

            # Compute SSIM and Euclidean distance for validation
            for i in range(outputs.size(0)):
                output_img = outputs[i].cpu().numpy()
                groundtruth_img = spectral_images[i].cpu().numpy()

                # Normalize images to [0, 1]
                output_img = (output_img - output_img.min()) / (output_img.max() - output_img.min() + 1e-8)
                groundtruth_img = (groundtruth_img - groundtruth_img.min()) / (groundtruth_img.max() - groundtruth_img.min() + 1e-8)

                # Compute SSIM for each spectral band separately
                band_ssim = [
                    ssim(output_img[band], groundtruth_img[band], data_range=1.0)
                    for band in range(output_img.shape[0])
                ]
                val_similarity_score += np.mean(band_ssim)

                # Compute Euclidean distance for each spectral band
                band_euclidean_distance = [
                    evaluateEuclideanDistance(output_img[band], groundtruth_img[band])
                    for band in range(output_img.shape[0])
                ]
                val_euclidean_distance_score += np.mean(band_euclidean_distance)

    avg_val_loss = val_loss / len(val_dataloader)
    avg_val_similarity = val_similarity_score / len(val_dataloader.dataset)
    avg_val_euclidean_distance = val_euclidean_distance_score / len(val_dataloader.dataset)

    val_losses.append(avg_val_loss)
    val_similarities.append(avg_val_similarity)
    val_euclidean_distances.append(avg_val_euclidean_distance)

    # Log results
    clear_output(wait=True)
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
    print(f"  Train SSIM: {avg_train_similarity:.4f}, Val SSIM: {avg_val_similarity:.4f}")
    print(f"  Train Euclidean: {avg_train_euclidean_distance:.4f}, Val Euclidean: {avg_val_euclidean_distance:.4f}")
    time.sleep(1)

In [None]:
# Plotting training and validation loss curves
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_losses, marker='o', linestyle='-', label="Training Loss")
plt.plot(range(1, num_epochs + 1), val_losses, marker='s', linestyle='--', label="Validation Loss")
plt.title("Training and Validation Loss per Epoch", fontsize=16)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.show()

# Plotting SSIM and Euclidean Distance for Training
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_similarities, marker='o', linestyle='-', label="Train SSIM")
plt.plot(range(1, num_epochs + 1), train_euclidean_distances, marker='^', linestyle='--', label="Train Euclidean Distance")
plt.title("Training SSIM and Euclidean Distance per Epoch", fontsize=16)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Metric Value", fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.show()

# Plotting SSIM and Euclidean Distance for Validation
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), val_similarities, marker='o', linestyle='-', label="Validation SSIM")
plt.plot(range(1, num_epochs + 1), val_euclidean_distances, marker='^', linestyle='--', label="Validation Euclidean Distance")
plt.title("Validation SSIM and Euclidean Distance per Epoch", fontsize=16)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Metric Value", fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)
plt.show()


# Plotting training and validation loss curves with zoomed y-axis
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_losses, marker='o', linestyle='-', label="Training Loss")
plt.plot(range(1, num_epochs + 1), val_losses, marker='s', linestyle='--', label="Validation Loss")
plt.title("Training and Validation Loss per Epoch", fontsize=16)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.grid(True)
plt.legend(fontsize=12)

# Adjust y-axis limits to zoom in
min_loss = min(min(train_losses), min(val_losses))
plt.ylim(0, min_loss * 1.5)  # Set upper limit slightly above the minimum loss for better visibility

plt.show()

In [None]:
# Test phase
model.eval()
test_loss = 0.0
test_similarity_score = 0.0
test_euclidean_distance_score = 0.0

test_loop = tqdm(test_dataloader, desc="Testing", leave=True)

# Initialize variables to store data for visualization
stored_rgb_images = None
stored_self_attn_weights = None
stored_cross_attn_weights = None

with torch.no_grad():
    for batch in test_loop:
        rgb_images, spectral_signal, *spectral_images = batch
        rgb_images = rgb_images.to(device)  # Input images
        spectral_signal = spectral_signal.to(device)  # Spectral signal
        spectral_images = torch.stack(spectral_images, dim=1).squeeze(2).to(device)  # Target spectral channels

        outputs, attn_weights = model(rgb_images, spectral_signal, get_weights=True)
        loss = criterion(outputs, spectral_images)

        test_loss += loss.item()

        # Store data for visualization (only from the first batch)
        if stored_rgb_images is None and stored_self_attn_weights is None and stored_cross_attn_weights is None:
            stored_rgb_images = rgb_images.cpu()  # Store as CPU tensor
            stored_self_attn_weights = attn_weights["self"]  # Store self-attention weights
            stored_cross_attn_weights = attn_weights["cross"]  # Store cross-attention weights

        # Compute SSIM and Euclidean distance for the test set
        for i in range(outputs.size(0)):  # Iterate over the batch size
            output_img = outputs[i].cpu().numpy()  # Predicted image
            groundtruth_img = spectral_images[i].cpu().numpy()  # Ground truth image

            # Normalize images to [0, 1] for SSIM computation
            output_img = (output_img - output_img.min()) / (output_img.max() - output_img.min() + 1e-8)
            groundtruth_img = (groundtruth_img - groundtruth_img.min()) / (groundtruth_img.max() - groundtruth_img.min() + 1e-8)

            # Compute SSIM for each spectral band separately
            band_ssim = [
                ssim(output_img[band], groundtruth_img[band], data_range=1.0)
                for band in range(output_img.shape[0])
            ]
            test_similarity_score += np.mean(band_ssim)

            # Compute Euclidean distance for each spectral band
            band_euclidean_distance = [
                evaluateEuclideanDistance(output_img[band], groundtruth_img[band])
                for band in range(output_img.shape[0])
            ]
            test_euclidean_distance_score += np.mean(band_euclidean_distance)

avg_test_loss = test_loss / len(test_dataloader)
avg_test_similarity = test_similarity_score / len(test_dataloader.dataset)
avg_test_euclidean_distance = test_euclidean_distance_score / len(test_dataloader.dataset)

# Log test results
print("\nTest Results:")
print(f"  Test Loss: {avg_test_loss:.4f}")
print(f"  Test SSIM: {avg_test_similarity:.4f}")
print(f"  Test Euclidean Distance: {avg_test_euclidean_distance:.4f}")


In [None]:
def show_predictions(dataloader, model, device, channels=None):
    """
    Displays the RGB input, ground truth spectral channels, and model predictions for a single sample in a vertical layout.
    Args:
        dataloader: DataLoader to fetch data.
        model: Trained model to generate predictions.
        device: Device (CPU/GPU) to use.
        channels: List of channel names (e.g., ['Green', 'NIR', 'Red', 'Red Edge']).
    """
    model.eval()  # Set model to evaluation mode

    channels = channels or ['Green', 'NIR', 'Red', 'Red Edge']  # Default channel names
    num_spectral_channels = len(channels)

    # Get one batch of data
    rgb_images, spectral_signal, *spectral_images = next(iter(dataloader))
    rgb_images = rgb_images.to(device)  # Move RGB inputs to the device
    spectral_signal = spectral_signal.to(device)  # Move spectral signal to the device
    spectral_images = torch.stack(spectral_images, dim=1).squeeze(2).to(device)  # Ground truth
    predictions = model(rgb_images, spectral_signal)  # Model predictions

    # Use only the first sample in the batch
    rgb_image = rgb_images[1].permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC for RGB
    ground_truth = spectral_images[1].cpu().numpy()  # (num_channels, H, W)
    prediction = predictions[1].detach().cpu().numpy()  # Detach, then convert to NumPy (num_channels, H, W)

    # Create a vertical layout figure
    fig, axs = plt.subplots(num_spectral_channels, 3, figsize=(15, 5 * num_spectral_channels))
    for channel_idx in range(num_spectral_channels):
        spectral_channel_gt = ground_truth[channel_idx]  # Ground truth for this channel
        spectral_channel_pred = prediction[channel_idx]  # Prediction for this channel

        # RGB input
        if channel_idx == 0:  # Show RGB only in the first row
            axs[channel_idx, 0].imshow(rgb_image)
            axs[channel_idx, 0].set_title("RGB Input")
        else:
            axs[channel_idx, 0].axis("off")  # Keep empty for other rows

        # Ground truth
        axs[channel_idx, 1].imshow(spectral_channel_gt, cmap="viridis")
        axs[channel_idx, 1].set_title(f"GT: {channels[channel_idx]}")

        # Prediction
        axs[channel_idx, 2].imshow(spectral_channel_pred, cmap="viridis")
        axs[channel_idx, 2].set_title(f"Pred: {channels[channel_idx]}")

        # Remove axes for cleaner visualization
        axs[channel_idx, 0].axis("off")
        axs[channel_idx, 1].axis("off")
        axs[channel_idx, 2].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
channels = ['Green', 'Near Infrared', 'Red', 'Red Edge']
show_predictions(test_dataloader, model, device, channels=channels)

In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_self_attention_with_image(attention_weights, num_layers, input_image, layer_names=None):
    """
    Visualize self-attention weights averaged across heads for each Transformer layer alongside the input image.

    Args:
        attention_weights (list[torch.Tensor]): List of attention weights for each layer.
            Each tensor should have shape (num_heads, num_patches, num_patches).
        num_layers (int): Number of layers to visualize.
        input_image (torch.Tensor or np.ndarray): Input image with shape (H, W, C) or (H, W).
        layer_names (list[str], optional): Names of the layers for display.
    """
    plt.figure(figsize=(15, 6 * num_layers))

    # Normalize the image for display
    if isinstance(input_image, torch.Tensor):
        input_image = input_image.detach().cpu().numpy()
    if input_image.shape[-1] == 3:  # If RGB image
        input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())

    for i, attn in enumerate(attention_weights[:num_layers]):
        # Average over heads
        avg_attn = attn.mean(dim=0).detach().cpu()  # Shape: (num_patches, num_patches)

        # Determine the patch grid shape dynamically
        num_patches = avg_attn.shape[0]
        grid_size = int(math.sqrt(num_patches))
        if grid_size * grid_size != num_patches:
            raise ValueError(f"Number of patches ({num_patches}) is not a perfect square.")

        # Reshape the attention map
        avg_attn = avg_attn.mean(dim=1)  # Average over all patches
        avg_attn = avg_attn.view(grid_size, grid_size)

        # Resize attention map to match image dimensions
        resized_attn = torch.nn.functional.interpolate(
            avg_attn.unsqueeze(0).unsqueeze(0),  # Add batch and channel dims
            size=input_image.shape[:2],  # Match input image size
            mode="bilinear",
            align_corners=False,
        ).squeeze().numpy()  # Remove extra dims

        # Plot the input image and the attention map
        plt.subplot(num_layers, 2, 2 * i + 1)
        plt.imshow(input_image, cmap="gray" if input_image.ndim == 2 else None)
        plt.axis("off")
        plt.title("Input Image")

        plt.subplot(num_layers, 2, 2 * i + 2)
        plt.imshow(input_image, cmap="gray" if input_image.ndim == 2 else None)
        plt.imshow(resized_attn, cmap="viridis", alpha=0.5)  # Overlay attention map
        plt.colorbar()
        layer_title = layer_names[i] if layer_names else f"Transformer Block {i + 1}"
        plt.title(f"Self-Attention Map - {layer_title}")

    plt.tight_layout()
    plt.show()


In [None]:
# Visualize attention maps from stored data
if stored_rgb_images is not None and stored_self_attn_weights is not None:
    # Take the first image from the stored batch
    input_image = stored_rgb_images[1].permute(1, 2, 0).numpy()  # Convert to HWC format

    # Normalize the input image for visualization
    input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())

    # Self-attention weights
    self_attentions = stored_self_attn_weights

    # Number of layers and optional layer names
    num_layers = len(self_attentions)
    layer_names = [f"Transformer Block {i + 1}" for i in range(num_layers)]

    # Visualize self-attention weights with the input image
    visualize_self_attention_with_image(
        attention_weights=self_attentions,
        num_layers=num_layers,
        input_image=input_image,
        layer_names=layer_names
    )


In [None]:
def visualize_cross_attention_for_patch(
    cross_attention_weights_all_blocks,
    input_image,
    patch_coords,
    layer_names=None
):
    """
    Visualize cross-attention scores for a specific patch in the image across Transformer blocks.

    Args:
        cross_attention_weights_all_blocks (list[torch.Tensor]): List of cross-attention weights for each block.
            Each tensor should have shape (num_heads, num_patches, num_keys).
        input_image (torch.Tensor or np.ndarray): Input image in HWC format.
        patch_coords (tuple[int, int]): Coordinates (row, col) of the selected patch in the grid.
        layer_names (list[str], optional): Names of the Transformer blocks.
    """
    num_blocks = len(cross_attention_weights_all_blocks)
    selected_row, selected_col = patch_coords

    plt.figure(figsize=(15, 6 * num_blocks))

    for block_idx, cross_attention_weights in enumerate(cross_attention_weights_all_blocks):
        # Average attention weights across heads
        avg_cross_attention = cross_attention_weights.mean(dim=0).detach().cpu()  # Shape: (num_patches, num_keys)

        # Calculate patch grid dimensions dynamically
        num_patches = avg_cross_attention.shape[0]
        grid_size = int(math.sqrt(num_patches))
        if grid_size * grid_size != num_patches:
            raise ValueError(f"Number of patches ({num_patches}) is not a perfect square.")
        h_patches, w_patches = grid_size, grid_size

        # Convert (row, col) to patch index
        patch_idx = selected_row * w_patches + selected_col
        if patch_idx >= num_patches:
            raise IndexError(f"Patch index {patch_idx} is out of bounds for grid size {h_patches}x{w_patches}.")

        # Extract attention scores for the selected patch
        patch_attention_scores = avg_cross_attention[patch_idx]  # Shape: (num_keys,)

        # Highlight the selected patch on the input image
        patch_height = input_image.shape[0] // h_patches
        patch_width = input_image.shape[1] // w_patches
        top, left = selected_row * patch_height, selected_col * patch_width

        input_image_display = input_image.copy()
        if isinstance(input_image_display, torch.Tensor):
            input_image_display = input_image_display.numpy()
        input_image_display = (input_image_display - input_image_display.min()) / (
            input_image_display.max() - input_image_display.min()
        )  # Normalize

        plt.subplot(num_blocks, 2, 2 * block_idx + 1)
        plt.imshow(input_image_display, cmap="gray" if input_image_display.ndim == 2 else None)
        plt.gca().add_patch(plt.Rectangle(
            (left, top), patch_width, patch_height, edgecolor='red', facecolor='none', linewidth=2))
        plt.axis("off")
        plt.title(f"Input Image with Patch Highlighted (Block {block_idx + 1})")

        # Visualize attention scores for the spectral signal
        plt.subplot(num_blocks, 2, 2 * block_idx + 2)
        plt.bar(range(len(patch_attention_scores)), patch_attention_scores, color="blue", alpha=0.7)
        plt.xlabel("Attention Keys")
        plt.ylabel("Attention Score")
        block_name = layer_names[block_idx] if layer_names else f"Transformer Block {block_idx + 1}"
        plt.title(f"Cross-Attention Scores for Selected Patch - {block_name}")

    plt.tight_layout()
    plt.show()


In [None]:
# Visualize cross-attention maps for a specific patch
if stored_rgb_images is not None and stored_cross_attn_weights is not None:
    # Take the first image from the stored batch
    input_image = stored_rgb_images[1].permute(1, 2, 0).numpy()  # Convert to HWC format

    # Normalize the input image for visualization
    input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())

    # Cross-attention weights
    cross_attentions = stored_cross_attn_weights

    # Number of layers and optional layer names
    num_layers = len(cross_attentions)
    layer_names = [f"Transformer Block {i + 1}" for i in range(num_layers)]

    # Select a patch (row, col) in the grid
    selected_patch_coords = (3, 3)  # Example: Center patch

    # Visualize cross-attention scores for the selected patch
    visualize_cross_attention_for_patch(
        cross_attention_weights_all_blocks=cross_attentions,
        input_image=input_image,
        patch_coords=selected_patch_coords,
        layer_names=layer_names
    )
