# im2gps small dataset (for overfitting)

In [1]:
import collections
from pathlib import Path
import json

import pandas
import s2sphere
import webdataset
import tqdm

import mlutil.label_mapping
import mlutil.s2cell_mapping

DATASETS = Path.home() / "LocalProjects" / "datasets"

label_mapping = mlutil.label_mapping.LabelMapping.read_csv(DATASETS / "im2gps/outputs/s2cell_930_ml.csv")
s2cell_mapping = mlutil.s2cell_mapping.S2CellMapping.from_label_mapping(label_mapping)
len(s2cell_mapping.all_cell_ids)

930

In [15]:
# Load dataframe and select some small subsets
#all_df = pandas.read_pickle(DATASETS / "im2gps/outputs/im2gps_2023.pkl")
all_df = pandas.read_pickle(DATASETS / "img2loc/outputs/world1/s3_parameterized.pkl")
all_df = all_df.head(40000)
all_df.columns

Index(['raw_lat', 'raw_lng', 'latitude', 'longitude', 'status', 'pano_id',
       'fov', 'pitch', 'heading', 'img_path'],
      dtype='object')

In [13]:
def select_n_per_label(df, n):
    label_to_index = collections.defaultdict(list)
    for row in df.itertuples():
        s2cell_token = s2cell_mapping.lat_lng_to_token(row.latitude, row.longitude)
        if s2cell_token is None:
            continue
        if len(label_to_index[s2cell_token]) < n:
            label_to_index[s2cell_token].append(row.Index)

    # Create a new DataFrame from the saved indices
    return all_df.loc[pandas.concat([all_df.loc[indices] for indices in label_to_index.values()]).index]

def select_n_per_multilabel(df, n):
    tensor_to_index = collections.defaultdict(list)
    for row in tqdm.tqdm(df.itertuples(), total=len(df)):
        label_tensor = s2cell_mapping.lat_lng_to_multihot_list(row.latitude, row.longitude)
        if sum(label_tensor) == 0:
            continue

        label_tuple = tuple(label_tensor)
        if len(tensor_to_index[label_tuple]) < n:
            tensor_to_index[label_tuple].append(row.Index)
        if len(tensor_to_index) > 10000:
            print("Oh no")
            break

    # Create a new DataFrame from the saved indices
    print(len(tensor_to_index))
    all_indices = []
    for indices in tensor_to_index.values():
        all_indices.extend(indices)
    return all_df.loc[all_indices]

In [16]:
# Select 1 example of each label
one_example = select_n_per_multilabel(all_df, 1)
one_example

100%|██████████| 40000/40000 [00:02<00:00, 17344.22it/s]

790





Unnamed: 0,raw_lat,raw_lng,latitude,longitude,status,pano_id,fov,pitch,heading,img_path
282448,48.775398,9.173911,48.775396,9.173885,OK,CAoSLEFGMVFpcE1sNzA1RktLd2M2MF9lQmJ1MTV2TzdsVE...,45,0,122,sv_CAoSLEFGMVFpcE1sNzA1RktLd2M2MF9lQmJ1MTV2Tzd...
440540,35.062452,-106.446064,35.062524,-106.446016,OK,wYOMND22d1-zpuIKvWAZBA,45,0,218,sv_wYOMND22d1-zpuIKvWAZBA.jpg
414898,43.142103,-2.968665,43.142191,-2.968652,OK,9aSWf2Ho5IPnMAqzay4ZwA,45,0,280,sv_9aSWf2Ho5IPnMAqzay4ZwA.jpg
363439,36.888929,-83.055294,36.888921,-83.055268,OK,T3foQFtbgjHXo-8UQzI2BA,45,0,93,sv_T3foQFtbgjHXo-8UQzI2BA.jpg
308927,60.265634,6.623198,60.265604,6.622986,OK,NGuN8t8w6p85mU-M0h32wg,45,0,127,sv_NGuN8t8w6p85mU-M0h32wg.jpg
...,...,...,...,...,...,...,...,...,...,...
334538,51.506658,-0.182455,51.506663,-0.182439,OK,mFdFz7dM4qvP4cAypawOkw,45,0,106,sv_mFdFz7dM4qvP4cAypawOkw.jpg
386809,25.721980,119.376174,25.721577,119.376546,OK,CAoSLEFGMVFpcFBwRkczWjNnX0NaR1l2a1lybjNwcmRlWm...,45,0,245,sv_CAoSLEFGMVFpcFBwRkczWjNnX0NaR1l2a1lybjNwcmR...
21026,33.884977,10.106528,33.884819,10.106134,OK,tVObhx_XR5UJ2Qzql5ZbOQ,45,0,75,sv_tVObhx_XR5UJ2Qzql5ZbOQ.jpg
426647,64.892482,-147.692661,64.892773,-147.692707,OK,0le3iWUiV0pkHCNrgKkXuQ,45,0,274,sv_0le3iWUiV0pkHCNrgKkXuQ.jpg


In [18]:
def write_dataset_as_wds(dataset_df, img_base_dir, out_pattern):
    def write_wds_row(row, sink):
        full_img_path = img_base_dir / row.img_path
        assert full_img_path.exists()

        wds_object = {
            "__key__": full_img_path.stem,
            "jpg": full_img_path.read_bytes(),
            "json": json.dumps(row._asdict()).encode("utf-8"),
        }
        sink.write(wds_object)

    dataset_df = dataset_df.sample(frac=1) # shuffle

    with webdataset.ShardWriter(out_pattern, encoder=False) as sink:
        for row in tqdm.tqdm(dataset_df.itertuples(), total=len(dataset_df.index)):
            write_wds_row(row, sink)


#write_dataset_as_wds(
#    one_example,
#    DATASETS / "im2gps/outputs/img2023",
#    str(DATASETS / "im2gps_overfit/wds/im2gps_2023_overfit_one_%03d.tar")
#)

write_dataset_as_wds(
    one_example,
    DATASETS / "img2loc/outputs/world1/img",
    str(DATASETS / "im2gps_overfit/wds/world1_overfit_one_%03d.tar"),
)

# writing /home/fyhuang/LocalProjects/datasets/im2gps_overfit/wds/world1_overfit_one_000.tar 0 0.0 GB 0


100%|██████████| 790/790 [00:01<00:00, 764.52it/s]


In [None]:
# Select 5 examples of each label
five_example = select_n_per_multilabel(all_df, 5)

write_dataset_as_wds(
    five_example,
    DATASETS / "im2gps" / "outputs" / "img",
    str(DATASETS / "im2gps_overfit" / "wds" / "im2gps_2023_overfit_five_%03d.tar")
)