In [4]:
import cv2
import numpy as np
import subprocess
import json
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

cwd = os.getcwd()

fpath_end = {
    'r': '_R.TIF',
    'g': '_G.TIF',
    'nir': '_NIR.TIF',
    're': '_RE.TIF'
}

multispectral_bands = ['r', 'g', 'nir', 're']

dataset_path = '/media/iittp/new volume/multispectral_validation_set'

metadata = {}
with open('r_params.json', 'r') as f:
    metadata['r'] = json.load(f)[0]
with open('g_params.json', 'r') as f:
    metadata['g'] = json.load(f)[0]
with open('nir_params.json', 'r') as f:
    metadata['nir'] = json.load(f)[0]
with open('re_params.json', 'r') as f:
    metadata['re'] = json.load(f)[0]

SAVE_SHIFTED = False
SAVE_UNDISTORTED = False

In [5]:
def get_relative_center(path):
    cmd = ["exiftool", "-j", path]
    info = json.loads(subprocess.check_output(cmd))[0]

    x = float(info.get("RelativeOpticalCenterX", 0))
    y = float(info.get("RelativeOpticalCenterY", 0))
    
    return x, y

def translate_band(img, dx, dy):
    M = np.float32([[1, 0, -dx], [0, 1, -dy]])
    shifted = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]),
                             flags=cv2.INTER_LINEAR,
                             borderMode=cv2.BORDER_REPLICATE)
    return shifted

def undistort_image(image, cam_matrix, distortion_coeffs, crop_percent=0.0):
    h, w = image.shape[:2]

    map1, map2 = cv2.initUndistortRectifyMap(cam_matrix, distortion_coeffs, None, cam_matrix, (w, h), cv2.CV_32FC1)
    undistorted_image = cv2.remap(image, map1, map2, cv2.INTER_LINEAR)
    
    undistorted_image = crop_image(undistorted_image, crop_percent)
        
    return undistorted_image

def crop_image(image, crop_percent):
    if crop_percent <= 0:
        return image

    h, w = image.shape[:2]
    crop_h = int(h * crop_percent)
    crop_w = int(w * crop_percent)
    cropped_image = image[crop_h:h - crop_h, crop_w:w - crop_w]

    return cropped_image

def de_vignette_image(image, vignetting_coeffs):
    h, w = image.shape[:2]
    y_indices, x_indices = np.indices((h, w))
    x_center = w / 2
    y_center = h / 2
    r = np.abs(x_indices - x_center) + np.abs(y_indices - y_center)
    
    vignetting_mask = np.polyval(vignetting_coeffs[::-1], r)
    vignetting_mask = np.clip(vignetting_mask, 0.1, 1.0)
    
    corrected_image = image.astype(np.float32) / vignetting_mask
    corrected_image = np.clip(corrected_image, 0, 65535).astype(np.uint16)
    
    return corrected_image

def to_edges(img):
    img = img.astype(np.float32)
    img = cv2.GaussianBlur(img, (7, 7), 1.8)
    gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
    gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
    return np.abs(gx) + np.abs(gy)

def downscale(img, scale=0.5):
    h, w = img.shape
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)

def upscale(img, scale=2.0):
    h, w = img.shape
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_CUBIC)

def align_image(base_image, input_image, scale=0.5):

    base_edges  = to_edges(base_image)
    input_edges = to_edges(input_image)
    
    base_edges  = downscale(base_edges, scale=scale)
    input_edges = downscale(input_edges, scale=scale)

    base_small  = base_edges / (np.percentile(base_edges, 95)  + 1e-7)
    input_small = input_edges / (np.percentile(input_edges, 95) + 1e-7)

    warp_matrix = np.eye(2, 3, dtype=np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS |
                cv2.TERM_CRITERIA_COUNT, 400, 1e-6)

    cc, warp_matrix = cv2.findTransformECC(
        base_small,
        input_small,
        warp_matrix,
        cv2.MOTION_TRANSLATION,
        criteria
    )

    aligned = cv2.warpAffine(
        input_image,
        warp_matrix,
        input_image.shape[::-1],
        flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
    )

    return aligned

def save_images(images, output_dir, scale=1.0):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for image in images:
        class_dir = os.path.join(output_dir, image['class'])
        if not os.path.exists(class_dir):
            os.makedirs(class_dir)

        for band in ['r', 'g', 'nir', 're']:
            fpath_out = os.path.join(class_dir, image['fname_base'] + fpath_end[band])
            band_image = downscale(image[band], scale=scale)
            cv2.imwrite(fpath_out, band_image)

In [6]:
images = []
LOADING_DOWNSAMPLE_SCALE = 0.5

for class_name in sorted(os.listdir(dataset_path)):
    class_dir = os.path.join(dataset_path, class_name)
    if not os.path.isdir(class_dir):
        continue

    # Group files by base filename
    samples = defaultdict(dict)
    for fname in os.listdir(class_dir):
        if not fname.lower().endswith('.tif'):
            continue

        base, ext = os.path.splitext(fname)
        for band in multispectral_bands:
            if base.endswith(f"_{band.upper()}"):
                base_name = base[:-len(f"_{band}")]  # e.g., "img1"
                image = cv2.imread(
                    os.path.join(class_dir, fname), cv2.IMREAD_UNCHANGED
                )
                
                samples[base_name][band] = downscale(image, scale=LOADING_DOWNSAMPLE_SCALE)
                samples[base_name]['fname_base'] = base_name
                samples[base_name]['class'] = class_name

    # Convert dict to list
    for sample in samples.values():
        # Only include samples that have all 4 bands
        if all(b in sample for b in multispectral_bands):
            images.append(sample)
            if len(images) % 20 == 0:
                print(f"Loaded {len(images)} images")

Loaded 20 images
Loaded 40 images
Loaded 60 images
Loaded 80 images
Loaded 100 images
Loaded 120 images
Loaded 140 images
Loaded 160 images
Loaded 180 images
Loaded 200 images
Loaded 220 images
Loaded 240 images
Loaded 260 images


In [7]:
relative_centers = {}
for band in ['r', 'g', 'nir', 're']:
    fpath_band = os.path.join(dataset_path, images[0]['class'], images[0]['fname_base'] + fpath_end[band])
    x,y = get_relative_center(fpath_band)
    relative_centers[band] = (x*LOADING_DOWNSAMPLE_SCALE, y*LOADING_DOWNSAMPLE_SCALE)

In [8]:
images_shifted = []
max_shift_x, max_shift_y = 0, 0
for img_dict in images:
    image_shifted = {}
    
    for band in ['r', 'g', 'nir', 're']:
        band_img, (dx, dy) = img_dict[band], relative_centers[band]
        band_shifted = translate_band(band_img, dx, dy)
        image_shifted[band] = band_shifted
        
        max_shift_x = max(max_shift_x, abs(dx))
        max_shift_y = max(max_shift_y, abs(dy))
    
    image_shifted['fname_base'] = img_dict['fname_base']
    image_shifted['class'] = img_dict['class']
    images_shifted.append(image_shifted)

max_shift_percent = max(max_shift_x / images_shifted[0]['r'].shape[1], max_shift_y / images_shifted[0]['r'].shape[0])

In [9]:
del images  # Free memory

In [10]:
shifted_dataset_dir = 'dataset_shifted'
if not os.path.exists(os.path.join(cwd, shifted_dataset_dir)):
    os.makedirs(os.path.join(cwd, shifted_dataset_dir))

if SAVE_SHIFTED:
    for idx, img_shifted in enumerate(images_shifted):
        fname_base = img_shifted['fname_base']
        
        for band in ['r', 'g', 're', 'nir']:
            fpath_out = os.path.join(cwd, shifted_dataset_dir, img_shifted['class'], fname_base + fpath_end[band])
            if not os.path.exists(os.path.dirname(fpath_out)):
                os.makedirs(os.path.dirname(fpath_out))
    
            cv2.imwrite(fpath_out, img_shifted[band])

In [11]:
undistorted_dataset_dir = 'dataset_undistorted'
if not os.path.exists(os.path.join(cwd, undistorted_dataset_dir)):
    os.makedirs(os.path.join(cwd, undistorted_dataset_dir))

images_undistorted = []
for idx, image in enumerate(images_shifted):
    image_undistorted = {}
    
    for band in ['r', 'g', 'nir', 're']:
        _, dewarp_data = metadata[band]['DewarpData'].split(';')
        fx, fy, cx, cy, k1, k2, p1, p2, k3 = map(float, dewarp_data.split(','))
        Cx, Cy = metadata[band]['CalibratedOpticalCenterX'], metadata[band]['CalibratedOpticalCenterY']
        vignette_coeffs = list(map(float, metadata[band]['VignettingData'].split(',')))[::-1]
        vignette_coeffs.append(1)
        
        dist_coeffs = np.array([k1, k2, p1, p2, k3])
        cam_matrix = np.array([[fx, 0, Cx-cx],
                               [0, fy, Cy-cy],
                               [0, 0, 1]])
        
        undistorted_band = undistort_image(image[band], cam_matrix, dist_coeffs, crop_percent=max_shift_percent*1.05)
        undistorted_band = de_vignette_image(undistorted_band, vignette_coeffs)
        image_undistorted[band] = undistorted_band
        
    image_undistorted['fname_base'] = image['fname_base']
    image_undistorted['class'] = image['class']
    images_undistorted.append(image_undistorted)

In [12]:
if SAVE_UNDISTORTED:
    num_threads = 16
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        def save_one_image(image, output_dir):
            save_images([image], output_dir)
        
        fpath_out = os.path.join(cwd, undistorted_dataset_dir)
        futures = [executor.submit(save_one_image, image, fpath_out) for image in images_undistorted]
        
        for future in as_completed(futures):
            future.result()

In [13]:
del images_shifted

In [14]:
aligned_dataset_dir = 'dataset_aligned'
if not os.path.exists(os.path.join(cwd, aligned_dataset_dir)):
    os.makedirs(os.path.join(cwd, aligned_dataset_dir))

aligned_images = []
num_failed_alignments = 0
for idx, image in enumerate(images_undistorted):
    image_aligned = {}
    base_band = 'nir'
    
    base_image = image[base_band]
    image_aligned[base_band] = crop_image(base_image, 0.05)
    
    try:
        for band in ['g', 'r', 're']:
            input_image = image[band]
            aligned_band = align_image(base_image, input_image, scale=0.50)
            image_aligned[band] = crop_image(aligned_band, 0.05)
        
        image_aligned['fname_base'] = image['fname_base']
        image_aligned['class'] = image['class']
        aligned_images.append(image_aligned)
    
    except Exception as e:
        print(f"Failed to align image {image['fname_base']} (Class: {image['class']})")
        num_failed_alignments+=1
        continue

print(f"Failed Alignments: {num_failed_alignments}, Percentage: {100 * num_failed_alignments / len(aligned_images):.2f}%")

Failed Alignments: 0, Percentage: 0.00%


In [15]:
# Saving aligned images is mandatory
num_threads = 16
with ThreadPoolExecutor(max_workers=num_threads) as executor:
    def save_one_image(image, output_dir):
        save_images([image], output_dir, scale=1)

    fpath_out = os.path.join(cwd, aligned_dataset_dir)
    futures = [executor.submit(save_one_image, image, fpath_out) for image in aligned_images]    

    for future in as_completed(futures):
        future.result()