# Image chipping for all RGB images


In [1]:
import os
import glob
import numpy as np
import tifffile
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from tqdm.notebook import tqdm  # Use tqdm.notebook for Jupyter
from functools import partial
import re 

def read_tifpngjpg(file):
    lower_file = file.lower()
    if lower_file.endswith(('.png', '.jpg', '.jpeg')):
        with Image.open(file) as img:
            return np.array(img)
    elif lower_file.endswith('.tif'):
        return tifffile.imread(file)

def clean_filename(filename: str):
    """Remove any extra '.tif' or '.TIF' extensions from the filename."""
    return re.sub(r'(\.tif)+$', '', filename, flags=re.IGNORECASE)


def process_image(file, output_dir=None, chip_size= 512, i =0, input_folder= None, out = None):
    tif_path = file
    raw_filename = os.path.basename(file)
    filedir = os.path.dirname(file)
    filename = clean_filename(os.path.splitext(raw_filename)[0]) 
    
    out_chip_folders = os.path.normpath(file).strip(input_folder).split(os.sep)
    path_identifier = '_'.join(out_chip_folders[4:-1])           


    pattern = os.path.join(output_dir, f"{path_identifier}_{filename}")
    
    existing_chips = [f for f in out if f.startswith(pattern)]
    if len(existing_chips) > 0:
        return file, False
    

    # Read the 4-channel image with tifffile
    img = read_tifpngjpg(file)
    # Expected shape: (height, width, channels) or (channels, height, width)

    # If channels are first, transpose to (H, W, C)
    # so we end up with shape: (height, width, 3)
    if img.shape[0] == 3 and (img.shape[1] != 3 and img.shape[2] != 3):
        img = np.transpose(img, (1, 2, 0))
        
    # Ensure uint16
    if img.dtype != np.uint8:
        img = img.astype(np.uint8)

    if len(img.shape)!= 3:
        return file, False

    height, width, channels = img.shape

    # How many full patches along each dimension?
    n_rows = height // chip_size
    n_cols = width // chip_size

    # If the image is too small for even one 224x224 patch, skip
    if n_rows == 0 or n_cols == 0:
        return file, False


    # Total covered area
    covered_height = n_rows * chip_size
    covered_width = n_cols * chip_size

    # Offsets to center the grid
    offset_row = (height - covered_height) // 2
    offset_col = (width - covered_width) // 2

    # Nested progress bars for row and column tiling
    for row_idx in range(n_rows):
        for col_idx in range(n_cols):
            row_start = offset_row + row_idx * chip_size
            col_start = offset_col + col_idx * chip_size
            
            out_chip_name = f"{path_identifier}_{filename}_{row_start}_{col_start}.jpg"
            out_chip_path = os.path.join(output_dir, out_chip_name)

            if os.path.exists(out_chip_path):
                return file, False
        


            # Extract the chip
            chip = img[row_start:row_start + chip_size,
                       col_start:col_start + chip_size, :]

            # Output name: originalfilename_row_col.tif

                                
            rgb_im = Image.fromarray(chip)
            if rgb_im.mode in ("RGBA", "P"): rgb_im = rgb_im.convert("RGB")
                
            rgb_im.save(out_chip_path, quality=95)
            
    
    return file, True  # Success
    

def collect_all_images(input_folder):
    all_image_files = []
    for root, _, _ in os.walk(input_folder):
        patterns = ["*.tif", "*.TIF", "*.jpg", "*.JPG", "*.png", "*.PNG"]
        for pattern in patterns:
            all_image_files.extend(glob.glob(os.path.join(root, pattern)))
    return all_image_files

def parallel_process_images(input_folder, output_dir=None, max_workers=8, chip_size=512):
    # Get all image files
    all_image_files = collect_all_images(input_folder)
    total_files = len(all_image_files)
    print(f"Found {total_files} image files")
    out = os.listdir(output_dir)
    
    # Define a wrapper function that handles exceptions
    def safe_process_image(file):
        try:
            return process_image(file, output_dir=output_dir, chip_size=chip_size, input_folder=input_folder, out=out)
        except Exception as e:
            print(f"Error processing {file}: {str(e)}")
            # Return a failed result
            return (file, False)
    
    # Process all files with ThreadPoolExecutor
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create a list of futures using the safe wrapper
        futures = [executor.submit(safe_process_image, file) for file in all_image_files]
        
        # Process with progress bar and timeout
        for future in tqdm(futures, total=len(futures), desc="Processing images", unit="file"):
            try:
                # Add timeout to prevent hanging on a single file
                result = future.result(timeout=300)  # 5-minute timeout per image
                results.append(result)
            except TimeoutError:
                print(f"Processing timed out for a file")
                results.append((None, False))
            except Exception as e:
                print(f"Unexpected error in thread: {str(e)}")
                results.append((None, False))
    
    # Report on results
    successes = sum(1 for _, status in results if status is True)
    print(f"Successfully processed {successes} of {total_files} files")
    
    # Return results for further analysis if needed
    return results

def process_images_sequential(input_folder, output_dir=None, chip_size=512, max_workers =1):
    # Get all image files
    all_image_files = collect_all_images(input_folder)
    total_files = len(all_image_files)
    print(f"Found {total_files} image files")
    out = os.listdir(output_dir)
    # Process all files sequentially with a progress bar
    results = []
    for file in tqdm(all_image_files, desc="Processing images", unit="file"):
        result = process_image(file, output_dir=output_dir, chip_size=chip_size, input_folder=input_folder,out=out)
        results.append(result)
    
    # Report on results
    successes = sum(1 for _, status in results if status is True)
    print(f"Successfully processed {successes} of {total_files} files")
    
    # Return results for further analysis if needed
    return results


In [2]:
# Create a processed folder if it doesn't exist
processed_dir = "../../data/output_multi/RGB2"
os.makedirs(processed_dir, exist_ok=True)
input_folder = "../../data/RGB_pretraining"
parallel_process_images(
    input_folder = input_folder,
    output_dir = processed_dir,
    chip_size = 512,
    max_workers = 8
)


Found 122250 image files


Processing images:   0%|          | 0/122250 [00:00<?, ?file/s]

Successfully processed 4 of 122250 files


[('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0061.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0062.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0063.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0064.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0065.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0066.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0067.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0068.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UAV images\\data2018\\15m\\DJI_0069.JPG',
  False),
 ('../../data/RGB_pretraining\\RGB_apple_orchard\\01_UA

In [9]:
import os
import glob
from PIL import Image
import numpy as np

def is_mostly_mono(image_path, threshold=0.1, check_white=True, check_black=True):
    """
    Check if an image contains more than threshold percentage of white or black pixels.
    
    Args:
        image_path: Path to the image file
        threshold: Threshold percentage (0.0 to 1.0)
        check_white: Whether to check for white pixels
        check_black: Whether to check for black pixels
    
    Returns:
        True if image has more than threshold percentage of white or black pixels
    """
    try:
        # Open the image
        img = Image.open(image_path)
        
        # Convert to numpy array for faster processing
        img_array = np.array(img)
        
        # Calculate total number of pixels
        total_pixels = img_array.shape[0] * img_array.shape[1]
        
        # Count white pixels (255, 255, 255)
        white_pixels = 0
        if check_white:
            white_pixels = np.sum(np.all(img_array == [255, 255, 255], axis=2))
            white_percentage = white_pixels / total_pixels
            if white_percentage > threshold:
                return True
        
        # Count black pixels (0, 0, 0)
        black_pixels = 0
        if check_black:
            black_pixels = np.sum(np.all(img_array == [0, 0, 0], axis=2))
            black_percentage = black_pixels / total_pixels
            if black_percentage > threshold:
                return True
        
        return False
    
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return False

def delete_mono_images(folder_path, threshold=0.1, check_white=True, check_black=True, dry_run=True):
    """
    Delete images in folder that have more than threshold percentage of white or black pixels.
    
    Args:
        folder_path: Path to folder containing images
        threshold: Threshold percentage (0.0 to 1.0)
        check_white: Whether to check for white pixels
        check_black: Whether to check for black pixels
        dry_run: If True, just print files to be deleted without actually deleting
    
    Returns:
        Number of files deleted
    """
    # Supported image extensions
    extensions = ['*.jpg']
    
    files_to_delete = []
    
    # Process each image file
    for ext in extensions:
        image_pattern = os.path.join(folder_path, ext)
        for image_path in tqdm(glob.glob(image_pattern), desc="Processing images", unit="file"):
            if is_mostly_mono(image_path, threshold, check_white, check_black):
                files_to_delete.append(image_path)
    
    # Delete files or just print them
    if dry_run:
        print(f"Would delete {len(files_to_delete)} files:")
        for file in files_to_delete:
            print(f"  {file}")
    else:
        print(f"Deleting {len(files_to_delete)} files:")
        for file in files_to_delete:
            print(f"  {file}")
            os.remove(file)
    
    return len(files_to_delete)    



In [None]:
# Run with parameters
source = "../../data/output_multi/RGB2"

count = delete_mono_images(
    source, 
    threshold=0.1,
    check_white=True,
    check_black=True,
    dry_run=False
)

Processing images:   0%|          | 0/443017 [00:00<?, ?file/s]

In [None]:
import os
import random
import shutil
from collections import defaultdict
import re

def sample_images_by_origin(source_dir, target_dir, target_sample_size=100000):
    """
    Sample images from a source directory to create a smaller representative dataset.
    
    Args:
        source_dir (str): Directory containing original images
        target_dir (str): Directory where sampled images will be copied
        target_sample_size (int): Desired size of the final dataset
    """
    # Create target directory if it doesn't exist
    os.makedirs(target_dir, exist_ok=True)
    
    # Group files by origin
    origin_to_files = defaultdict(list)
    
    # Pattern to extract origin from filenames like ORIGIN_filename.jpg
    pattern = re.compile(r'^([^_]+)_.*$')
    
    # Scan source directory and group files by origin
    for filename in os.listdir(source_dir):
        if not filename.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
            continue
        
        match = pattern.match(filename)
        if match:
            origin = match.group(1)
            origin_to_files[origin].append(filename)
    
    # Calculate how many images to sample from each origin
    total_origins = len(origin_to_files)
    total_files = sum(len(files) for files in origin_to_files.values())
    
    print(f"Found {total_files} images across {total_origins} different origins")
    
    # Determine sampling strategy
    samples_per_origin = {}
    
    if total_origins > 0:
        # Option 1: Equal sampling (same number from each origin)
        # samples_per_origin = {origin: target_sample_size // total_origins for origin in origin_to_files}
        
        # Option 2: Proportional sampling (based on original distribution)
        for origin, files in origin_to_files.items():
            proportion = len(files) / total_files
            samples_per_origin[origin] = max(1, int(proportion * target_sample_size))
        
        # Adjust to hit target exactly
        current_total = sum(samples_per_origin.values())
        if current_total != target_sample_size:
            # Sort origins by number of files (descending)
            sorted_origins = sorted(origin_to_files.keys(), 
                                   key=lambda o: len(origin_to_files[o]), 
                                   reverse=True)
            
            diff = target_sample_size - current_total
            # Add or remove from largest origins first
            for origin in sorted_origins:
                if diff == 0:
                    break
                elif diff > 0:
                    # Need to add more samples
                    available = len(origin_to_files[origin]) - samples_per_origin[origin]
                    adjustment = min(diff, available)
                    samples_per_origin[origin] += adjustment
                    diff -= adjustment
                else:
                    # Need to remove samples
                    adjustment = min(abs(diff), samples_per_origin[origin] - 1)
                    samples_per_origin[origin] -= adjustment
                    diff += adjustment
    
    # Sample and copy files
    sampled_count = 0
    for origin, num_samples in samples_per_origin.items():
        files = origin_to_files[origin]
        
        # Cap the samples to available files
        actual_samples = min(num_samples, len(files))
        
        # Randomly sample without replacement
        selected_files = random.sample(files, actual_samples)
        
        # Copy selected files to target directory
        for filename in selected_files:
            source_path = os.path.join(source_dir, filename)
            target_path = os.path.join(target_dir, filename)
            shutil.copy2(source_path, target_path)
            sampled_count += 1
        
        print(f"Sampled {actual_samples} images from origin '{origin}'")
    
    print(f"Total images sampled: {sampled_count}")
    return sampled_count

In [None]:
# Create a processed folder if it doesn't exist
sampled_dir = "../../data/output_multi/RGB"
os.makedirs(sampled_dir, exist_ok=True)
source = "../../data/output_multi/RGB2"

sample_images_by_origin(source, sampled_dir, , 100_000)