In [None]:
!pip install azure-storage-blob azure-identity --quiet

In [None]:
import cv2
import io
import numpy as np
import pandas as pd
import os
import random
import time
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from datetime import datetime
from io import BytesIO
from PIL import Image, ImageEnhance, ImageOps
from azure.storage.blob import BlobServiceClient
from google.colab import userdata
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoImageProcessor, ConvNextModel, SwinModel

In [None]:
#train_dates = ["20170418"] # "20170422", "20170508", "20170512"

train_dates = ["20170418"]

In [None]:
# from google.colab import drive

# # Mount Google Drive
# drive.mount('/content/drive')

In [None]:
# Authentication details
account_name = userdata.get('storage_account_name')
account_key = userdata.get('storage_account_key')
container_name = userdata.get('blob_container_name')

# Connection string to Azure Blob Storage
connection_string = f"DefaultEndpointsProtocol=https;AccountName={account_name};AccountKey={account_key};EndpointSuffix=core.windows.net"

# Setup to load file from blob
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
container_client = blob_service_client.get_container_client(container_name)

In [None]:
aircraft_metadata_params = ['DateTime_UTC', 'GPS_MSL_Alt', 'Drift', 'Pitch', 'Roll', 'Vert_Velocity']
CTH_col = 'top_height'

# Aircraft Metadata
def load_metadata(blob_name):
    blob_client = container_client.get_blob_client(blob_name)
    streamdownloader = blob_client.download_blob()
    metadata_df = pd.read_csv(io.BytesIO(streamdownloader.readall()))
    return metadata_df

# LiDAR Validation Heights
def load_validation_heights(blob_name):
    blob_client = container_client.get_blob_client(blob_name)
    streamdownloader = blob_client.download_blob()
    validation_df = pd.read_csv(io.BytesIO(streamdownloader.readall()))
    return validation_df

# CloudDataset Class

In [None]:
# CloudDataset classes integrating all 3 data sources: FEGS Images, Aircraft Metadata and LiDAR Validation Heights with temporal alignment
class CloudDataset(Dataset):
    def __init__(self, date_folders, transform=None):
        self.date_folders = date_folders
        self.transform = transform
        self.data_df = self._prepare_dataframe()

    def _prepare_dataframe(self):
        """
        Iterates over the date folders in azure blob storage and loads:
          1. .jpg Images from each sub-directory in the folder with '_tight_crop' in the name.
          2. Aircraft Metadata with 1-1 time alignment with the images.
          3. LiDAR Validation Heights, mapped using timestamp, if not available filled with NaN.
        Creates a df with following columns:
            timestamp, image_path, [...aircraft_metadata_params...], validation_height
        """
        image_paths, timestamps, metadata_rows, validation_heights = [], [], [], []

        for folder in self.date_folders:
            print(f"Processing folder: {folder}")
            folder_image_paths, folder_timestamps, folder_metadata_rows, folder_validation_heights = [], [], [], []

            blob_list = container_client.list_blobs(name_starts_with=folder)
            metadata_path, validation_path = None, None
            for blob in blob_list:
                # extract image paths of all .jpg images in cropped folders
                if blob.name.endswith(".jpg") and "_tight_crop" in blob.name:
                    folder_image_paths.append(blob.name)
                    folder_timestamps.append(self._extract_timestamp_from_filename(blob.name))
                # extract the aircraft metadata file path
                if blob.name.startswith(f"{folder}/IWG1.") and "processed" in blob.name:
                    metadata_path = blob.name
                # extract the LiDAR validation file path
                if blob.name.startswith(f"{folder}/goesrplt_CPL_layers_") and blob.name.endswith("_processed.txt"):
                    validation_path = blob.name

            # load aircraft metadata and LiDAR validation data
            if metadata_path:
                metadata_df = load_metadata(metadata_path)
            if validation_path:
                validation_df = load_validation_heights(validation_path)

            # prepare LiDAR validation data
            validation_df['datetime_combined'] = validation_df['date'] + ' ' + validation_df['timestamp']
            validation_df['datetime_combined'] = validation_df['datetime_combined'].str.split('.').str[0]
            validation_df['datetime_combined'] = pd.to_datetime(validation_df['datetime_combined'], format="%Y-%m-%d %H:%M:%S")

            # prepare aircraft metadata
            metadata_df = metadata_df[aircraft_metadata_params]
            metadata_df['DateTime_UTC'] = metadata_df['DateTime_UTC'].str.split('.').str[0]
            metadata_timestamps = pd.to_datetime(metadata_df['DateTime_UTC'], format="%Y-%m-%d %H:%M:%S")
            metadata_df = metadata_df.set_index(metadata_timestamps)
            aligned_metadata = pd.DataFrame(index=pd.to_datetime(folder_timestamps, format="%H:%M:%S"))
            aligned_metadata = aligned_metadata.join(metadata_df, how='left')
            aligned_metadata = aligned_metadata[aircraft_metadata_params]

            folder_metadata_rows.extend(aligned_metadata.values.tolist())

            # extract LiDAR validation height exactly matching the timestamp where available, else NaN
            for ts in folder_timestamps:
                cth = self._map_timestamp_to_lidar(ts, validation_df)
                folder_validation_heights.append(cth)

            # Create a folder-level DataFrame
            folder_data = {
                'timestamp': folder_timestamps,
                'image_path': folder_image_paths,
                **{param: [row[i] for row in folder_metadata_rows] for i, param in enumerate(aircraft_metadata_params)},
                'validation_height': folder_validation_heights
            }
            folder_df = pd.DataFrame(folder_data)

            # Remove rows after the last valid validation_height in this folder
            last_valid_index = folder_df['validation_height'].last_valid_index()
            if last_valid_index is not None:
                folder_df_cleaned = folder_df.loc[:last_valid_index].copy()  # Use .copy() to ensure independence
            else:
                folder_df_cleaned = folder_df.copy()  # In case there are no valid entries

            # Extend to the global lists
            image_paths.extend(folder_df_cleaned['image_path'].tolist())
            timestamps.extend(folder_df_cleaned['timestamp'].tolist())
            metadata_rows.extend(folder_df_cleaned[aircraft_metadata_params].values.tolist())
            validation_heights.extend(folder_df_cleaned['validation_height'].tolist())

            # Print the lengths for the current folder
            print(f"Folder {folder}:")
            print(f"  Number of images: {len(folder_df_cleaned['image_path'])}")
            print(f"  Number of timestamps: {len(folder_df_cleaned['timestamp'])}")
            print(f"  Number of metadata rows: {len(folder_df_cleaned)}")
            print(f"  Number of validation heights: {len(folder_df_cleaned['validation_height'])}")

        # Print the final lengths after processing all folders
        print("After processing all folders combined:")
        print(f"  Total number of images: {len(image_paths)}")
        print(f"  Total number of timestamps: {len(timestamps)}")
        print(f"  Total number of metadata rows: {len(metadata_rows)}")
        print(f"  Total number of validation heights: {len(validation_heights)}")

        # Check for any mismatches
        if not (len(image_paths) == len(timestamps) == len(metadata_rows) == len(validation_heights)):
            print("Error: Length mismatch detected!")
            print(f"  Images: {len(image_paths)}")
            print(f"  Timestamps: {len(timestamps)}")
            print(f"  Metadata rows: {len(metadata_rows)}")
            print(f"  Validation heights: {len(validation_heights)}")
            return None

        # combine all aligned data in a df
        data = {
            'timestamp': timestamps,
            'image_path': image_paths,
            **{param: [row[i] for row in metadata_rows] for i, param in enumerate(aircraft_metadata_params)},
            'validation_height': validation_heights
        }
        df = pd.DataFrame(data)
        df = df.drop(columns=['DateTime_UTC'])

        # Add sequence length information for RNN
        self._add_sequence_length_column(df)
        return df

    def _extract_timestamp_from_filename(self, filename):
        """
        Extracts the timestamp from the image filename on the blob.
        path/to/blob/YYYYMMDD_HHMMSS_frame_n_cropped.jpg -> %Y%m%d%H%M%S
        """
        filename = os.path.basename(filename)
        date_str = filename.split("_")[0]
        time_str = filename.split("_")[1]
        timestamp = datetime.strptime(date_str + time_str, "%Y%m%d%H%M%S")
        return timestamp


    def _map_timestamp_to_lidar(self, timestamp, validation_df):
        """
        extract LiDAR validation height exactly matching the timestamp where available, else NaN
        """
        validation_df['datetime_combined'] = pd.to_datetime(validation_df['datetime_combined'], format="%Y-%m-%d %H:%M:%S")
        timestamp_dt = pd.to_datetime(timestamp, format="%Y-%m-%d %H:%M:%S")
        exact_match = validation_df[validation_df['datetime_combined'] == timestamp_dt]
        return exact_match[CTH_col].values[0] if not exact_match.empty else np.nan


    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        """
        Retrieve a data record from the dataset for a given index.
        Returns loaded and transformed image, metadata and validation height associated with that image if available.
        """
        row = self.data_df.iloc[idx]

        image_path = row['image_path']
        blob_client = container_client.get_blob_client(image_path)
        streamdownloader = blob_client.download_blob()
        img_data = streamdownloader.readall()
        img = Image.open(io.BytesIO(img_data)).convert("RGB")
        if self.transform:
            img = self.transform(img)

        metadata = torch.tensor([row[param] for param in aircraft_metadata_params], dtype=torch.float32)
        validation_height = torch.tensor([row['validation_height']], dtype=torch.float32)

        return img, metadata, validation_height

    def _add_sequence_length_column(self, df):
        # Initialize a new column to NaN
        df['sequence_length'] = np.nan

        # Track the start of each sequence
        sequence_start = 0

        # Iterate through the dataframe to detect when a validation_height exists
        for i in range(len(df)):
            if not pd.isna(df.loc[i, 'validation_height']):
                # We found the end of a sequence, so mark the previous sequence images
                sequence_length = i - sequence_start
                df.loc[sequence_start:i, 'sequence_length'] = sequence_length + 1  # Using count (1-based index)
                sequence_start = i + 1  # Move the start to the next sequence

        # Ensure all sequence_length values are integers (if any were missed)
        df['sequence_length'] = df['sequence_length'].astype(int)

        return df

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

augmentations = transforms.Compose([
    #transforms.ColorJitter(brightness=0.9, contrast=1.6, saturation=1.2, hue=0.1),
])

# Create the full CloudDataset without splitting initially
full_dataset = CloudDataset(train_dates, transform=None)

# Extract the full dataframe from CloudDataset
full_dataframe = full_dataset.data_df

In [None]:
full_dataframe.to_csv('train_dataset.csv', index=False, header=True)

In [None]:
full_dataframe.head(20)

In [None]:
full_dataframe.tail(20)

# Image loading and augmentation

In [None]:
def load_image_from_blob_cv(blob_img):
    """
    Loads the image from the Azure Blob Storage using OpenCV and returns it as a numpy array.
    Args:
        blob_img (str): name of the blob image in the container
    Returns:
        (numpy.ndarray): loaded image in greyscale
    """
    blob_client = container_client.get_blob_client(blob_img)
    streamdownloader = blob_client.download_blob()
    blob_data = streamdownloader.readall()
    image_array = np.asarray(bytearray(blob_data), dtype=np.uint8)
    img_bgr = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    return img_rgb

def augment_color_image(img, contrast_factor=1.2, brightness_beta=10, kernel_size=(5, 5), blur=True):
    """
    Augments the color image to enhance cloud features. Includes:
      1. Contrast and Brightness Adjustment
      2. Gaussian Blur (optional)
      3. CLAHE (Contrast Limited Adaptive Histogram Equalization)
      4. Sharpening (optional)

    Args:
        img (numpy.ndarray): color image in RGB format
        contrast_factor (float): contrast factor
        brightness_beta (int): brightness factor
        kernel_size (tuple): kernel size for Gaussian Blur
        blur (bool): whether to apply Gaussian Blur

    Returns:
        (numpy.ndarray): augmented image
    """
    # Step 1: Adjust contrast and brightness
    img_enhanced = cv2.convertScaleAbs(img, alpha=contrast_factor, beta=brightness_beta)

    # Step 2: Optionally apply Gaussian Blur
    if blur:
        img_enhanced = cv2.GaussianBlur(img_enhanced, kernel_size, 0)

    # Step 3: Convert to LAB color space to apply CLAHE on the L channel
    img_lab = cv2.cvtColor(img_enhanced, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(img_lab)

    # Step 4: Apply CLAHE (Clip Limit Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(4, 4))
    cl = clahe.apply(l)

    # Step 5: Merge CLAHE-enhanced L channel back with A and B channels
    img_clahe = cv2.merge((cl, a, b))

    # Step 6: Convert back to RGB color space
    img_clahe_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)

    # Step 7: Optionally apply sharpening
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    img_sharpened = cv2.filter2D(img_clahe_rgb, -1, kernel)

    return img_sharpened

# Modify augment_color_image to use the original CLAHE-only method
def augment_color_image_clahe_only(img, clip_limit=1.5, tile_grid_size=(4, 4)):
    """
    Augments the color image by applying CLAHE (Contrast Limited Adaptive Histogram Equalization) only.

    Args:
        img (numpy.ndarray): color image in RGB format
        clip_limit (float): clip limit for CLAHE (higher values give more contrast)
        tile_grid_size (tuple): size of the grid for applying CLAHE

    Returns:
        (numpy.ndarray): augmented image with CLAHE applied to the L channel in LAB color space
    """
    # Convert to LAB color space to apply CLAHE on the L channel (lightness)
    img_lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(img_lab)

    # Apply CLAHE to the L channel
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    cl = clahe.apply(l)

    # Merge CLAHE-enhanced L channel back with A and B channels
    img_clahe = cv2.merge((cl, a, b))

    # Convert back to RGB color space
    img_clahe_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)

    return img_clahe_rgb

def augment_greyscale_image(img, contrast_factor=1.5, brightness_beta=30, kernel_size=(5, 5), blur=True):
    """
    Augments the Greyscale image to enhance the feature of the cloud. Includes:
      1. Contrast and Brightness
      2. Gaussian Blur
      3. Histogram Equalization
      4. Sharpening
    Args:
        img (numpy.ndarray): greyscale image
        contrast_factor (float): contrast factor
        brightness_beta (int): brightness factor
        kernel_size (tuple): kernel size for Gaussian Blur
        blur (bool): whether to apply Gaussian Blur
    Returns:
        (numpy.ndarray): augmented image
    """
    img_enhanced = cv2.convertScaleAbs(img, alpha=contrast_factor, beta=brightness_beta)
    if blur:
        img_enhanced = cv2.GaussianBlur(img_enhanced, kernel_size, 0)
    img_enhanced = cv2.equalizeHist(img_enhanced)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    img_enhanced = clahe.apply(img_enhanced)
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    return cv2.filter2D(img_enhanced, -1, kernel)

def undistort_fisheye_image(img, K, D, balance=0.5):
    """
    Undistorts a fisheye image using the intrinsic camera matrix and distortion coefficients.

    Args:
        img (numpy.ndarray): The distorted fisheye image.
        K (numpy.ndarray): The camera matrix.
        D (numpy.ndarray): The distortion coefficients.
        balance (float): Balance between the undistorted and distorted image. 0.0 means fully zoomed in,
                         1.0 means fully zoomed out.

    Returns:
        undistorted_img (numpy.ndarray): The undistorted image.
    """
    h, w = img.shape[:2]

    # Generate new camera matrix based on free scaling parameter (alpha=1 keeps all the image)
    new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(K, D, (w, h), np.eye(3), balance=balance)

    # Undistort the image
    map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, D, np.eye(3), new_K, (w, h), cv2.CV_16SC2)
    undistorted_img = cv2.remap(img, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)

    return undistorted_img

def visualize_augmentations_cv(image_list, enhance=True):
    """
    Visualizes original and augmented greyscale images using OpenCV.
    Args:
        image_list (list): List of image paths in Azure Blob Storage
        enhance (bool): Whether to apply image augmentation
    Returns:
        None
    """
    fig, axes = plt.subplots(len(image_list), 3, figsize=(15, 5 * len(image_list)))

    for idx, blob_img in enumerate(image_list):
        original_image = load_image_from_blob_cv(blob_img)
        greyscale_image = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY)
        augmented_image = greyscale_image.copy()

        if enhance:
            augmented_image = augment_greyscale_image(augmented_image, contrast_factor=1.5, brightness_beta=30, kernel_size=(5, 5))

        axes[idx, 0].imshow(original_image)
        axes[idx, 0].set_title("Original Image (RGB)")
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(greyscale_image, cmap='gray')
        axes[idx, 1].set_title("Greyscale Image")
        axes[idx, 1].axis('off')

        axes[idx, 2].imshow(augmented_image, cmap='gray')
        axes[idx, 2].set_title("Augmented Image")
        axes[idx, 2].axis('off')

    plt.tight_layout()
    plt.show()

def random_crop_sequence(image_sequence, crop_size_minimum=224):
    """
    Apply identical random crop to a sequence of images and ensure randomness in cropping location.
    The crop size is randomly selected between the minimum size (crop_size_minimum) and the full image size.

    Args:
        image_sequence: A list or tensor of images in the sequence (each image should be the same size).
        crop_size_minimum: The minimum crop size (default is 224).

    Returns:
        cropped_sequence: The cropped sequence of images.
        new_center_coords: The new coordinates of the center pixel after cropping.
    """
    # Get the height and width of the original images
    orig_height, orig_width = image_sequence[0].shape[-2:]

    # Randomly select the crop size between the minimum and full image size
    max_crop_size = min(orig_height, orig_width)
    crop_h = random.randint(crop_size_minimum, max_crop_size)
    crop_w = crop_h  # Keeping the crop square for simplicity

    # Ensure that the crop size is smaller than the original image
    assert crop_h <= orig_height and crop_w <= orig_width, "Crop size must be smaller than the original image size."

    # Original center coordinates of the image
    orig_center_x = orig_width // 2
    orig_center_y = orig_height // 2

    # Randomly select the top-left corner for the crop
    top = random.randint(0, orig_height - crop_h)
    left = random.randint(0, orig_width - crop_w)

    # Apply the same crop to each image in the sequence
    cropped_sequence = [TF.crop(img, top, left, crop_h, crop_w) for img in image_sequence]

    # Calculate the new center coordinates based on the original center relative to the cropped region
    new_center_x = orig_center_x - left  # Adjust the original center x based on the crop left offset
    new_center_y = orig_center_y - top   # Adjust the original center y based on the crop top offset

    # Ensure the new center coordinates are still within the cropped image bounds
    new_center_x = max(0, min(new_center_x, crop_w - 1))
    new_center_y = max(0, min(new_center_y, crop_h - 1))

    return cropped_sequence, (new_center_x, new_center_y)

def resize_sequence_and_adjust_center(cropped_sequence, new_center_coords, target_size=(224, 224)):
    """
    Resize the cropped sequence to a target size and adjust the center coordinates accordingly.

    Args:
        cropped_sequence: The list of cropped images.
        new_center_coords: The (x, y) coordinates of the center in the cropped image.
        target_size: The desired output size (height, width), default is (224, 224).

    Returns:
        resized_sequence: The resized sequence of images.
        resized_center_coords: The adjusted (x, y) coordinates of the center in the resized images.
    """
    crop_h, crop_w = cropped_sequence[0].shape[-2:]  # Get height and width of the cropped image
    target_h, target_w = target_size  # Desired target size (e.g., 224x224)

    # Calculate the scaling factors for height and width
    scale_x = target_w / crop_w
    scale_y = target_h / crop_h

    # Resize each image in the sequence to the target size
    resized_sequence = [TF.resize(img, size=target_size) for img in cropped_sequence]

    # Adjust the center coordinates according to the scaling factor
    new_center_x, new_center_y = new_center_coords
    resized_center_x = int(new_center_x * scale_x)
    resized_center_y = int(new_center_y * scale_y)

    return resized_sequence, (resized_center_x, resized_center_y)

In [None]:
images = [
    "20170418/170418_175706_183328_frames/20170418_175706_frame_0.jpg",
    "20170418/170418_175706_183328_frames/20170418_175707_frame_60.jpg",
    "20170418/170418_175706_183328_frames_cropped/20170418_175708_frame_120_cropped.jpg",
    "20170418/170418_175706_183328_frames_cropped/20170418_175709_frame_180_cropped.jpg",
    "20170418/170418_175706_183328_frames_cropped/20170418_175710_frame_240_cropped.jpg",
    "20170418/170418_175706_183328_frames_tight_crop/20170418_175708_frame_120_cropped.jpg",
    "20170418/170418_175706_183328_frames_tight_crop/20170418_175709_frame_180_cropped.jpg",
    "20170418/170418_175706_183328_frames_tight_crop/20170418_175710_frame_240_cropped.jpg",
]

visualize_augmentations_cv(images, True)

# CNN-RNN

In [None]:
class Cloud2CloudDataset(Dataset):
    def __init__(self,
                 dataframe,
                 normalization_params=None,
                 transform=None, augmentations=None,
                 apply_normalization=True,
                 apply_crop_and_scale=True,
                 resize_size=224):  # Add resize_size parameter with default value 224
        """
        Args:
            dataframe (pd.DataFrame): dataframe containing image paths, flight data, and ground truth.
            normalization_params (dict, optional): Dictionary containing min and max values for all the fields.
            transform (callable, optional): optional transform to apply to each image.
            augmentations (callable, optional): optional augmentations to apply to each image.
            apply_normalization (bool): whether to apply normalization to the dataframe.
            resize_size (int): The target size to resize the images (square dimensions).
        """
        self.dataframe = dataframe
        self.transform = transform
        self.augmentations = augmentations
        self.apply_normalization = apply_normalization
        self.apply_crop_and_scale = apply_crop_and_scale
        self.resize_size = resize_size  # Save resize_size to be used later
        self.columns_to_normalize = ['validation_height', 'GPS_MSL_Alt', 'Drift', 'Pitch', 'Roll', 'Vert_Velocity']

        if self.apply_normalization:
            # Calculate or use provided normalization parameters
            self.normalization_params = normalization_params or self._calculate_normalization_params(dataframe, self.columns_to_normalize)

            for col in self.columns_to_normalize:
                col_min = self.normalization_params[col]['min']
                col_max = self.normalization_params[col]['max']
                self.dataframe[col] = (self.dataframe[col] - col_min) / (col_max - col_min)

        # Track the indices where sequences start
        self.sequence_indices = self._generate_sequence_indices()

    def _calculate_normalization_params(self, dataframe, columns_to_normalize):
        """
        Manually calculate min and max values for the columns and store them for consistency across datasets.
        """
        params = {}
        for col in columns_to_normalize:
            params[col] = {
                'min': dataframe[col].min(),
                'max': dataframe[col].max()
            }
        return params

    def denormalize_validation_height(self, normalized_height):
        """
        Denormalize the validation height using the stored min and max values for validation_height.

        Args:
            normalized_height (float or np.array): The normalized validation height(s) to denormalize.

        Returns:
            float or np.array: The denormalized validation height(s).
        """
        # Get min and max for validation_height
        col_min = self.normalization_params['validation_height']['min']
        col_max = self.normalization_params['validation_height']['max']

        print(f"  Max: {col_max} Min: {col_min}")

        # Denormalize using the stored min and max values
        original_height = normalized_height * (col_max - col_min) + col_min
        return original_height

    def denormalize_flight_data(self, normalized_flight_data):
        """
        Denormalize flight data using the stored min and max values for each column.

        Args:
            normalized_flight_data (np.array): Normalized flight data array to denormalize.

        Returns:
            np.array: Denormalized flight data.
        """
        denormalized_flight_data = []
        for i, col in enumerate(self.columns_to_normalize[1:]):  # Skip validation_height
            col_min = self.normalization_params[col]['min']
            col_max = self.normalization_params[col]['max']
            denorm_value = normalized_flight_data[i] * (col_max - col_min) + col_min
            denormalized_flight_data.append(denorm_value)
        return np.array(denormalized_flight_data)


    def _generate_sequence_indices(self):
        """
        Generate a list of indices where each sequence starts.
        If a NaN is encountered in the sequence_length, the process stops.
        """
        sequence_indices = []
        idx = 0

        while idx < len(self.dataframe):
            sequence_length = self.dataframe.iloc[idx]['sequence_length']

            if pd.isna(sequence_length):
                # Stop if sequence_length is NaN (no more sequences)
                break

            # Convert sequence_length to an integer
            sequence_length = int(sequence_length)
            sequence_indices.append(idx)
            idx += sequence_length  # Move to the start of the next sequence

        return sequence_indices

    def __len__(self):
        return len(self.sequence_indices)

    def __getitem__(self, idx):
        # Get the starting index of the sequence
        start_idx = self.sequence_indices[idx]

        # Get the sequence length for this specific starting point
        sequence_length = int(self.dataframe.iloc[start_idx]['sequence_length'])

        # Fetch the image sequence from Azure Blob
        image_sequence = []

        # Distortion coefficients (D) - estimated
        D = np.array([-0.34, 0.12, -0.01, 0.0], dtype=np.float64)

        for i in range(sequence_length):
            image_path = self.dataframe.iloc[start_idx + i]['image_path']

            # Load image using OpenCV
            img_rgb = load_image_from_blob_cv(image_path)

            # Get the height and width of the loaded image
            h, w = img_rgb.shape[:2]

            # Dynamically adjust the camera matrix (K) based on the image dimensions
            K = np.array([[w, 0, w / 2],
                          [0, h, h / 2],
                          [0, 0, 1]], dtype=np.float64)

            # Undistort the fisheye image (if needed, uncomment this line)
            # img_rgb = undistort_fisheye_image(img_rgb, K, D, balance=0.5)

            # Apply the color augmentation instead of grayscale
            #img_rgb_converted = augment_color_image_clahe_only(img_rgb)  # Using augment_color_image function
            img_rgb_converted = augment_color_image_clahe_only(img_rgb, clip_limit=3.0, tile_grid_size=(8, 8))

            # Apply additional augmentations to the RGB image if any
            if self.augmentations:
                img_rgb_converted = self.augmentations(img_rgb_converted)

            # Convert back to tensor and append to the sequence
            img_tensor = TF.to_tensor(img_rgb_converted)
            image_sequence.append(img_tensor)

        # Apply random cropping and resizing only if `apply_crop_and_scale` is True
        if self.apply_crop_and_scale:
            # Apply random cropping to the entire sequence with varying crop sizes (minimum of self.resize_size, up to full image size)
            cropped_sequence, new_center_coords = random_crop_sequence(image_sequence, crop_size_minimum=self.resize_size)

            # Resize the cropped sequence and adjust the center coordinates accordingly
            resized_sequence, resized_center_coords = resize_sequence_and_adjust_center(cropped_sequence, new_center_coords)

            # Stack images into a tensor
            image_sequence = torch.stack(resized_sequence)
        else:
            # Resize original images (whatever their original size) to 224x224
            resized_sequence = [TF.resize(img, size=(self.resize_size, self.resize_size)) for img in image_sequence]

            # Stack the resized images
            image_sequence = torch.stack(resized_sequence)

            # Since no cropping is applied, calculate the center of the resized 224x224 image
            resized_center_coords = (self.resize_size // 2, self.resize_size // 2)  # Center of the resized square image

        # Fetch the additional flight data from the dataframe
        flight_data = self.dataframe.iloc[start_idx:start_idx + sequence_length][['GPS_MSL_Alt', 'Drift', 'Pitch', 'Roll', 'Vert_Velocity']].values

        # Fetch the validation height (target), just for printing
        validation_height = self.dataframe.iloc[start_idx + sequence_length - 1]['validation_height']

        # Print the new center coordinates for debugging
        # print(f"New center coordinates in resized image: {resized_center_coords}")

        return image_sequence, flight_data, validation_height, resized_center_coords

In [None]:
# Create the full Cloud2CloudDataset (no normalization yet)
full_cloud2cloud_dataset = Cloud2CloudDataset(dataframe=full_dataframe, transform=transform, augmentations=augmentations, apply_normalization=False, resize_size=512)

# Split sequence indices into training and validation sets
train_sequence_indices, val_sequence_indices = train_test_split(full_cloud2cloud_dataset.sequence_indices, test_size=0.2, random_state=42)

# Extract rows for training dataset based on sequence indices
train_rows = []
for seq_start in train_sequence_indices:
    sequence_length = full_cloud2cloud_dataset.dataframe.iloc[seq_start]['sequence_length']
    train_rows.extend(range(seq_start, seq_start + sequence_length))

train_dataframe = full_dataframe.iloc[train_rows].reset_index(drop=True)

# Create the training dataset and calculate normalization parameters
train_cloud2cloud_dataset = Cloud2CloudDataset(dataframe=train_dataframe, transform=transform, augmentations=augmentations, apply_normalization=True, apply_crop_and_scale=True, resize_size=512)

# Get normalization parameters from the training dataset
normalization_params = train_cloud2cloud_dataset.normalization_params

# Extract rows for validation dataset based on sequence indices
val_rows = []
for seq_start in val_sequence_indices:
    sequence_length = full_cloud2cloud_dataset.dataframe.iloc[seq_start]['sequence_length']
    val_rows.extend(range(seq_start, seq_start + sequence_length))

val_dataframe = full_dataframe.iloc[val_rows].reset_index(drop=True)

# Create the validation dataset using the same normalization parameters
val_cloud2cloud_dataset = Cloud2CloudDataset(dataframe=val_dataframe, normalization_params=normalization_params, transform=transform, augmentations=augmentations, apply_normalization=True, apply_crop_and_scale=False, resize_size=512)

# Create DataLoaders for training and validation sets
train_dataloader_single_batch = DataLoader(train_cloud2cloud_dataset, batch_size=1, shuffle=True)
val_dataloader_single_batch = DataLoader(val_cloud2cloud_dataset, batch_size=1, shuffle=False)

In [None]:
# Function to display a grid of images with a red dot at the revised center coordinates only on the last image
def show_image_grid_with_center(images, center_coords, titles=None):
    # Number of images in the sequence
    num_images = images.shape[0]

    # Set up the grid (1 row, num_images columns)
    fig, axes = plt.subplots(1, num_images, figsize=(15, 15))

    if num_images == 1:
        axes = [axes]  # Ensure axes is iterable even with one image

    for i, ax in enumerate(axes):
        img = images[i].permute(1, 2, 0).numpy()  # Convert tensor to HWC format
        img = np.clip(img, 0, 1)  # Ensure values are between 0 and 1 after normalization
        ax.imshow(img)

        # Overlay red dot only on the last image in the sequence
        if i == num_images - 1:
            center_x, center_y = center_coords
            ax.plot(center_x, center_y, 'ro')  # Red dot (marker 'ro' for red circle)

        ax.axis('off')

        if titles:
            ax.set_title(titles[i])

    plt.show()

# Iterate over the first 5 sequences for display with red dot on center
for i, (images, flight_data, validation_height, resized_center_coords) in enumerate(train_dataloader_single_batch):
    if i >= 5:  # Limiting to 5 sequences
        break

    # Display the image sequence
    print(f"Sequence {i + 1}:")

    # Display normalized flight data
    print("Normalized Flight Data:")
    flight_data_np = flight_data.squeeze(0).numpy()  # Convert tensor to NumPy for readability
    columns = ['GPS_MSL_Alt', 'Drift', 'Pitch', 'Roll', 'Vert_Velocity']

    for step, data in enumerate(flight_data_np):
        flight_info = dict(zip(columns, data))
        print(f"  Step {step + 1}: {flight_info}")

    # Denormalize and display flight data
    print("Denormalized Flight Data:")
    for step, data in enumerate(flight_data_np):
        denormalized_data = train_cloud2cloud_dataset.denormalize_flight_data(data)
        flight_info_denorm = dict(zip(columns, denormalized_data))
        print(f"  Step {step + 1} (denormalized): {flight_info_denorm}")

    # Display normalized validation height
    print("Normalized Validation Height:", validation_height.item())

    # Denormalize and display validation height
    validation_height_denorm = train_cloud2cloud_dataset.denormalize_validation_height(validation_height.item())
    print(f"Denormalized Validation Height: {validation_height_denorm:.2f} meters")

    # Unpack resized center coordinates (no need to squeeze since it's a list)
    resized_center_x, resized_center_y = resized_center_coords
    print(f"Resized Center Coordinates: (x: {resized_center_x}, y: {resized_center_y})")

    # Display images in a grid with the red dot at the resized center coordinates
    show_image_grid_with_center(images.squeeze(0), (resized_center_x, resized_center_y))  # Remove batch dimension (since batch size is 1)

# Height Field Model

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt

# Download MiDaS model using torch.hub
def load_midas_model(device):
    midas_model = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
    midas_model.to(device)
    midas_model.eval()
    return midas_model

# Function to generate a relative height map with the center pixel set to zero
def generate_relative_height_map(midas_model, image_tensor, device):
    """
    Generates a relative height map using MiDaS, setting the center pixel to zero.

    Args:
        midas_model (nn.Module): The MiDaS model for depth estimation.
        image_tensor (torch.Tensor): A tensor of shape (channels, height, width) representing the final image.

    Returns:
        relative_height_map (np.array): A 2D NumPy array of shape (height, width) with the center pixel at zero.
    """
    # Print the original image size before any processing
    print("Original image size:", image_tensor.shape)

    # Assert statement to check if image is 512x512
    assert image_tensor.shape[1:] == (512, 512), "Image must be 512x512 for the large MiDaS model."

    # Ensure the input tensor has batch and channel dimensions
    if image_tensor.dim() == 3:
        image_tensor = image_tensor.unsqueeze(0)  # Add batch dimension

    # Pass image through MiDaS directly without resizing
    input_image = image_tensor.to(device)
    with torch.no_grad():
        depth_map = midas_model(input_image)

    # Remove batch and channel dimensions to get a 2D depth map
    depth_map_2d = depth_map.squeeze().cpu().numpy()

    # Set center pixel to zero for relative height
    center_y, center_x = depth_map_2d.shape[0] // 2, depth_map_2d.shape[1] // 2
    center_value = depth_map_2d[center_y, center_x]
    relative_height_map = depth_map_2d - center_value

    return relative_height_map

# Helper function to display 3D topographical plot of height field using Plotly
def plotly_3d_height_field(height_field, title="Predicted Relative Height Field"):
    # Ensure height_field is 2D by selecting the first element if it has a batch dimension
    if height_field.ndim == 3:
        height_field = height_field[0]

    # Create a grid of coordinates (x, y) for the height field
    x = np.arange(height_field.shape[1])
    y = np.arange(height_field.shape[0])
    x, y = np.meshgrid(x, y)

    # Create the plotly surface plot
    fig = go.Figure(data=[go.Surface(z=height_field, x=x, y=y, colorscale='Viridis')])

    # Customize layout
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Relative Height'
        ),
        margin=dict(l=0, r=0, b=0, t=30)
    )

    # Display the plot
    fig.show()

# Function to display 5 random samples with their 3D height fields
def display_5_random_samples_with_3d_plotly(dataset, midas_model, device):
    indices = np.random.choice(len(dataset), 5, replace=False)
    for idx in indices:
        image_sequence, _, _, _ = dataset[idx]
        image_sequence = image_sequence.to(device)

        # Use the last image in the sequence for MiDaS
        last_image = image_sequence[-1]

        # Generate the relative height field using MiDaS
        relative_height_field = generate_relative_height_map(midas_model, last_image, device)
        relative_height_field_np = relative_height_field

        # Print the height field values for debugging
        print(f"Height field values for sample {idx + 1}:\n", relative_height_field_np)

        # Plot the image sequence with the center dot
        show_image_grid_with_center(image_sequence.cpu(), center_coords=(112, 112))

        # Plot the relative height field in 3D using Plotly
        plotly_3d_height_field(relative_height_field_np, title=f"Sample {idx + 1}: Relative Height Field")

# Load MiDaS model and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas_model = load_midas_model(device)

# Call the function
display_5_random_samples_with_3d_plotly(val_cloud2cloud_dataset, midas_model, device)