In [4]:
import cv2
import numpy as np
import os
from itertools import combinations
import csv
from collections import defaultdict

cwd = os.getcwd()
dataset_dir = 'dataset_aligned'

multispectral_bands = ['r', 'g', 'nir', 're']

images = []

for class_name in sorted(os.listdir(dataset_dir)):
    class_dir = os.path.join(dataset_dir, class_name)
    if not os.path.isdir(class_dir):
        continue

    # Group files by base filename
    samples = defaultdict(dict)
    for fname in os.listdir(class_dir):
        if not fname.lower().endswith('.tif'):
            continue

        base, ext = os.path.splitext(fname)
        for band in multispectral_bands:
            if base.endswith(f"_{band.upper()}"):
                base_name = base[:-len(f"_{band}")]  # e.g., "img1"
                image = cv2.imread(
                    os.path.join(class_dir, fname), cv2.IMREAD_UNCHANGED
                )
                
                samples[base_name][band] = image
                samples[base_name]['fname_base'] = base_name
                samples[base_name]['class'] = class_name

    # Convert dict to list
    for sample in samples.values():
        # Only include samples that have all 4 bands
        if all(b in sample for b in multispectral_bands):
            images.append(sample)
            if len(images) % 20 == 0:
                print(f"Loaded {len(images)} images")

Loaded 20 images
Loaded 40 images
Loaded 60 images
Loaded 80 images
Loaded 100 images
Loaded 120 images
Loaded 140 images
Loaded 160 images
Loaded 180 images
Loaded 200 images
Loaded 220 images


In [5]:
outdir = os.path.join(cwd, 'phase_correlation_results')
with open(os.path.join(outdir, f'{dataset_dir}.csv'), mode='w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)
    header = ['Image Fname', 'Class', 'Band1', 'Band2', 'Shift_X', 'Shift_Y', 'Response']
    csv_writer.writerow(header)

    shift_magnitudes = []
    for i, image in enumerate(images):
        image_bands_only = {
            'g': image['g'],
            'r': image['r'],
            're': image['re'],
            'nir': image['nir'],
        }
        for image1, image2 in combinations(image_bands_only.items(), 2):
            (band1, img1), (band2, img2) = image1, image2
            shift, response = cv2.phaseCorrelate(np.float32(img1), np.float32(img2))
            csv_writer.writerow([f"{image['fname_base']}", f"{image['class']}", band1, band2, shift[0], shift[1], response])
            shift_magnitude = np.sqrt(shift[0]**2 + shift[1]**2)

            shift_magnitudes.append(shift_magnitude)

In [6]:
#removing outliers
sdev = np.std(shift_magnitudes)
outliers_flag = [0 if mag < sdev/2 else 1 for mag in shift_magnitudes]
if sum(outliers_flag) / len(outliers_flag) < 0.1:
    print(f"Detected {sum(outliers_flag)} heavy outliers, removing them from Phase correlation")
    print()
    shift_magnitudes = [mag for mag in shift_magnitudes if mag < sdev/2]

print(f"Phase correlation results saved to 'phase_correlation_results/{dataset_dir}.csv'")
print(f"Average shift magnitude: {np.mean(shift_magnitudes)}")
print(f"Median shift magnitude: {np.median(shift_magnitudes)}")
print(f"StDev shift magnitude: {np.std(shift_magnitudes)}")

Detected 23 heavy outliers, removing them from Phase correlation

Phase correlation results saved to 'phase_correlation_results/dataset_aligned.csv'
Average shift magnitude: 2.174128568055843
Median shift magnitude: 1.648155625944087
StDev shift magnitude: 2.311916409374624
