#### Download and deduplication

In [3]:
from ddgs import DDGS
import os
import requests

def collect_url_and_download_images(specie: str, max_pages: int=50, out_folder: str='dataset'):
    outdir = f"{out_folder}/{specie}/raw"
    os.makedirs(outdir, exist_ok=True)
    
    results = []
    for query in [specie, f"dinosaur {specie}"]:
        for i in range(max_pages):
            results.extend(DDGS().images(
                query=query,
                region="us-en",
                safesearch="off",
                max_results=1000,
                page=i))
            
    unique = {result["image"] for result in results}
    for idx, url in enumerate(unique, start=1):
        try:
            r = requests.get(url, timeout=20, headers={"User-Agent": "Mozilla/5.0"})
            r.raise_for_status()
            ext = os.path.splitext(url.split("?")[0])[1] or ".jpg"
            filename = os.path.join(outdir, f"{idx:04d}{ext}")
            with open(filename, "wb") as f:
                f.write(r.content)
            print(f"Downloaded {filename}")
        except Exception as e:
            print(f"Failed {url} -> {e}")\
                
    return outdir
            
import imagehash
from PIL import Image
import glob
import os

def remove_duplicate_images(folder_path: str, similarity_threshold: int = 4):
    """
    Remove duplicate/similar images using perceptual hashing.
    
    Args:
        folder_path: Path to folder containing images
        similarity_threshold: Lower = more strict (0=identical, 5=default, 10+=lenient)
    """
    # Get all image files
    image_files = glob.glob(os.path.join(folder_path, "*"))
    image_files = [f for f in image_files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'))]
    
    if not image_files:
        print("No images found in folder")
        return
    
    print(f"Checking {len(image_files)} images for duplicates...")
    
    hashes = {}
    duplicates = []
    
    for img_path in image_files:
        try:
            with Image.open(img_path) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                img_hash = imagehash.phash(img)
                
                # Check if similar image already exists
                for existing_hash in hashes:
                    if abs(img_hash - existing_hash) <= similarity_threshold:
                        duplicates.append(img_path)
                        print(f"Duplicate: {os.path.basename(img_path)}")
                        break
                else:
                    hashes[img_hash] = img_path
                    
        except Exception as e:
            print(f"Error processing {os.path.basename(img_path)}: {e}")
            print(f"Deleting {os.path.basename(img_path)}")
            os.remove(img_path)
    
    # Remove duplicates
    for duplicate in duplicates:
        os.remove(duplicate)
    
    print(f"Removed {len(duplicates)} duplicates. {len(image_files) - len(duplicates)} unique images remain.")

### Dataset cleaning

In [4]:
# !pip install torch torchvision ftfy regex tqdm pillow
# !pip install git+https://github.com/openai/CLIP.git

import os, shutil, glob
import torch
import clip
from PIL import Image
from tqdm import tqdm

@torch.no_grad()
def score_image(img_path, model, preprocess, device, 
                TXT_IS_DINO, TXT_NOT_DINO, TXT_REAL, TXT_NONREAL,
                is_dino_prompts, not_dino_prompts, realistic_prompts, non_realistic_prompts,
                verbose=False):
    try:
        image = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    except Exception:
        return None, None

    img_feat = model.encode_image(image)
    img_feat /= img_feat.norm(dim=-1, keepdim=True)

    def logits_for(prompts, txt_feats):
        # cosine sims
        sims = (img_feat @ txt_feats.T)  # [1, num_prompts]
        # apply CLIP temperature (sharpens distribution)
        scale = model.logit_scale.exp()
        return sims * scale

    # Get logits for each category
    L_is   = logits_for(is_dino_prompts, TXT_IS_DINO)
    L_not  = logits_for(not_dino_prompts, TXT_NOT_DINO)
    L_real = logits_for(realistic_prompts, TXT_REAL)
    L_non  = logits_for(non_realistic_prompts, TXT_NONREAL)

    # 1) Dinosaur classification: max positive vs max negative
    max_is_dino = L_is.max().item()
    max_not_dino = L_not.max().item()
    
    # Softmax between the two max logits
    dino_logits = torch.tensor([max_is_dino, max_not_dino])
    dino_probs = torch.softmax(dino_logits, dim=0)
    p_is_dino = dino_probs[0].item()
    p_not_dino = dino_probs[1].item()

    # 2) Realism classification: max positive vs max negative
    max_real = L_real.max().item()
    max_non_real = L_non.max().item()
    
    # Softmax between the two max logits
    real_logits = torch.tensor([max_real, max_non_real])
    real_probs = torch.softmax(real_logits, dim=0)
    p_real = real_probs[0].item()
    p_non_real = real_probs[1].item()

    # margins in [âˆ’1, 1]; 0 means tie, >0 favors positives
    is_dino_margin = p_is_dino - p_not_dino
    realism_margin = p_real - p_non_real

    if verbose:
        print(f"\nðŸ‘‰ {os.path.basename(img_path)}")
        print("  [DINO LOGITS]")
        for i, (s, logit) in enumerate(zip(is_dino_prompts, L_is.squeeze(0).tolist())):
            mark = "â˜…" if i == L_is.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print("  [NOT_DINO LOGITS]")
        for i, (s, logit) in enumerate(zip(not_dino_prompts, L_not.squeeze(0).tolist())):
            mark = "â˜…" if i == L_not.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print("  [REALISTIC LOGITS]")
        for i, (s, logit) in enumerate(zip(realistic_prompts, L_real.squeeze(0).tolist())):
            mark = "â˜…" if i == L_real.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print("  [NON_REALISTIC LOGITS]")
        for i, (s, logit) in enumerate(zip(non_realistic_prompts, L_non.squeeze(0).tolist())):
            mark = "â˜…" if i == L_non.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print(f"  Max dino: {max_is_dino:.3f} vs Max not-dino: {max_not_dino:.3f}")
        print(f"  Max real: {max_real:.3f} vs Max non-real: {max_non_real:.3f}")
        print(f"  p_is_dino={p_is_dino:.3f}  p_real={p_real:.3f}")
        print(f"  is_dino_margin={is_dino_margin:.3f}  realism_margin={realism_margin:.3f}")

    return is_dino_margin, realism_margin

def filter_folder(in_dir, out_good, out_bad, specie,
                  dino_thr=0.20, realism_thr=0.15, verbose=False):
    
    os.makedirs(out_good, exist_ok=True)
    os.makedirs(out_bad, exist_ok=True)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    # ---------- PROMPTS ----------
    # 1) "Ãˆ un dinosauro?"
    is_dino_prompts = [
        "a realistic illustration of a dinosaur",
        "a realistic toy dinosaur figure",
        "a photo of a dinosaur",
        "a paleoart illustration of a dinosaur",
        f"a realistic illustration of a {specie}",
        f"a realistic toy {specie} figure",
        f"a photo of a {specie}",
        f"a paleoart illustration of a {specie}",
    ]
    not_dino_prompts = [
        "a photo of a modern animal",
        "a person or human",
        "a landscape without animals",
        "a vehicle or building",
        "a toy",
        "a cloth",
        "an abstract image",
        "a geometric figure",
    ]


    # 2) Realistico vs Cartoon/Toy/Fossile
    realistic_prompts = [
        "a realistic illustration of a dinosaur",
        "a dinosaur fossil skeleton",
        "a realistic toy dinosaur figure",
        "a high-quality render of a dinosaur",
        
        f"a realistic illustration of a {specie}",
        f"a {specie} fossil skeleton",
        f"a realistic toy {specie} figure",
        f"a high-quality render of a {specie}",
    ]
    non_realistic_prompts = [
        "a cartoon dinosaur for kids",
        "a peluche toy dinosaur figure",
        "a pixel art dinosaur",
        "a simple line drawing of a dinosaur",
        "a non realistic drawing of a dinosaur"
    ]

    @torch.no_grad()
    def encode_text_batch(prompts):
        tokens = clip.tokenize(prompts).to(device)
        text_features = model.encode_text(tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    TXT_IS_DINO = encode_text_batch(is_dino_prompts)
    TXT_NOT_DINO = encode_text_batch(not_dino_prompts)
    TXT_REAL = encode_text_batch(realistic_prompts)
    TXT_NONREAL = encode_text_batch(non_realistic_prompts)

    imgs = []
    for ext in ("*.jpg","*.jpeg","*.png","*.webp","*.bmp"):
        imgs += glob.glob(os.path.join(in_dir, ext))
    kept = 0; rejected = 0

    for p in tqdm(imgs, desc=f"Filtering {os.path.basename(in_dir)}"):
        res = score_image(p, model, preprocess, device, 
                        TXT_IS_DINO, TXT_NOT_DINO, TXT_REAL, TXT_NONREAL,
                        is_dino_prompts, not_dino_prompts, realistic_prompts, non_realistic_prompts, 
                        verbose=verbose)
        if res is None:
            rejected += 1
            shutil.copy(p, os.path.join(out_bad, os.path.basename(p)))
            continue
        is_dino_m, realism_m = res

        if (is_dino_m >= dino_thr) and (realism_m >= realism_thr):
            kept += 1
            shutil.copy(p, os.path.join(out_good, os.path.basename(p)))
        else:
            rejected += 1
            shutil.copy(p, os.path.join(out_bad, os.path.basename(p)))

    return kept, rejected



# pipeline for the 15 species

In [None]:
species = [
    "Ankylosaurus",
    "Brachiosaurus",
    "Compsognathus",
    "Corythosaurus",
    "Dilophosaurus",
    "Dimorphodon",
    "Gallimimus",
    "Microceratus",
    "Pachycephalosaurus",
    "Parasaurolophus",
    "Spinosaurus",
    "Stegosaurus",
    "Triceratops",
    "Tyrannosaurus",
    "Velociraptor"
]
species = ["Tyrannosaurus"]
for specie in species:
    print(f"processing dinosaur {specie}")
    specie_raw_dir = collect_url_and_download_images(specie=specie)
    remove_duplicate_images(folder_path=specie_raw_dir)
    print(f"filtering dinosaur {specie}")
    kept, rej = filter_folder(
        specie_raw_dir,
        out_good=os.path.join(specie_raw_dir, "clean"),
        out_bad=os.path.join(specie_raw_dir,"rejected"),
        specie=specie,
        dino_thr=0.20,
        realism_thr=0.20)
    print(f"Specie: {specie}. Kept {(kept)}, ejected {(rej)}")

In [24]:
for dinosaurs in species:
    print(f'dino: {dinosaurs}')
    folder = f'./dataset/{dinosaurs}/raw/clean'
    out_folder = f'dataset/to_phone/{dinosaurs}'
    #os.makedirs(out_folder)
    for file_name in tqdm(os.listdir(folder)):
        # construct full file path
        source = os.path.join(folder, file_name)
        destination = os.path.join(out_folder, file_name)
        # copy only files
        if os.path.isfile(source):
            shutil.copy(source, destination)
    

dino: Ankylosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 597/597 [00:00<00:00, 665.71it/s]


dino: Brachiosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 518/518 [00:00<00:00, 725.86it/s]


dino: Compsognathus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 480/480 [00:04<00:00, 102.02it/s]


dino: Corythosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 390/390 [00:05<00:00, 76.57it/s] 


dino: Dilophosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 518/518 [00:07<00:00, 70.90it/s]


dino: Dimorphodon


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 414/414 [00:05<00:00, 70.15it/s]


dino: Gallimimus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 494/494 [00:06<00:00, 72.91it/s]


dino: Microceratus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 425/425 [00:06<00:00, 68.22it/s]


dino: Pachycephalosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 435/435 [00:06<00:00, 70.41it/s]


dino: Parasaurolophus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 504/504 [00:07<00:00, 67.62it/s]


dino: Spinosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 545/545 [00:08<00:00, 67.43it/s]


dino: Stegosaurus


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 647/647 [00:09<00:00, 69.61it/s]


dino: Triceratops


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 650/650 [00:09<00:00, 68.36it/s]


dino: Tyrannosaurus_Rex


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 392/392 [00:05<00:00, 66.70it/s]


dino: Velociraptor


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 518/518 [00:07<00:00, 65.14it/s]


# After manual curation of the dataset - Split

In [12]:
import random
import os
import shutil
from PIL import Image

curated_path = './dataset/hand_curated_datasets'
final_dataset = './dataset/dataset'
seed = 42
random.seed(seed)

for species_sub_folder in os.listdir(curated_path):
    print(species_sub_folder)
    species_images = os.listdir(os.path.join(curated_path,species_sub_folder))
    species_images = [image for image in species_images if not image.startswith('.')] 
    indices = random.sample(range(len(species_images)), int(0.15*len(species_images)))
    os.makedirs(os.path.join(final_dataset, 'train', species_sub_folder), exist_ok=True)
    os.makedirs(os.path.join(final_dataset, 'test', species_sub_folder), exist_ok=True)
    
    for idx, image in enumerate(species_images):
        path_to_image = os.path.join(curated_path, species_sub_folder, image)
        split = 'test' if idx in indices else 'train'
        dst_path = os.path.join(final_dataset, split, species_sub_folder, image)
        try:
            img = Image.open(path_to_image)
            png_filename = os.path.splitext(image)[0] + '.png'
            dst_path = os.path.join(final_dataset, split, species_sub_folder, png_filename)
            img.save(dst_path, 'PNG')
        except Exception as e:
            print(f"Error processing {image}: {e}")
         

        


Microceratus
Pachycephalosaurus
Parasaurolophus
Spinosaurus
Stegosaurus
Triceratops
Tyrannosaurus
Velociraptor


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),       # force 3 channels
    transforms.RandomResizedCrop(224, antialias=True),       # preserves ratio, random crop
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_tfms = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize(256, antialias=True),                  # resize short side to 256
    transforms.CenterCrop(224),                              # now 224x224
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

root = "./dataset/dataset"

train_ds = datasets.ImageFolder(root=f"{root}/train", transform=train_tfms)
test_ds  = datasets.ImageFolder(root=f"{root}/test",  transform=test_tfms)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
