In [None]:
import os
import random
import numpy as np
from rasterio.windows import Window
import matplotlib.pyplot as plt
from tqdm import tqdm 

In [None]:
def create_overlapping_cubes(hsi_data, num_cubes, cube_size=32, stride=16):
    """
    Extract overlapping cubes from HSI data with controlled overlap using stride.
    Args:
        hsi_data: 3D numpy array (height, width, bands)
        num_cubes: Maximum number of cubes to extract
        cube_size: Spatial dimensions of the cube (cube_size x cube_size)
        stride: Step size between cube positions (controls overlap)
    Returns:
        List of 3D cubes (cube_size x cube_size x bands)
    """
    cubes = []
    height, width, num_bands = hsi_data.shape
    cube_height = cube_width = cube_size

    # Create mask of valid (non-zero) regions
    valid_mask = np.any(hsi_data != 0, axis=2)

    # Generate all possible positions with specified stride
    valid_positions = [
        (i, j)
        for i in range(0, height - cube_height + 1, stride)
        for j in range(0, width - cube_width + 1, stride)
        if valid_mask[i:i + cube_height, j:j + cube_width].all()
    ]

    if not valid_positions:
        print("Warning: No valid regions found for cube extraction.")
        return []

    # Randomly select positions without replacement
    random.shuffle(valid_positions)
    selected_positions = valid_positions[:num_cubes]

    # Extract cubes from selected positions
    for top, left in selected_positions:
        cube = hsi_data[top:top + cube_height, left:left + cube_width, :]
        cubes.append(cube)

    return cubes

def save_cube(cube, output_path):
    """Save a cube as .npy file"""
    np.save(output_path, cube)

def process_hsi_file(file_path, output_dir, num_cubes, cube_size=32, stride=16):
    """
    Process single HSI file with overlapping cubes
    """
    hsi_data = np.load(file_path)
    
    cubes = create_overlapping_cubes(
        hsi_data,
        num_cubes=num_cubes,
        cube_size=cube_size,
        stride=stride
    )
    
    base_filename = os.path.basename(file_path).replace('.npy', '')
    for i, cube in enumerate(cubes):
        output_path = os.path.join(output_dir, f"{base_filename}_cube_{i+1}.npy")
        save_cube(cube, output_path)

def process_folder_structure(input_dir, output_dir, num_cubes, cube_size=32, stride=16):
    """
    Process entire folder tree with overlapping cubes
    """
    for root, _, files in tqdm(os.walk(input_dir), desc="Processing folders"):
        for file in files:
            if file.endswith('.npy'):
                input_path = os.path.join(root, file)
                
                # Maintain folder structure in output
                relative_path = os.path.relpath(root, input_dir)
                output_subdir = os.path.join(output_dir, relative_path)
                os.makedirs(output_subdir, exist_ok=True)
                
                process_hsi_file(
                    input_path,
                    output_subdir,
                    num_cubes=num_cubes,
                    cube_size=cube_size,
                    stride=stride
                )

# Configuration
num_cubes = 10               # Max cubes per input file
cube_size = 32               # Spatial dimensions of cubes
stride = cube_size // 2      # 50% overlap (optimal default)

input_directory = '<path_to_file>'
output_directory = '<path_to_file>'

# Run processing
process_folder_structure(
    input_directory,
    output_directory,
    num_cubes=num_cubes,
    cube_size=cube_size,
    stride=stride
)

In [None]:
def visualize_random_images(output_dir, num_images=5, band_index=399):
    """
    Visualizes random images from the masked HSI directory.
    
    Args:
    - output_dir (str): Path to the directory containing masked .npy files.
    - num_images (int): Number of random images to visualize.
    - band_index (int): Spectral band to visualize.
    """
    
    # Collect all .npy files from the output directory
    file_paths = []
    
    for root, _, files in os.walk(output_dir):
        
        for file in files:
            
            if file.endswith('.npy'):
                
                file_paths.append(os.path.join(root, file))
    
    
    if len(file_paths) < num_images:
        
        print(f"Only {len(file_paths)} files found. Visualizing all available files.")
        num_images = len(file_paths)
    
    
    # Randomly select files to visualize
    selected_files = random.sample(file_paths, num_images)
    
    
    # Plot the selected images
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    
    if num_images == 1:
        
        axes = [axes]  # Ensure axes is always iterable
    
    
    for i, file_path in enumerate(selected_files):
        
        hsi_data = np.load(file_path)
        band_image = hsi_data[:, :, band_index]
        
        ax = axes[i]
        ax.imshow(band_image, cmap='viridis')
        ax.set_title(f"File {i+1}\n{os.path.basename(file_path)}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Define the output directory
output_directory =  '<path_to_file>'

# Visualize random images
visualize_random_images(output_directory)