In [11]:
import collections
from pathlib import Path

import pandas
import tqdm

from mlutil import label_mapping, s2cell_mapping

s2mapping = s2cell_mapping.S2CellMapping.from_label_mapping(
    label_mapping.LabelMapping.read_csv(Path.home() / "datasets/img2loc/s2cell_930_ml.csv")
)

def annotate_s2_classes(img_root, df):
    # Assign each row a list of classes
    def class_list_str(row):
        one_hot = s2mapping.lat_lng_to_multihot_list(row.latitude, row.longitude)
        class_numbers = [i for i, v in enumerate(one_hot) if v]
        class_numbers.sort()
        return ",".join(str(i) for i in class_numbers)

    # Make sure every img_path is valid (otherwise class index doesn't make sense)
    exists = []
    for img_path in tqdm.tqdm(df.img_path, desc="Checking images"):
        exists.append((img_root / img_path).exists())
    df["img_exists"] = exists
    df = df[df.img_exists].copy()

    s2_classes = []
    for row in tqdm.tqdm(df.itertuples(), total=len(df), desc="Annotating s2_classes"):
        s2_classes.append(class_list_str(row))
    df["s2_classes"] = s2_classes

    return df

def annotate_classindex(df):
    # Shuffle and pick the first N of each class
    df = df.sample(frac=1, ignore_index=True)

    s2_classes_count = collections.Counter() # how many of each classlist have we seen already
    sameclass_index = []
    for row in tqdm.tqdm(df.itertuples(), total=len(df), desc="Class index"):
        sameclass_index.append(s2_classes_count[row.s2_classes])
        s2_classes_count[row.s2_classes] += 1
    df["sameclass_index"] = sameclass_index

    return df

def assign_split(df, n):
    # Pick the first N sameclass_index as the validation set
    split = []
    for row in tqdm.tqdm(df.itertuples()):
        split.append("val" if row.sameclass_index < n else "train")
    df["split"] = split

    return df

In [4]:
# n=8 for im2gps combined v2
df = pandas.read_pickle(Path.home() / "LocalProjects/datasets/im2gps/outputs/clustered/im2gps_1_filtered.pkl")
df = annotate_s2_classes(df)
df = annotate_classindex(Path.home() / "LocalProjects/datasets/im2gps/outputs/img", df)
df.to_pickle(Path.home() / "LocalProjects/datasets/im2gps/outputs/clustered/im2gps_2_classindex.pkl")
shuffled = assign_split(df, 8)
shuffled.to_pickle(Path.home() / "LocalProjects/datasets/im2gps/outputs/clustered/im2gps_3_split.pkl")

(shuffled["split"] == "val").sum(), (shuffled["split"] == "train").sum()

100%|██████████| 773879/773879 [00:53<00:00, 14479.25it/s]
100%|██████████| 773879/773879 [00:06<00:00, 119713.00it/s]
100%|██████████| 773879/773879 [00:02<00:00, 312361.64it/s]
773879it [00:02, 329910.56it/s]


(6970, 766909)

In [13]:
# n=2 for world1 sv
df = pandas.read_pickle(Path.home() / "LocalProjects/datasets/img2loc/outputs/world1/s3_parameterized.pkl")
df = annotate_s2_classes(Path.home() / "LocalProjects/datasets/img2loc/outputs/world1/img", df)
df = annotate_classindex(df)
shuffled = assign_split(df, 2)
shuffled.to_pickle(Path.home() / "LocalProjects/datasets/img2loc/outputs/world1/s4_split.pkl")

(shuffled["split"] == "val").sum(), (shuffled["split"] == "train").sum()

Checking images: 100%|██████████| 203572/203572 [00:01<00:00, 138465.96it/s]
Annotating s2_classes: 100%|██████████| 40033/40033 [00:02<00:00, 15695.86it/s]
Class index: 100%|██████████| 40033/40033 [00:00<00:00, 555711.90it/s]
40033it [00:00, 622340.48it/s]


(1527, 38506)