In [None]:
import os
import numpy as np
import rasterio
import logging

# Set up logging to print live updates
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def calculate_mean_std_per_band(image_paths, num_bands):
    band_means = np.zeros(num_bands)
    band_stds = np.zeros(num_bands)
    band_pixel_counts = np.zeros(num_bands)

    logging.info(f"Starting calculation on {len(image_paths)} images")

    # Loop through each image and accumulate pixel values for each band
    for i, path in enumerate(image_paths):
        logging.info(f"Processing image {i+1}/{len(image_paths)}: {path}")
        with rasterio.open(path) as dataset:
            for band_index in range(num_bands):
                # Read the specific band
                band_data = dataset.read(band_index + 1)  # Bands are 1-indexed in rasterio
                # Flatten the band data
                band_data_flat = band_data.flatten()
                # Update mean and count calculations for the band
                band_means[band_index] += np.sum(band_data_flat)
                band_pixel_counts[band_index] += len(band_data_flat)
    
    # Calculate the mean for each band
    band_means /= band_pixel_counts
    logging.info(f"Means calculated: {band_means}")

    # Now calculate the standard deviation for each band
    for i, path in enumerate(image_paths):
        logging.info(f"Calculating standard deviation for image {i+1}/{len(image_paths)}: {path}")
        with rasterio.open(path) as dataset:
            for band_index in range(num_bands):
                # Read the specific band
                band_data = dataset.read(band_index + 1)
                # Flatten the band data
                band_data_flat = band_data.flatten()
                # Accumulate the squared difference from the mean
                band_stds[band_index] += np.sum((band_data_flat - band_means[band_index]) ** 2)
    
    # Calculate standard deviation for each band
    band_stds = np.sqrt(band_stds / band_pixel_counts)
    logging.info(f"Standard deviations calculated: {band_stds}")

    return band_means.tolist(), band_stds.tolist()

def get_image_paths(directory, extension=".tif", include_keyword="_merged"):
    # Collect all image file paths from the given directory with the specified extension
    # and include only files containing the specified keyword
    image_paths = []
    logging.info(f"Collecting image paths from directory: {directory}")
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(extension) and include_keyword in file:
                image_paths.append(os.path.join(root, file))
    logging.info(f"Found {len(image_paths)} images in {directory}")
    return image_paths

# Define the directories containing the training and validation images
train_dir = "datasets/fire_scars_train_val/train"
val_dir = "datasets/fire_scars_train_val/validation"

# Get all .tif image paths from both directories, filtered for multi-band files
train_image_paths = get_image_paths(train_dir)
val_image_paths = get_image_paths(val_dir)

# Combine train and validation image paths
all_image_paths = train_image_paths + val_image_paths

# Get number of bands from one of the images
with rasterio.open(all_image_paths[0]) as dataset:
    num_bands = dataset.count

logging.info(f"Number of bands in the images: {num_bands}")

# Calculate mean and standard deviation per band
means, stds = calculate_mean_std_per_band(all_image_paths, num_bands)

# Output results in the specified format
result = {
    "means": means,
    "stds": stds,
    "num_classes": 2
}

logging.info(f"Final result: {result}")
print(result)