In [1]:
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTImageProcessor
import numpy as np
import evaluate
import torch
import huggingface_hub
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os


In [2]:
CUDA_AVAILABLE=torch.cuda.is_available()
print(f"CUDA={CUDA_AVAILABLE}")
device = "cuda" if CUDA_AVAILABLE else "cpu"
print(f"count={torch.cuda.device_count()}")
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")

CUDA=True
count=1
current=NVIDIA GeForce RTX 4070


In [3]:
ds_k78k = load_dataset("jrheiner/geo_training_v3", split="test")
ds_im2gps = load_dataset("jrheiner/im2gps", split="train")
ds_im2gps2k = load_dataset("jrheiner/im2gps2k", split="train")
ds_im2gps3k = load_dataset("jrheiner/im2gps3k", split="train")
ds_yfcc4k = load_dataset("jrheiner/yfcc4k", split="train")


Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

In [4]:
def map_id2label(example):
    example["country_str"] = ds_im2gps.features["country"].int2str(example["country"])
    example["continent_str"] = ds_im2gps.features["continent"].int2str(example["continent"])
    return example

ds_im2gps = ds_im2gps.map(map_id2label, batched=True)
ds_im2gps = ds_im2gps.remove_columns(["country", "continent"])
ds_im2gps = ds_im2gps.rename_column("country_str", "country")
ds_im2gps = ds_im2gps.rename_column("continent_str", "continent")

def map_id2label(example):
    example["country_str"] = ds_im2gps2k.features["country"].int2str(example["country"])
    example["continent_str"] = ds_im2gps2k.features["continent"].int2str(example["continent"])
    return example

ds_im2gps2k = ds_im2gps2k.map(map_id2label, batched=True)
ds_im2gps2k = ds_im2gps2k.remove_columns(["country", "continent"])
ds_im2gps2k = ds_im2gps2k.rename_column("country_str", "country")
ds_im2gps2k = ds_im2gps2k.rename_column("continent_str", "continent")


In [5]:
ds_k78k

Dataset({
    features: ['image', 'image_id', 'longitude', 'latitude', 'country', 'continent', 'source'],
    num_rows: 7020
})

In [6]:
ds_im2gps

Dataset({
    features: ['image', 'im2gps_category', 'latitude', 'longitude', 'country', 'continent'],
    num_rows: 237
})

In [7]:
ds_im2gps2k

Dataset({
    features: ['image', 'latitude', 'longitude', 'country', 'continent'],
    num_rows: 2000
})

In [8]:
ds_im2gps3k

Dataset({
    features: ['image', 'latitude', 'longitude', 'country', 'continent'],
    num_rows: 2997
})

In [9]:
ds_yfcc4k

Dataset({
    features: ['image', 'id', 'longitude', 'latitude', 'country', 'continent'],
    num_rows: 4536
})

In [73]:
EVAL_EXPORT_BASE_DIR = "evals"
BATCH_SIZE=1

EVAL_DATASETS = {
    "k78k": ds_k78k,
    "im2gps": ds_im2gps,
    "im2gps2k": ds_im2gps2k,
    "im2gps3k": ds_im2gps3k,
    "yfcc4k": ds_yfcc4k
}

TARGET_LABELS = ["continent", "country"]

EVAL_MODELS = {
    "ViT-T-16-continent-scratch": {
        "model_path": "models/vit-tiny-16-224-continent-base",
        "processor": "WinKawaks/vit-tiny-patch16-224"
    },
    "ViT-T-16-country-scratch": {
        "model_path": "models/vit-tiny-16-224-country-base",
        "processor": "WinKawaks/vit-tiny-patch16-224"
    },
    "ViT-T-16-continent-finetune": {
        "model_path": "models/vit-tiny-16-224-continent-pretraining",
        "processor": "WinKawaks/vit-tiny-patch16-224"
    },
    "ViT-T-16-country-finetune": {
        "model_path": "models/vit-tiny-16-224-country-pretraining",
        "processor": "WinKawaks/vit-tiny-patch16-224"
    },
    "ViT-B-16-continent-finetune": {
        "model_path": "models/vit-base-16-224-continent-finetune",
        "processor": "google/vit-base-patch16-224"
    },
    "ViT-B-16-country-finetune": {
        "model_path": "models/vit-base-16-224-country-finetune",
        "processor": "google/vit-base-patch16-224"
    }
}

In [74]:
for eval_model in EVAL_MODELS:
    print(eval_model)
    model_path = EVAL_MODELS[eval_model]["model_path"]
    processor_path = EVAL_MODELS[eval_model]["processor"]
    torch.cuda.empty_cache()
    model = ViTForImageClassification.from_pretrained(model_path)
    processor = ViTImageProcessor.from_pretrained(processor_path)
    model = model.to(device)

    eval_export_dir = os.path.join(EVAL_EXPORT_BASE_DIR, eval_model)
    if not os.path.exists(eval_export_dir):
        os.makedirs(eval_export_dir)

    for ds in EVAL_DATASETS.keys():
        i_ds = EVAL_DATASETS[ds].to_iterable_dataset()
        for target in TARGET_LABELS:
            pred_ids = []
            gt_ids = []
            pbar = tqdm(enumerate(i_ds.iter(batch_size=BATCH_SIZE)), total=int(EVAL_DATASETS[ds].num_rows/BATCH_SIZE)+1, desc=f"{ds}-{target}")
            for idx, batch in pbar:
                if np.array(batch["image"]).shape[-1] != 3:
                    inputs = processor(images=np.stack((np.array(ds_yfcc4k[20]["image"]),)*3, axis=-1), return_tensors='pt')
                else:
                    inputs = processor(images=batch["image"], return_tensors='pt')
                batch_gt_ids = batch[target]
                gt_ids.extend(batch_gt_ids)
                inputs = inputs.to(device)
                with torch.no_grad():
                    outputs = model(**inputs)
                    batch_pred_ids = outputs.logits.argmax(axis=-1).to("cpu")

                pred_ids.extend([model.config.id2label[pred.item()] for pred in batch_pred_ids])
                pbar.set_postfix_str(f"Accuracy: {accuracy_score(y_true=gt_ids, y_pred=pred_ids):.4f}")
            np.save(os.path.join(eval_export_dir, f"{ds}-{target}"), pred_ids)


ViT-T-16-continent-scratch


k78k-continent: 100%|█████████▉| 7020/7021 [02:17<00:00, 51.16it/s, Accuracy: 0.5721]
k78k-country: 100%|█████████▉| 7020/7021 [02:35<00:00, 45.28it/s, Accuracy: 0.0000]
im2gps-continent: 100%|█████████▉| 237/238 [00:03<00:00, 59.33it/s, Accuracy: 0.2532]
im2gps-country: 100%|█████████▉| 237/238 [00:03<00:00, 60.53it/s, Accuracy: 0.0000]
im2gps2k-continent: 100%|█████████▉| 2000/2001 [00:33<00:00, 59.59it/s, Accuracy: 0.1970]
im2gps2k-country: 100%|█████████▉| 2000/2001 [00:32<00:00, 61.92it/s, Accuracy: 0.0000]
im2gps3k-continent: 100%|█████████▉| 2997/2998 [00:50<00:00, 59.80it/s, Accuracy: 0.2289]
im2gps3k-country: 100%|█████████▉| 2997/2998 [00:50<00:00, 59.86it/s, Accuracy: 0.0000]
yfcc4k-continent: 100%|█████████▉| 4536/4537 [00:58<00:00, 77.06it/s, Accuracy: 0.1825]
yfcc4k-country: 100%|█████████▉| 4536/4537 [00:57<00:00, 78.64it/s, Accuracy: 0.0000]


ViT-T-16-country-scratch


k78k-continent: 100%|█████████▉| 7020/7021 [02:34<00:00, 45.55it/s, Accuracy: 0.0000]
k78k-country: 100%|█████████▉| 7020/7021 [02:16<00:00, 51.56it/s, Accuracy: 0.4652]
im2gps-continent: 100%|█████████▉| 237/238 [00:02<00:00, 86.09it/s, Accuracy: 0.0000]
im2gps-country: 100%|█████████▉| 237/238 [00:02<00:00, 82.37it/s, Accuracy: 0.0338]
im2gps2k-continent: 100%|█████████▉| 2000/2001 [00:25<00:00, 78.15it/s, Accuracy: 0.0000]
im2gps2k-country: 100%|█████████▉| 2000/2001 [00:26<00:00, 75.04it/s, Accuracy: 0.0695]
im2gps3k-continent: 100%|█████████▉| 2997/2998 [00:39<00:00, 75.82it/s, Accuracy: 0.0000]
im2gps3k-country: 100%|█████████▉| 2997/2998 [00:40<00:00, 73.18it/s, Accuracy: 0.0507]
yfcc4k-continent: 100%|█████████▉| 4536/4537 [00:48<00:00, 94.09it/s, Accuracy: 0.0000] 
yfcc4k-country: 100%|█████████▉| 4536/4537 [00:54<00:00, 82.50it/s, Accuracy: 0.0657] 


ViT-T-16-continent-finetune


k78k-continent: 100%|█████████▉| 7020/7021 [02:11<00:00, 53.22it/s, Accuracy: 0.7976]
k78k-country: 100%|█████████▉| 7020/7021 [02:12<00:00, 52.90it/s, Accuracy: 0.0000]
im2gps-continent: 100%|█████████▉| 237/238 [00:03<00:00, 70.46it/s, Accuracy: 0.2700]
im2gps-country: 100%|█████████▉| 237/238 [00:02<00:00, 83.23it/s, Accuracy: 0.0000]
im2gps2k-continent: 100%|█████████▉| 2000/2001 [00:29<00:00, 66.80it/s, Accuracy: 0.2400]
im2gps2k-country: 100%|█████████▉| 2000/2001 [00:28<00:00, 70.55it/s, Accuracy: 0.0000]
im2gps3k-continent: 100%|█████████▉| 2997/2998 [00:51<00:00, 58.73it/s, Accuracy: 0.2926]
im2gps3k-country: 100%|█████████▉| 2997/2998 [00:50<00:00, 59.76it/s, Accuracy: 0.0000]
yfcc4k-continent: 100%|█████████▉| 4536/4537 [00:54<00:00, 83.54it/s, Accuracy: 0.2240] 
yfcc4k-country: 100%|█████████▉| 4536/4537 [00:53<00:00, 85.26it/s, Accuracy: 0.0000]


ViT-T-16-country-finetune


k78k-continent: 100%|█████████▉| 7020/7021 [02:10<00:00, 53.75it/s, Accuracy: 0.0000]
k78k-country: 100%|█████████▉| 7020/7021 [02:22<00:00, 49.38it/s, Accuracy: 0.6775]
im2gps-continent: 100%|█████████▉| 237/238 [00:03<00:00, 61.11it/s, Accuracy: 0.0000]
im2gps-country: 100%|█████████▉| 237/238 [00:03<00:00, 63.76it/s, Accuracy: 0.0464]
im2gps2k-continent: 100%|█████████▉| 2000/2001 [00:33<00:00, 59.51it/s, Accuracy: 0.0000]
im2gps2k-country: 100%|█████████▉| 2000/2001 [00:31<00:00, 64.13it/s, Accuracy: 0.0735]
im2gps3k-continent: 100%|█████████▉| 2997/2998 [00:47<00:00, 63.66it/s, Accuracy: 0.0000]
im2gps3k-country: 100%|█████████▉| 2997/2998 [00:48<00:00, 61.86it/s, Accuracy: 0.0697]
yfcc4k-continent: 100%|█████████▉| 4536/4537 [00:56<00:00, 80.85it/s, Accuracy: 0.0000]
yfcc4k-country: 100%|█████████▉| 4536/4537 [00:56<00:00, 79.73it/s, Accuracy: 0.0626]


ViT-B-16-continent-finetune


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

k78k-continent: 100%|█████████▉| 7020/7021 [02:04<00:00, 56.51it/s, Accuracy: 0.8507]
k78k-country: 100%|█████████▉| 7020/7021 [02:05<00:00, 55.84it/s, Accuracy: 0.0000]
im2gps-continent: 100%|█████████▉| 237/238 [00:03<00:00, 75.89it/s, Accuracy: 0.4304]
im2gps-country: 100%|█████████▉| 237/238 [00:02<00:00, 79.47it/s, Accuracy: 0.0000]
im2gps2k-continent: 100%|█████████▉| 2000/2001 [00:26<00:00, 75.23it/s, Accuracy: 0.3415]
im2gps2k-country: 100%|█████████▉| 2000/2001 [00:27<00:00, 73.11it/s, Accuracy: 0.0000]
im2gps3k-continent: 100%|█████████▉| 2997/2998 [00:42<00:00, 70.73it/s, Accuracy: 0.3917]
im2gps3k-country: 100%|█████████▉| 2997/2998 [00:43<00:00, 68.68it/s, Accuracy: 0.0000]
yfcc4k-continent: 100%|█████████▉| 4536/4537 [00:52<00:00, 86.72it/s, Accuracy: 0.3104]
yfcc4k-country: 100%|█████████▉| 4536/4537 [00:52<00:00, 86.96it/s, Accuracy: 0.0000]


ViT-B-16-country-finetune


k78k-continent: 100%|█████████▉| 7020/7021 [02:07<00:00, 55.12it/s, Accuracy: 0.0000]
k78k-country: 100%|█████████▉| 7020/7021 [02:09<00:00, 54.23it/s, Accuracy: 0.7450]
im2gps-continent: 100%|█████████▉| 237/238 [00:03<00:00, 63.46it/s, Accuracy: 0.0000]
im2gps-country: 100%|█████████▉| 237/238 [00:03<00:00, 74.68it/s, Accuracy: 0.1266]
im2gps2k-continent: 100%|█████████▉| 2000/2001 [00:27<00:00, 73.71it/s, Accuracy: 0.0000]
im2gps2k-country: 100%|█████████▉| 2000/2001 [00:27<00:00, 72.99it/s, Accuracy: 0.1205]
im2gps3k-continent: 100%|█████████▉| 2997/2998 [00:42<00:00, 70.68it/s, Accuracy: 0.0000]
im2gps3k-country: 100%|█████████▉| 2997/2998 [00:42<00:00, 71.11it/s, Accuracy: 0.1158]
yfcc4k-continent: 100%|█████████▉| 4536/4537 [00:51<00:00, 88.93it/s, Accuracy: 0.0000] 
yfcc4k-country: 100%|█████████▉| 4536/4537 [00:51<00:00, 88.89it/s, Accuracy: 0.1254] 
