## 1 Importing Libs and Configuring Parameters

In [1]:
import os
import sys
import shutil
import numpy as np
from PIL import Image
import glob
from sklearn.model_selection import train_test_split
from osgeo import gdal
from tqdm.notebook import tqdm

Image.MAX_IMAGE_PIXELS = None

DATA_FOLDER = "cloud_removal_dataset_california"
OUTPUT_DATASET = "dataset"

TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15

MIN_IMAGE_SIZE = 512
MAX_NODATA_RATIO_IN_CROP = 0.05

## 2 Define the processing function

In [2]:
def load_tif_as_array(file_path):
    dataset = gdal.Open(file_path)
    if dataset is None:
        raise ValueError(f"Cannot open {file_path}")
    
    bands = []
    for i in range(1, min(dataset.RasterCount + 1, 4)):
        band = dataset.GetRasterBand(i)
        band_array = band.ReadAsArray()
        bands.append(band_array)
    
    image = np.stack(bands, axis=-1)
    return image

def create_common_valid_mask(img1, img2, threshold=0):
    if img1.ndim == 3:
        mask1 = np.any(img1 > threshold, axis=2)
    else:
        mask1 = img1 > threshold
    
    if img2.ndim == 3:
        mask2 = np.any(img2 > threshold, axis=2)
    else:
        mask2 = img2 > threshold
    
    return mask1 & mask2

def crop_to_common_valid_region(img1, img2):
    common_mask = create_common_valid_mask(img1, img2, threshold=0)
    
    if not np.any(common_mask):
        return None, None, "no_valid_pixels"
    
    rows = np.any(common_mask, axis=1)
    cols = np.any(common_mask, axis=0)
    
    if not np.any(rows) or not np.any(cols):
        return None, None, "no_valid_rows_cols"
    
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    
    cropped1 = img1[rmin:rmax+1, cmin:cmax+1]
    cropped2 = img2[rmin:rmax+1, cmin:cmax+1]
    
    crop_height, crop_width = cropped1.shape[0], cropped1.shape[1]
    
    if crop_height < MIN_IMAGE_SIZE or crop_width < MIN_IMAGE_SIZE:
        return None, None, f"too_small_{crop_height}x{crop_width}"
    
    cropped_mask = common_mask[rmin:rmax+1, cmin:cmax+1]
    valid_pixels_in_crop = np.sum(cropped_mask)
    total_pixels_in_crop = cropped_mask.size
    nodata_ratio = 1 - (valid_pixels_in_crop / total_pixels_in_crop)
    
    if nodata_ratio > MAX_NODATA_RATIO_IN_CROP:
        return None, None, f"too_much_nodata_{nodata_ratio:.2f}"
    
    return cropped1, cropped2, "success"

def normalize_image_pair(cloudy_img, clear_img):
    combined = np.concatenate([cloudy_img, clear_img], axis=0)
    
    valid_mask = np.any(combined > 0, axis=2)
    valid_pixels = combined[valid_mask]
    
    p_low = np.percentile(valid_pixels, 2)
    p_high = np.percentile(valid_pixels, 98)
    
    cloudy_norm = np.clip((cloudy_img - p_low) / (p_high - p_low) * 255, 0, 255).astype(np.uint8)
    clear_norm = np.clip((clear_img - p_low) / (p_high - p_low) * 255, 0, 255).astype(np.uint8)
    
    return cloudy_norm, clear_norm

## 3 data preprocessing

In [3]:
from joblib import Parallel, delayed
import multiprocessing

train_files = glob.glob(os.path.join(OUTPUT_DATASET, "train", "*.png"))
val_files = glob.glob(os.path.join(OUTPUT_DATASET, "val", "*.png"))
test_files = glob.glob(os.path.join(OUTPUT_DATASET, "test", "*.png"))

if train_files or val_files or test_files:
    print(f"Dataset already exists:")
    print(f"  Train: {len(train_files)} images")
    print(f"  Val: {len(val_files)} images")
    print(f"  Test: {len(test_files)} images")
    print("\nSkipping preprocessing. Delete the 'dataset' folder to reprocess.")
else:
    print("No existing dataset found. Starting preprocessing...\n")
    
    os.makedirs(os.path.join(OUTPUT_DATASET, "train"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DATASET, "val"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DATASET, "test"), exist_ok=True)
    
    clear_files = sorted(glob.glob(os.path.join(DATA_FOLDER, "clear_*.tif")))
    cloudy_files = sorted(glob.glob(os.path.join(DATA_FOLDER, "cloudy_*.tif")))
    
    print(f"Found {len(clear_files)} clear and {len(cloudy_files)} cloudy images")
    
    pair_indices = []
    skip_reasons = {}
    
    print("Phase 1: Validating pairs...")
    for i in tqdm(range(min(len(clear_files), len(cloudy_files)))):
        clear_path = os.path.join(DATA_FOLDER, f"clear_{i}.tif")
        cloudy_path = os.path.join(DATA_FOLDER, f"cloudy_{i}.tif")
        
        if not (os.path.exists(clear_path) and os.path.exists(cloudy_path)):
            skip_reasons['file_not_found'] = skip_reasons.get('file_not_found', 0) + 1
            continue
        
        try:
            cloudy_raw = load_tif_as_array(cloudy_path)
            clear_raw = load_tif_as_array(clear_path)
            
            cloudy_cropped, clear_cropped, reason = crop_to_common_valid_region(cloudy_raw, clear_raw)
            
            if cloudy_cropped is None:
                skip_reasons[reason] = skip_reasons.get(reason, 0) + 1
                continue
            
            pair_indices.append(i)
                
        except Exception as e:
            skip_reasons['exception'] = skip_reasons.get('exception', 0) + 1
            continue
    
    print(f"Valid pairs: {len(pair_indices)}, Skipped: {sum(skip_reasons.values())}")
    
    if len(pair_indices) < 10:
        raise ValueError("Too few valid pairs for training")
    
    train_val_idx, test_idx = train_test_split(pair_indices, test_size=TEST_RATIO, random_state=42)
    relative_val_size = VAL_RATIO / (TRAIN_RATIO + VAL_RATIO)
    train_idx, val_idx = train_test_split(train_val_idx, test_size=relative_val_size, random_state=42)
    
    split_mapping = {}
    for idx in train_idx:
        split_mapping[idx] = 'train'
    for idx in val_idx:
        split_mapping[idx] = 'val'
    for idx in test_idx:
        split_mapping[idx] = 'test'
    
    print(f"Split: {len(train_idx)} train, {len(val_idx)} val, {len(test_idx)} test")
    
    print("\nPhase 2: Processing and saving images in parallel...")
    
    def process_single_pair(i, split, counter_start):
        try:
            clear_path = os.path.join(DATA_FOLDER, f"clear_{i}.tif")
            cloudy_path = os.path.join(DATA_FOLDER, f"cloudy_{i}.tif")
            
            cloudy_raw = load_tif_as_array(cloudy_path)
            clear_raw = load_tif_as_array(clear_path)
            
            cloudy_cropped, clear_cropped, _ = crop_to_common_valid_region(cloudy_raw, clear_raw)
            
            cloudy_norm, clear_norm = normalize_image_pair(cloudy_cropped, clear_cropped)
            
            combined = np.concatenate([cloudy_norm, clear_norm], axis=1)
            
            combined_pil = Image.fromarray(combined)
            output_path = os.path.join(OUTPUT_DATASET, split, f"pair_{counter_start:04d}.png")
            combined_pil.save(output_path)
            
            return True
            
        except Exception as e:
            return False
    
    splits_to_process = [
        ('train', [i for i in pair_indices if split_mapping[i] == 'train']),
        ('val', [i for i in pair_indices if split_mapping[i] == 'val']),
        ('test', [i for i in pair_indices if split_mapping[i] == 'test'])
    ]
    
    n_jobs = min(multiprocessing.cpu_count() - 1, 8)
    print(f"Using {n_jobs} parallel workers")
    
    total_saved = 0
    
    for split_name, indices in splits_to_process:
        print(f"\nProcessing {split_name} set ({len(indices)} images)...")
        
        results = Parallel(n_jobs=n_jobs)(
            delayed(process_single_pair)(idx, split_name, counter) 
            for counter, idx in enumerate(tqdm(indices, desc=split_name))
        )
        
        saved = sum(results)
        total_saved += saved
        print(f"  Saved {saved}/{len(indices)} images")
    
    print(f"\nTotal saved: {total_saved} images")

No existing dataset found. Starting preprocessing...

Found 856 clear and 860 cloudy images
Phase 1: Validating pairs...


  0%|          | 0/856 [00:00<?, ?it/s]

Valid pairs: 200, Skipped: 656
Split: 139 train, 31 val, 30 test

Phase 2: Processing and saving images in parallel...
Using 8 parallel workers

Processing train set (139 images)...


train:   0%|          | 0/139 [00:00<?, ?it/s]

  Saved 139/139 images

Processing val set (31 images)...


val:   0%|          | 0/31 [00:00<?, ?it/s]

  Saved 31/31 images

Processing test set (30 images)...


test:   0%|          | 0/30 [00:00<?, ?it/s]

  Saved 30/30 images

Total saved: 200 images
