In [5]:
# Configuration Variables - Modify these as needed
INPUT_DIR = "/cluster/home/akmarala/data/TEXMET"
OUTPUT_DIR = "/cluster/home/akmarala/data/TEXMET_processed"
MAX_SIZE = 2048
JPEG_QUALITY = 90
VALIDATE_ONLY = False  # Set to True to only run validation

# ============================================================================
# TEXMET Dataset Preprocessing Script
# ============================================================================

import os
import shutil
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import json

def get_image_info(image_path):
    """Get image dimensions and file size"""
    try:
        with Image.open(image_path) as img:
            return {
                'size': img.size,
                'mode': img.mode,
                'format': img.format,
                'file_size': os.path.getsize(image_path)
            }
    except Exception as e:
        return {'error': str(e)}

def preprocess_texmet_dataset(input_dir, output_dir, max_size=2048, quality=90):
    """
    Preprocess TEXMET dataset by resizing large images
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    
    # Statistics tracking
    stats = {
        'total_processed': 0,
        'resized_count': 0,
        'copied_count': 0,
        'error_count': 0,
        'splits': {},
        'size_distribution': {'before': {}, 'after': {}},
        'errors': []
    }
    
    print(f"Preprocessing TEXMET dataset...")
    print(f"Input:  {input_dir}")
    print(f"Output: {output_dir}")
    print(f"Max size: {max_size}px")
    print(f"JPEG quality: {quality}")
    print("-" * 50)
    
    # Process each split
    for split in ['train', 'val', 'test']:
        split_input_dir = input_path / split / 'images'
        split_output_dir = output_path / split / 'images'
        
        if not split_input_dir.exists():
            print(f"⚠️  {split} split not found, skipping...")
            continue
            
        # Create output directory
        split_output_dir.mkdir(parents=True, exist_ok=True)
        
        # Get all image files
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        image_files = []
        for ext in image_extensions:
            image_files.extend(list(split_input_dir.glob(f'*{ext}')))
            image_files.extend(list(split_input_dir.glob(f'*{ext.upper()}')))
        
        print(f"Processing {len(image_files)} images in {split} split...")
        
        split_stats = {
            'total': len(image_files),
            'resized': 0,
            'copied': 0,
            'errors': 0
        }
        
        for img_file in tqdm(image_files, desc=f"{split} split"):
            try:
                output_file = split_output_dir / f"{img_file.stem}.jpg"
                
                # Get original image info
                original_info = get_image_info(img_file)
                if 'error' in original_info:
                    stats['errors'].append(f"{img_file}: {original_info['error']}")
                    split_stats['errors'] += 1
                    continue
                
                with Image.open(img_file) as img:
                    img = img.convert('RGB')
                    original_size = img.size
                    
                    # Check if image needs resizing
                    if img.size[0] > max_size or img.size[1] > max_size:
                        # Resize maintaining aspect ratio
                        img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
                        split_stats['resized'] += 1
                        
                        # Track size changes
                        size_key = f"{original_size[0]}x{original_size[1]}"
                        new_size_key = f"{img.size[0]}x{img.size[1]}"
                        
                        if size_key not in stats['size_distribution']['before']:
                            stats['size_distribution']['before'][size_key] = 0
                        stats['size_distribution']['before'][size_key] += 1
                        
                        if new_size_key not in stats['size_distribution']['after']:
                            stats['size_distribution']['after'][new_size_key] = 0
                        stats['size_distribution']['after'][new_size_key] += 1
                        
                    else:
                        split_stats['copied'] += 1
                    
                    # Save image as JPEG
                    img.save(output_file, 'JPEG', quality=quality, optimize=True)
                    
            except Exception as e:
                error_msg = f"{img_file}: {str(e)}"
                stats['errors'].append(error_msg)
                split_stats['errors'] += 1
                print(f"❌ Error: {error_msg}")
        
        # Update global stats
        stats['total_processed'] += split_stats['total']
        stats['resized_count'] += split_stats['resized']
        stats['copied_count'] += split_stats['copied']
        stats['error_count'] += split_stats['errors']
        stats['splits'][split] = split_stats
        
        print(f"✅ {split} split complete:")
        print(f"   Total: {split_stats['total']}")
        print(f"   Resized: {split_stats['resized']}")
        print(f"   Copied: {split_stats['copied']}")
        print(f"   Errors: {split_stats['errors']}")
        print()
    
    return stats

def validate_processed_dataset(input_dir, output_dir, max_size=2048):
    """
    Validate the processed dataset
    """
    print("🔍 Validating processed dataset...")
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    
    validation_results = {
        'splits': {},
        'size_violations': [],
        'missing_files': [],
        'extra_files': []
    }
    
    for split in ['train', 'val', 'test']:
        input_split_dir = input_path / split / 'images'
        output_split_dir = output_path / split / 'images'
        
        if not input_split_dir.exists():
            continue
            
        # Get file lists
        input_files = set([f.stem for f in input_split_dir.glob('*') if f.is_file()])
        output_files = set([f.stem for f in output_split_dir.glob('*') if f.is_file()])
        
        # Check for missing/extra files
        missing = input_files - output_files
        extra = output_files - input_files
        
        if missing:
            validation_results['missing_files'].extend([f"{split}/{f}" for f in missing])
        if extra:
            validation_results['extra_files'].extend([f"{split}/{f}" for f in extra])
        
        # Check image sizes
        oversized_images = []
        for img_file in output_split_dir.glob('*.jpg'):
            try:
                with Image.open(img_file) as img:
                    if img.size[0] > max_size or img.size[1] > max_size:
                        oversized_images.append({
                            'file': str(img_file),
                            'size': img.size
                        })
            except Exception as e:
                print(f"❌ Error checking {img_file}: {e}")
        
        if oversized_images:
            validation_results['size_violations'].extend(oversized_images)
        
        validation_results['splits'][split] = {
            'input_count': len(input_files),
            'output_count': len(output_files),
            'missing_count': len(missing),
            'extra_count': len(extra),
            'oversized_count': len(oversized_images)
        }
    
    return validation_results

# ============================================================================
# Main Execution
# ============================================================================

if not VALIDATE_ONLY:
    # Preprocess the dataset
    stats = preprocess_texmet_dataset(INPUT_DIR, OUTPUT_DIR, MAX_SIZE, JPEG_QUALITY)
    
    # Save processing stats
    stats_file = Path(OUTPUT_DIR) / 'processing_stats.json'
    with open(stats_file, 'w') as f:
        json.dump(stats, f, indent=2)
    
    print("📊 Processing Summary:")
    print(f"Total images processed: {stats['total_processed']}")
    print(f"Images resized: {stats['resized_count']}")
    print(f"Images copied: {stats['copied_count']}")
    print(f"Errors: {stats['error_count']}")
    print(f"Stats saved to: {stats_file}")
    print()

# Validate the processed dataset
validation = validate_processed_dataset(INPUT_DIR, OUTPUT_DIR, MAX_SIZE)

print("✅ Validation Results:")
for split, results in validation['splits'].items():
    print(f"{split} split:")
    print(f"  Input files: {results['input_count']}")
    print(f"  Output files: {results['output_count']}")
    print(f"  Missing files: {results['missing_count']}")
    print(f"  Extra files: {results['extra_count']}")
    print(f"  Oversized images: {results['oversized_count']}")
    
    if results['input_count'] == results['output_count']:
        print(f"  ✅ File count matches!")
    else:
        print(f"  ❌ File count mismatch!")

# Report any issues
if validation['missing_files']:
    print(f"\n❌ Missing files ({len(validation['missing_files'])}):")
    for f in validation['missing_files'][:10]:  # Show first 10
        print(f"  {f}")
    if len(validation['missing_files']) > 10:
        print(f"  ... and {len(validation['missing_files']) - 10} more")

if validation['extra_files']:
    print(f"\n⚠️  Extra files ({len(validation['extra_files'])}):")
    for f in validation['extra_files'][:10]:
        print(f"  {f}")
    if len(validation['extra_files']) > 10:
        print(f"  ... and {len(validation['extra_files']) - 10} more")

if validation['size_violations']:
    print(f"\n❌ Oversized images ({len(validation['size_violations'])}):")
    for img in validation['size_violations'][:5]:
        print(f"  {img['file']}: {img['size']}")
    if len(validation['size_violations']) > 5:
        print(f"  ... and {len(validation['size_violations']) - 5} more")

if (not validation['missing_files'] and 
    not validation['extra_files'] and 
    not validation['size_violations']):
    print("\n🎉 Dataset preprocessing completed successfully!")
    print("All files processed, no size violations, file counts match!")
    print(f"\nYou can now use the processed dataset:")
    print(f"--data_path {OUTPUT_DIR}")

Preprocessing TEXMET dataset...
Input:  /cluster/home/akmarala/data/TEXMET
Output: /cluster/home/akmarala/data/TEXMET_processed
Max size: 2048px
JPEG quality: 90
--------------------------------------------------
Processing 14915 images in train split...


train split:  74%|███████▍  | 11096/14915 [29:11<10:02,  6.33it/s] 


KeyboardInterrupt: 