In [31]:
import os
import io
import json
import time
import platform
import requests
import h5py
import psutil
import imagehash
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict
from datetime import datetime
from sklearn.model_selection import train_test_split, StratifiedGroupKFold, StratifiedKFold
from concurrent.futures import ThreadPoolExecutor
import albumentations as A
from torchvision import transforms

from pygbif import occurrences
from pyinaturalist.node_api import get_observations

In [32]:
DATA_DIR = "full_image_dataset"          # Expected input dir: species_name/*.jpg
AUGMENTED_DIR = "augmented_dataset"     # Augmented images will be saved here
IMG_SIZE_THRESHOLD = 200              # Min resolution (px)
HASH_THRESHOLD = 8                    # Duplicate threshold using phash

species_keys = {
    "Carduelis carduelis": 2494686,
    "Ciconia ciconia": 2481912,
    "Columba livia": 2495414,
    "Delichon urbicum": 2489214,
    "Emberiza calandra":7634625,
    "Hirundo rustica": 7192162,
    "Passer domesticus": 5231190,
    "Serinus serinus":2494200,
    "Streptopelia decaocto": 2495696,
    "Sturnus unicolor":2489104,
    "Turdus merula": 6171845   
}

CONFIG = {
    'IMG_SIZE': (224, 224),
    'TEST_SIZE': 0.15,
    'TRAIN_SIZE': 0.7,
    'VAL_SIZE': 0.15,
    'N_SPLITS': 5,
    'COMPRESSION': 'gzip',
    'COMPRESSION_LEVEL': 6,
    'SAVE_AS_JPEG': True,
    'JPEG_QUALITY': 80,
    'AUGMENTATION': {
        'train': [
            {'name': 'RandomResizedCrop','size':(224,224) , 'scale': (0.8, 1.0)},
            {'name': 'HorizontalFlip', 'p': 0.5},
            {'name': 'ShiftScaleRotate', 'shift_limit': 0.05, 'scale_limit': 0.1, 'rotate_limit': 20, 'p': 0.7},
            {'name': 'ColorJitter', 'brightness': 0.1, 'contrast': 0.1, 'saturation': 0.1, 'hue': 0.05, 'p': 0.8},
            {'name': 'CoarseDropout', 'max_holes':1, 'max_height': 48, 'max_width': 48, 'p': 0.4},
        ]
    }
}

In [33]:
def getSystemInfo():
    mem = psutil.virtual_memory()
    return {
        "timestamp": datetime.now().isoformat(),
        "os": platform.system(),
        "os_version": platform.release(),
        "cpu": platform.processor(),
        "cpu_cores": psutil.cpu_count(logical=False),
        "ram_total_gb": round(mem.total / (1024**3), 2),
        "ram_available_gb": round(mem.available / (1024**3), 2),
        "python_version": platform.python_version()
    }

def initLogging(output_dir):
    metadata = {
        "config": CONFIG,
        "system": getSystemInfo(),
        "download": {},
        "cleaning": {},
        "dataset_stats": {},
    }
    os.makedirs(output_dir, exist_ok=True)
    metadata_path = os.path.join(output_dir, f"dataset_prep_{datetime.now().strftime("%Y%m%d")}.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    return metadata_path

def updateLogging(metadata_path, updates):
    if not os.path.exists(metadata_path):
        return initLogging(os.path.dirname(metadata_path))
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    metadata.update(updates)
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)

In [34]:
def downloadImages(species_name, output_dir, limit=500, metadata_path=None):
    start_time = time.time()
    output_dir = os.path.join(DATA_DIR, species_name.replace(" ", "_"))
    os.makedirs(output_dir, exist_ok=True)
    if metadata_path is None:
        metadata_path = initLogging(DATA_DIR)
    
    print(f"\nDownloading images for: {species_name}")
    stats = {
        'iNaturalist': 0,
        'GBIF': 0,
        'start_time': datetime.now().strftime("%Y%m%d_%H%M%S")
    }
    
    try:
        #iNaturalist download
        stats['iNaturalist'] = downloadImages_INaturalist(species_name, output_dir, limit)
        
        # GBIF download
        current_count = stats['iNaturalist']
        stats['GBIF'] = downloadImages_GBIF(species_name, current_count, output_dir, limit - current_count)
        
    except Exception as e:
        print(f"Error during download: {e}")
    
    # Update metadata
    stats.update({
        'end_time': datetime.now().strftime("%Y%m%d_%H%M%S"),
        'total_downloaded': stats['iNaturalist'] + stats['GBIF'],
        'time_seconds': time.time() - start_time
    })
    
    updateLogging(metadata_path, {"download": {species_name: stats}})
    print(f"Total images downloaded for {species_name}: {stats['total_downloaded']}")
    return stats['total_downloaded']

def downloadImages_INaturalist(species_name, output_dir, limit=500):
    results = get_observations(
        taxon_name=species_name,
        per_page=limit,
        quality_grade="research",
        media_type="photo",
        license=["CC-BY","CC-BY-NC"] 
    )

    images_downloaded = 0
    seen_urls = set()

    for obs in tqdm(results.get("results", [])):
        for photo in obs.get("photos",[]):
            url = photo.get("url", "").replace("square", "original")
            if not url or url in seen_urls:
                continue
            seen_urls.add(url)
            try:
                response = requests.get(url, timeout=10)
                if response.status_code == 200:
                    img = Image.open(io.BytesIO(response.content)).convert('RGB')
                    image_ext = url.split(".")[-1].split("?")[0]
                    filename = f"{species_name.replace(' ', '_')}_{images_downloaded}.{image_ext}"
                    img.save(os.path.join(output_dir, filename))
                    images_downloaded += 1
            except Exception as e:
                print(f"Error: {e}")

            if images_downloaded >= limit:
                break
        if images_downloaded >= limit:
            break

    print(f"Downloaded {images_downloaded} images from iNaturalist for {species_name}")
    return images_downloaded

def downloadImages_GBIF(species_name, downloadedValue, output_dir, limit=500):
    results = occurrences.search(
            taxonKey=species_keys[species_name],
            mediaType="StillImage",
            limit=limit
        )

    images_downloaded = 0
    seen_urls = set()

    for obs in tqdm(results.get("results", [])):
        for media in obs.get("media",[]):
            url = media.get("identifier")
            if not url or url in seen_urls:
                continue
            seen_urls.add(url)
            try:
                response = requests.get(url, timeout=10)
                if response.status_code == 200:
                    img = Image.open(io.BytesIO(response.content)).convert('RGB')
                    image_ext = url.split(".")[-1].split("?")[0]
                    filename = f"{species_name.replace(' ', '_')}_{downloadedValue + images_downloaded}.{image_ext}"
                    img.save(os.path.join(output_dir, filename))
                    images_downloaded += 1
            except Exception as e:
                print(f"Error: {e}")

            if images_downloaded >= limit:
                break
        if images_downloaded >= limit:
            break   
    print(f"\nDownloaded {images_downloaded} images from GBIF for {species_name}")
    return images_downloaded

In [35]:
def getAugmentation():
    aug_config = CONFIG['AUGMENTATION']['train']
    return A.Compose([
        A.RandomResizedCrop(
            size=aug_config[0]['size'],
            scale=aug_config[0]['scale'],
        ),
        A.HorizontalFlip(p=aug_config[1]['p']),
        A.ShiftScaleRotate(
            shift_limit=aug_config[2]['shift_limit'],
            scale_limit=aug_config[2]['scale_limit'],
            rotate_limit=aug_config[2]['rotate_limit'],
            p=aug_config[2]['p']
        ),
        A.ColorJitter(
            brightness=aug_config[3]['brightness'],
            contrast=aug_config[3]['contrast'],
            saturation=aug_config[3]['saturation'],
            hue=aug_config[3]['hue'],
            p=aug_config[3]['p']
        ),
        A.CoarseDropout(
            max_holes=aug_config[4]['max_holes'],
            max_height=aug_config[4]['max_height'],
            max_width=aug_config[4]['max_width'],
            p=aug_config[4]['p']
        )
    ])

def processImage(img_path, output_dir, transform, save_augmented=True):
    """Process and save a single image with augmentation"""
    try:
        img = Image.open(img_path).convert("RGB")
        img_np = np.array(img)
        
        # Apply augmentation
        augmented = transform(image=img_np)['image']
        
        if save_augmented:
            # Save augmented image
            aug_name = f"{Path(img_path).stem}_aug.jpg"
            aug_path = os.path.join(output_dir, aug_name)
            Image.fromarray(augmented).save(aug_path, quality=CONFIG['JPEG_QUALITY'], optimize=True)
            return True
        return True
    except Exception as e:
        print(f"Error processing {img_path}: {e}")
        return False


def transformImagesFromDirectory(species_name, data_dir, metadata_path=None, save_augmented=True):
    start_time = time.time()
    species_dir = os.path.join(data_dir, species_name.replace(" ", "_"))
    if metadata_path is None:
        metadata_path = initLogging(data_dir)
    stats = {
        'species': species_name,
        'original_count': 0,
        'augmented_saved': 0,
        'start_time': datetime.now().strftime('%Y%m%d_%H%M%S'),
    }

    # Create output directory
    if save_augmented:
        output_dir = os.path.join(data_dir, AUGMENTED_DIR, species_name.replace(" ", "_"))
        os.makedirs(output_dir, exist_ok=True)
    else:
        output_dir = species_dir

    # Get augmentation pipeline
    transform = getAugmentation()

    # Process images in parallel
    image_paths = [os.path.join(species_dir, f) for f in os.listdir(species_dir) 
                  if os.path.isfile(os.path.join(species_dir, f)) and not f.endswith(".json")]
    
    stats['original_count'] = len(image_paths)

    with ThreadPoolExecutor(max_workers=4) as executor:
        results = list(tqdm(
            executor.map(
                lambda p: processImage(p, output_dir, transform, save_augmented),
                image_paths
            ),
            total=len(image_paths),
            desc=f"Augmenting {species_name}"
        ))
    
    stats['augmented_saved'] = sum(results)
    stats.update({
        'end_time': datetime.now().isoformat(),
        'time_seconds': time.time() - start_time
    })

    updateLogging(metadata_path, {"augmentation": {species_name: stats}})
    return stats['augmented_saved']


In [36]:
def isValidImage(path):
    try:
        img = Image.open(path).convert("RGB")
        return min(img.size) >= IMG_SIZE_THRESHOLD
    except Exception as e:
        print(f"Error processing {path}: {e}")
        return False

def getPhash(path):
    try:
        img = Image.open(path).convert("RGB")
        return imagehash.phash(img)
    except Exception as e:
        print(f"Error generating hash for {path}: {e}")
        return None

def cleanData(species_name, dir, metadata_path=None):
    start_time = time.time()
    if metadata_path is None:
        metadata_path = initLogging(DATA_DIR)

    species_path = os.path.join(dir, species_name.replace(" ", "_"))
    hash_db = []
    stats = {
        'removed': 0,
        'remaining': 0,
        'duplicates': 0,
        'invalid': 0
    }

    # Process images in parallel
    image_paths = list(Path(species_path).glob("*.*"))
    
    with ThreadPoolExecutor(max_workers=4) as executor:
        results = list(tqdm(
            executor.map(
                lambda p: (p, isValidImage(p), getPhash(p)),
                image_paths
            ),
            total=len(image_paths),
            desc=f"Cleaning {species_name}"
        ))

    # Process results
    for img_path, is_valid, phash in results:
        if not is_valid:
            os.remove(img_path)
            stats['invalid'] += 1
            stats['removed'] += 1
        elif phash is None:
            os.remove(img_path)
            stats['removed'] += 1
        elif any(phash - existing < HASH_THRESHOLD for existing in hash_db):
            os.remove(img_path)
            stats['duplicates'] += 1
            stats['removed'] += 1
        else:
            hash_db.append(phash)
            stats['remaining'] += 1

    stats.update({
        'time_seconds': time.time() - start_time,
        'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S")
    })
    updateLogging(metadata_path, {"cleaning": {species_name: stats}})
    return stats


In [None]:

def createDataset(metadata_path=None):
    start_time = time.time()
    
    # Initialize log if not provided
    if metadata_path is None:
        metadata_path = initLogging(DATA_DIR)
    
    # Data collection structures
    images = []
    labels = []
    species_counts = defaultdict(int)
    for species_idx, (species_name, _) in enumerate(species_keys.items()):
        species_dir = os.path.join(DATA_DIR, AUGMENTED_DIR, species_name.replace(" ", "_"))
        if not os.path.exists(species_dir):
            continue

        for img_name in os.listdir(species_dir):
            img_path = os.path.join(species_dir, img_name)
            try:
                img = Image.open(img_path).convert('RGB').resize(CONFIG['IMG_SIZE'])
                images.append(np.array(img))
                labels.append(species_idx)
                species_counts[species_name] += 1
            except:
                continue

    # Convert to numpy arrays for HDF5
    X = np.array(images)
    y = np.array(labels)
    
    # 1. First split: 70% train, 30% temp (val+test)
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, 
        test_size=0.3,  # 30% for val + test
        stratify=y,
        random_state=42
    )

    # 2. Second split: 15% val, 15% test
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=0.5,  # Split 30% into equal 15% parts
        stratify=y_temp,
        random_state=42
    )
    
    
    # Add fold column for cross-validation
    skf = StratifiedKFold(
        n_splits=CONFIG['N_SPLITS'], 
        shuffle=True,
        random_state=42
    )

    # Create HDF5 dataset
    timestamp = datetime.now().strftime("%Y%m%d")
    h5_path = os.path.join(DATA_DIR, f"dataset_{timestamp}.h5")
    
    with h5py.File(h5_path, 'w') as hf:
        # Test set
        test_group = hf.create_group('test')
        test_group.create_dataset('X_test', data=X_test, 
                                compression=CONFIG['COMPRESSION'], 
                                compression_opts=CONFIG['COMPRESSION_LEVEL'])
        test_group.create_dataset('y_test', data=y_test,
                                compression=CONFIG['COMPRESSION'],
                                compression_opts=CONFIG['COMPRESSION_LEVEL'])
        
        # Train set
        train_group = hf.create_group('train')
        train_group.create_dataset('X_train', data=X_train,
                                compression=CONFIG['COMPRESSION'],
                                compression_opts=CONFIG['COMPRESSION_LEVEL'])
        train_group.create_dataset('y_train', data=y_train,
                                compression=CONFIG['COMPRESSION'],
                                compression_opts=CONFIG['COMPRESSION_LEVEL'])

        # Validation set
        val_group = hf.create_group('val')
        val_group.create_dataset('X_val', data=X_val,
                                compression=CONFIG['COMPRESSION'],
                                compression_opts=CONFIG['COMPRESSION_LEVEL'])
        val_group.create_dataset('y_val', data=y_val,
                                compression=CONFIG['COMPRESSION'],
                                compression_opts=CONFIG['COMPRESSION_LEVEL'])

        # Cross-validation splits
        cv_group = hf.create_group('cross_validation')
        for fold, (train_idx, val_idx) in enumerate(skf.split(X_train, y_train)):
            fold_group = cv_group.create_group(f'fold_{fold + 1}')
            fold_group.create_dataset('X_train', data=X_train[train_idx],
                                    compression=CONFIG['COMPRESSION'],
                                    compression_opts=CONFIG['COMPRESSION_LEVEL'])
            fold_group.create_dataset('y_train', data=y_train[train_idx],
                                    compression=CONFIG['COMPRESSION'],
                                    compression_opts=CONFIG['COMPRESSION_LEVEL'])
            fold_group.create_dataset('X_val', data=X_train[val_idx],
                                    compression=CONFIG['COMPRESSION'],
                                    compression_opts=CONFIG['COMPRESSION_LEVEL'])
            fold_group.create_dataset('y_val', data=y_train[val_idx],
                                    compression=CONFIG['COMPRESSION'],
                                    compression_opts=CONFIG['COMPRESSION_LEVEL'])

        # Save metadata
        hf.attrs['species'] = json.dumps(list(species_keys.keys()))
        hf.attrs['image_size'] = json.dumps(CONFIG['IMG_SIZE'])
        hf.attrs['augmentation'] = json.dumps(CONFIG['AUGMENTATION'])
        hf.attrs['creation_time'] = timestamp
    
    # Update metadata log
    dataset_stats = {
        'total_images': len(images),
        'species_counts': dict(species_counts),
        'h5_path': h5_path,
        'train_samples': len(X_train),
        'test_samples': len(X_test),
        'compression': CONFIG['COMPRESSION'],
        'compression_level': CONFIG['COMPRESSION_LEVEL'],
        'processing_time_seconds': time.time() - start_time,
        'timestamp': timestamp
    }
    
    updateLogging(metadata_path, {
        "dataset_stats": dataset_stats
    })

    print(f"Dataset created with multiple formats:")
    print(f"- HDF5 file: {h5_path}")
    print(f"Total processing time: {time.time() - start_time:.2f} seconds")

In [38]:
SKIP_DOWNLOAD = True
#dir = DATA_DIR
dir = f"{DATA_DIR}/{AUGMENTED_DIR}"

print("Initiating dataset creation...")
metadata_path = initLogging(DATA_DIR)
for species in species_keys.keys():
    if not SKIP_DOWNLOAD:
        downloadImages(species, DATA_DIR, limit=600, metadata_path=metadata_path)
    #transformImagesFromDirectory(species, DATA_DIR, metadata_path, True)
    cleanData(species, dir, metadata_path)
createDataset(metadata_path)
print("Tasks completed")


Initiating dataset creation...


Cleaning Carduelis carduelis: 100%|██████████| 599/599 [00:01<00:00, 572.76it/s]
Cleaning Ciconia ciconia: 100%|██████████| 600/600 [00:01<00:00, 548.17it/s]
Cleaning Columba livia: 100%|██████████| 600/600 [00:00<00:00, 611.61it/s]
Cleaning Delichon urbicum: 100%|██████████| 596/596 [00:00<00:00, 661.65it/s]
Cleaning Emberiza calandra: 100%|██████████| 600/600 [00:00<00:00, 635.53it/s]
Cleaning Hirundo rustica: 100%|██████████| 597/597 [00:00<00:00, 615.42it/s]
Cleaning Passer domesticus: 100%|██████████| 600/600 [00:01<00:00, 579.83it/s]
Cleaning Serinus serinus: 100%|██████████| 600/600 [00:01<00:00, 458.92it/s]
Cleaning Streptopelia decaocto: 100%|██████████| 600/600 [00:01<00:00, 580.45it/s]
Cleaning Sturnus unicolor: 100%|██████████| 600/600 [00:00<00:00, 610.08it/s]
Cleaning Turdus merula: 100%|██████████| 596/596 [00:01<00:00, 554.60it/s]


Dataset created with multiple formats:
- HDF5 file: full_image_dataset\dataset_20250524_131708.h5
Total processing time: 144.13 seconds
Tasks completed
