In [1]:
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
import numpy as np
import evaluate
import torch
import huggingface_hub
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os


In [2]:
CUDA_AVAILABLE=torch.cuda.is_available()
print(f"CUDA={CUDA_AVAILABLE}")
device = "cuda" if CUDA_AVAILABLE else "cpu"
print(f"count={torch.cuda.device_count()}")
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")

CUDA=True
count=1
current=NVIDIA GeForce RTX 4070


In [20]:
ds_k78k = load_dataset("jrheiner/geo_training_v3", split="test")
ds_im2gps = load_dataset("jrheiner/im2gps", split="train")
ds_im2gps2k = load_dataset("jrheiner/im2gps2k", split="train")
ds_im2gps3k = load_dataset("jrheiner/im2gps3k", split="train")
ds_yfcc4k = load_dataset("jrheiner/yfcc4k", split="train")


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Downloading readme:   0%|          | 0.00/469 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/280M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/289M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4536 [00:00<?, ? examples/s]

In [21]:
def map_id2label(example):
    example["country_str"] = ds_im2gps.features["country"].int2str(example["country"])
    example["continent_str"] = ds_im2gps.features["continent"].int2str(example["continent"])
    return example

ds_im2gps = ds_im2gps.map(map_id2label, batched=True)
ds_im2gps = ds_im2gps.remove_columns(["country", "continent"])
ds_im2gps = ds_im2gps.rename_column("country_str", "country")
ds_im2gps = ds_im2gps.rename_column("continent_str", "continent")

def map_id2label(example):
    example["country_str"] = ds_im2gps2k.features["country"].int2str(example["country"])
    example["continent_str"] = ds_im2gps2k.features["continent"].int2str(example["continent"])
    return example

ds_im2gps2k = ds_im2gps2k.map(map_id2label, batched=True)
ds_im2gps2k = ds_im2gps2k.remove_columns(["country", "continent"])
ds_im2gps2k = ds_im2gps2k.rename_column("country_str", "country")
ds_im2gps2k = ds_im2gps2k.rename_column("continent_str", "continent")


In [22]:
ds_k78k

Dataset({
    features: ['image', 'image_id', 'longitude', 'latitude', 'country', 'continent', 'source'],
    num_rows: 7020
})

In [23]:
ds_im2gps

Dataset({
    features: ['image', 'im2gps_category', 'latitude', 'longitude', 'country', 'continent'],
    num_rows: 237
})

In [24]:
ds_im2gps2k

Dataset({
    features: ['image', 'latitude', 'longitude', 'country', 'continent'],
    num_rows: 2000
})

In [25]:
ds_im2gps3k

Dataset({
    features: ['image', 'latitude', 'longitude', 'country', 'continent'],
    num_rows: 2997
})

In [26]:
ds_yfcc4k

Dataset({
    features: ['image', 'id', 'longitude', 'latitude', 'country', 'continent'],
    num_rows: 4536
})

In [27]:
EVAL_EXPORT_BASE_DIR = "evals"
BATCH_SIZE=8

EVAL_DATASETS = {
    # "k78k": ds_k78k,
    # "im2gps": ds_im2gps,
    # "im2gps2k": ds_im2gps2k,
    "im2gps3k": ds_im2gps3k,
    "yfcc4k": ds_yfcc4k
}

TARGET_LABELS = ["continent", "country"]

EVAL_MODELS = {
    "CLIP-ViT-B-16-continent": {
        "model_path": "trainings/clip-ft-1e-6/clip-vit-base-patch16-continent-ft-best_LR_1e-6",
        "processor": "openai/clip-vit-base-patch16"
    },
    "CLIP-ViT-B-16-combinedlabels": {
        "model_path": "trainings/clip-fit-base-combinedlabels/clip-vit-base-patch16-combinedlabels-ft-e3",
        "processor": "openai/clip-vit-base-patch16"
    },
    "CLIP-ViT-B-16-continent-osv5m": {
        "model_path": "trainings/clip-ft-base-osv5m-continent/clip-vit-base-patch16-continent-ft-osv5m-best",
        "processor": "openai/clip-vit-base-patch16"
    },
    "CLIP-ViT-B-16-combinedlabels-osv5m": {
        "model_path": "trainings/clip-ft-base-osv5m-combinedlabels/clip-vit-base-patch16-osv5m-combinedlabels-ft-e3_FINAL",
        "processor": "openai/clip-vit-base-patch16"
    },
    "CLIP-ViT-B-16-continent-zeroshot": {
        "model_path": "openai/clip-vit-base-patch16",
        "processor": "openai/clip-vit-base-patch16"
    },
    "CLIP-ViT-L-14-336-continent": {
        "model_path": "trainings/clip-fit-large-continent/clip-vit-large-patch14-336-continent-ft-best",
        "processor": "openai/clip-vit-large-patch14-336"
    },
    "CLIP-ViT-L-14-336-combinedlabels": {
        "model_path": "trainings/clip-ft-large-combinedlabels/clip-vit-large-patch14-336-combinedlabels-ft-e3",
        "processor": "openai/clip-vit-large-patch14-336"
    },
    "CLIP-ViT-L-14-336-continent-osv5m": {
        "model_path": "trainings/clip-ft-large-osv5m-continent/clip-large-e1",
        "processor": "openai/clip-vit-large-patch14-336"
    },
    "CLIP-ViT-L-14-336-combinedlabels-osv5m": {
        "model_path": "trainings/clip-ft-large-osv5m-combinedlabels/clip-vit-large-patch14-336-osv5m-combinedlabels-ft-e3-BEST",
        "processor": "openai/clip-vit-large-patch14-336"
    },
    "CLIP-ViT-L-14-336-continent-zeroshot": {
        "model_path": "openai/clip-vit-large-patch14-336",
        "processor": "openai/clip-vit-large-patch14-336"
    },
}

In [28]:
for eval_model in EVAL_MODELS:
    print(eval_model)
    model_path = EVAL_MODELS[eval_model]["model_path"]
    processor_path = EVAL_MODELS[eval_model]["processor"]
    torch.cuda.empty_cache()
    model = CLIPModel.from_pretrained(model_path)
    processor = CLIPProcessor.from_pretrained(processor_path)
    model = model.to(device)

    eval_export_dir = os.path.join(EVAL_EXPORT_BASE_DIR, eval_model)
    if not os.path.exists(eval_export_dir):
        os.makedirs(eval_export_dir)

    for ds in EVAL_DATASETS.keys():
        i_ds = EVAL_DATASETS[ds].to_iterable_dataset()
        for target in TARGET_LABELS:
            labels = EVAL_DATASETS[ds].unique(target)
            text_prompts = [f"A photo from {geo}." for geo in labels]
            pred_ids = []
            gt_ids = []
            pbar = tqdm(enumerate(i_ds.iter(batch_size=BATCH_SIZE)), total=int(EVAL_DATASETS[ds].num_rows/BATCH_SIZE)+1, desc=f"{ds}-{target}")
            for idx, batch in pbar:
                inputs = processor(text=text_prompts, images=batch["image"], return_tensors='pt', padding=True)
                batch_gt_ids = batch[target]
                gt_ids.extend(batch_gt_ids)
                inputs = inputs.to(device)
                with torch.no_grad():
                    outputs = model(**inputs)
                    logits_per_image = outputs.logits_per_image
                    probs = logits_per_image.softmax(dim=-1)
                    batch_pred_ids = probs.argmax(axis=-1).to("cpu")

                pred_ids.extend([labels[pred] for pred in batch_pred_ids])
                pbar.set_postfix_str(f"Accuracy: {accuracy_score(y_true=gt_ids, y_pred=pred_ids):.4f}")
            np.save(os.path.join(eval_export_dir, f"{ds}-{target}"), pred_ids)


CLIP-ViT-B-16-continent


im2gps3k-continent: 100%|██████████| 375/375 [00:39<00:00,  9.42it/s, Accuracy: 0.6563]
im2gps3k-country: 100%|██████████| 375/375 [00:39<00:00,  9.43it/s, Accuracy: 0.3530]
yfcc4k-continent: 100%|█████████▉| 567/568 [00:35<00:00, 16.10it/s, Accuracy: 0.6133]
yfcc4k-country: 100%|█████████▉| 567/568 [00:41<00:00, 13.58it/s, Accuracy: 0.2848]


CLIP-ViT-B-16-combinedlabels


im2gps3k-continent: 100%|██████████| 375/375 [00:38<00:00,  9.77it/s, Accuracy: 0.6373]
im2gps3k-country: 100%|██████████| 375/375 [00:39<00:00,  9.50it/s, Accuracy: 0.4064]
yfcc4k-continent: 100%|█████████▉| 567/568 [00:34<00:00, 16.47it/s, Accuracy: 0.5763]
yfcc4k-country: 100%|█████████▉| 567/568 [00:41<00:00, 13.74it/s, Accuracy: 0.3234]


CLIP-ViT-B-16-continent-osv5m


im2gps3k-continent: 100%|██████████| 375/375 [00:38<00:00,  9.83it/s, Accuracy: 0.6720]
im2gps3k-country: 100%|██████████| 375/375 [00:39<00:00,  9.51it/s, Accuracy: 0.3493]
yfcc4k-continent: 100%|█████████▉| 567/568 [00:34<00:00, 16.50it/s, Accuracy: 0.6080]
yfcc4k-country: 100%|█████████▉| 567/568 [00:41<00:00, 13.77it/s, Accuracy: 0.3228]


CLIP-ViT-B-16-combinedlabels-osv5m


im2gps3k-continent: 100%|██████████| 375/375 [00:38<00:00,  9.84it/s, Accuracy: 0.6400]
im2gps3k-country: 100%|██████████| 375/375 [00:38<00:00,  9.63it/s, Accuracy: 0.3901]
yfcc4k-continent: 100%|█████████▉| 567/568 [00:34<00:00, 16.45it/s, Accuracy: 0.5578]
yfcc4k-country: 100%|█████████▉| 567/568 [00:41<00:00, 13.60it/s, Accuracy: 0.3020]


CLIP-ViT-B-16-continent-zeroshot


im2gps3k-continent: 100%|██████████| 375/375 [00:39<00:00,  9.59it/s, Accuracy: 0.5465]
im2gps3k-country: 100%|██████████| 375/375 [00:39<00:00,  9.38it/s, Accuracy: 0.4558]
yfcc4k-continent: 100%|█████████▉| 567/568 [00:34<00:00, 16.24it/s, Accuracy: 0.4147]
yfcc4k-country: 100%|█████████▉| 567/568 [00:41<00:00, 13.63it/s, Accuracy: 0.3810]


CLIP-ViT-L-14-336-continent


im2gps3k-continent: 100%|██████████| 375/375 [02:10<00:00,  2.88it/s, Accuracy: 0.7928]
im2gps3k-country: 100%|██████████| 375/375 [02:17<00:00,  2.72it/s, Accuracy: 0.5045]
yfcc4k-continent: 100%|█████████▉| 567/568 [02:58<00:00,  3.17it/s, Accuracy: 0.6825]
yfcc4k-country: 100%|█████████▉| 567/568 [03:13<00:00,  2.93it/s, Accuracy: 0.4231]


CLIP-ViT-L-14-336-combinedlabels


im2gps3k-continent: 100%|██████████| 375/375 [02:13<00:00,  2.81it/s, Accuracy: 0.7908]
im2gps3k-country: 100%|██████████| 375/375 [02:21<00:00,  2.66it/s, Accuracy: 0.5759]
yfcc4k-continent: 100%|█████████▉| 567/568 [02:58<00:00,  3.18it/s, Accuracy: 0.6049]
yfcc4k-country: 100%|█████████▉| 567/568 [03:15<00:00,  2.90it/s, Accuracy: 0.4888]


CLIP-ViT-L-14-336-continent-osv5m


im2gps3k-continent: 100%|██████████| 375/375 [02:13<00:00,  2.81it/s, Accuracy: 0.8098]
im2gps3k-country: 100%|██████████| 375/375 [02:21<00:00,  2.64it/s, Accuracy: 0.5472]
yfcc4k-continent: 100%|█████████▉| 567/568 [03:02<00:00,  3.11it/s, Accuracy: 0.7189]
yfcc4k-country: 100%|█████████▉| 567/568 [03:13<00:00,  2.93it/s, Accuracy: 0.4563]


CLIP-ViT-L-14-336-combinedlabels-osv5m


im2gps3k-continent: 100%|██████████| 375/375 [02:15<00:00,  2.77it/s, Accuracy: 0.7998]
im2gps3k-country: 100%|██████████| 375/375 [02:23<00:00,  2.62it/s, Accuracy: 0.5712]
yfcc4k-continent: 100%|█████████▉| 567/568 [03:02<00:00,  3.11it/s, Accuracy: 0.6510]
yfcc4k-country: 100%|█████████▉| 567/568 [03:10<00:00,  2.97it/s, Accuracy: 0.4755]


CLIP-ViT-L-14-336-continent-zeroshot


im2gps3k-continent: 100%|██████████| 375/375 [02:10<00:00,  2.87it/s, Accuracy: 0.6880]
im2gps3k-country: 100%|██████████| 375/375 [02:17<00:00,  2.73it/s, Accuracy: 0.5202]
yfcc4k-continent: 100%|█████████▉| 567/568 [02:56<00:00,  3.21it/s, Accuracy: 0.5331]
yfcc4k-country: 100%|█████████▉| 567/568 [03:10<00:00,  2.98it/s, Accuracy: 0.3836]
