# im2gps small dataset (for overfitting)

In [17]:
import collections
from pathlib import Path
import json

import pandas
import s2sphere
import webdataset
import tqdm

import mlutil

DATASETS = Path.home() / "LocalProjects" / "datasets"

label_mapping = mlutil.label_mapping.LabelMapping.read_csv(DATASETS / "im2gps" / "outputs" / "s2cell_2007" / "cells.csv")
s2cell_mapping = mlutil.s2cell_mapping.S2CellMapping.from_label_mapping(label_mapping)
len(s2cell_mapping.all_cell_ids)

1776

In [8]:
# Load dataframe and select some small subsets
all_df = pandas.read_pickle(DATASETS / "im2gps" / "outputs" / "im2gps_2007.pkl")
all_df.columns

Index(['id', 'owner', 'secret', 'server', 'farm', 'title', 'ispublic',
       'isfriend', 'isfamily', 'dateupload', 'latitude', 'longitude',
       'accuracy', 'context', 'place_id', 'woeid', 'geo_is_public',
       'geo_is_contact', 'geo_is_friend', 'geo_is_family', 'interestingness',
       'tag', 'split'],
      dtype='object')

In [18]:
def select_n_per_label(df, n):
    label_to_index = collections.defaultdict(list)
    for row in all_df.itertuples():
        s2cell_token = s2cell_mapping.lat_lng_to_token(row.latitude, row.longitude)
        if s2cell_token is None:
            continue
        if len(label_to_index[s2cell_token]) < n:
            label_to_index[s2cell_token].append(row.Index)

    # Create a new DataFrame from the saved indices
    return all_df.loc[pandas.concat([all_df.loc[indices] for indices in label_to_index.values()]).index]

# Select 1 example of each label
one_example = select_n_per_label(all_df, 1)

# Select 5 examples of each label
five_example = select_n_per_label(all_df, 5)

one_example

Unnamed: 0,id,owner,secret,server,farm,title,ispublic,isfriend,isfamily,dateupload,...,context,place_id,woeid,geo_is_public,geo_is_contact,geo_is_friend,geo_is_family,interestingness,tag,split
0,398492752,17392647@N00,e3152820b2,169,1,Tokyo roads,1,0,0,1172128070,...,0,cLK5.HBQU7vUgaJM,1112372.0,1,0,0,0,0,Tokyo,train
2,397737291,88468856@N00,f7572f28ed,164,1,fall into despair,1,0,0,1172074105,...,0,FRthiQZQU7uKHvmP,1118370.0,1,0,0,0,2,Tokyo,train
3,391903677,58776404@N00,0d983cd099,131,1,[Kichijouji]_070213_085 Enpty Line,1,0,0,1171618333,...,0,2Eh6._NQV7rv6f0AjQ,15015379.0,1,0,0,0,3,Tokyo,train
4,397665565,36516818@N00,c3160578b6,149,1,City View,1,0,0,1172068796,...,0,,26198557.0,1,0,0,0,4,Tokyo,train
5,398556956,32448339@N00,afe2d5b146,127,1,Yoshi at work,1,0,0,1172134776,...,0,tSbdQrlWU7oG9AQ,710281.0,1,0,0,0,5,Tokyo,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
590235,400014411,15432472@N00,db3674f6e8,141,1,beinn dearg & cona mheal,1,0,0,1172258873,...,0,iToXzFpQULwFedvtwQ,12602203.0,1,0,0,0,81,scotland,train
590242,391523966,30942843@N00,3ed5f38aab,181,1,IMG_0012,1,0,0,1171580430,...,0,1VBbW35SULs9gA,32109.0,1,0,0,0,88,scotland,train
590273,396348599,57099173@N00,0fa921634e,163,1,Loch lomond,1,0,0,1171972228,...,0,.5sqqr5QU7zXzQ,11691.0,1,0,0,0,119,scotland,train
590319,393180310,42624857@N00,908510034a,136,1,Leaderfoot Bridges,1,0,0,1171739311,...,0,BaKOF0tSUbqOMA,33077.0,1,0,0,0,165,scotland,train


In [19]:
def write_dataset_as_wds(dataset_df, img_base_dir, out_pattern):
    def row_subdir(row):
        return img_base_dir / row.tag / '{:05d}'.format(row.Index//1000)
    def row_filename_stem(row):
        return f"{row.id}_{row.secret}_{row.server}_{row.owner}"

    def write_wds_row(row, sink, split=None):
        if split is not None and row.split != split:
            return

        img_path = row_subdir(row) / f"{row_filename_stem(row)}.jpg"
        if not img_path.exists():
            return

        wds_object = {
            "__key__": row_filename_stem(row),
            "jpg": img_path.read_bytes(),
            "json": json.dumps(row._asdict()).encode("utf-8"),
        }
        sink.write(wds_object)

    dataset_df = dataset_df.sample(frac=1) # shuffle

    with webdataset.ShardWriter(out_pattern, encoder=False) as sink:
        for row in tqdm.tqdm(dataset_df.itertuples(), total=len(dataset_df.index)):
            write_wds_row(row, sink)


write_dataset_as_wds(
    one_example,
    DATASETS / "im2gps" / "outputs" / "img",
    str(DATASETS / "im2gps_overfit" / "wds" / "im2gps_overfit_one_%03d.tar")
)

write_dataset_as_wds(
    five_example,
    DATASETS / "im2gps" / "outputs" / "img",
    str(DATASETS / "im2gps_overfit" / "wds" / "im2gps_overfit_five_%03d.tar")
)

# writing /home/fyhuang/LocalProjects/datasets/im2gps_overfit/wds/im2gps_overfit_one_000.tar 0 0.0 GB 0


100%|██████████| 1776/1776 [00:00<00:00, 2033.74it/s]


# writing /home/fyhuang/LocalProjects/datasets/im2gps_overfit/wds/im2gps_overfit_five_000.tar 0 0.0 GB 0


100%|██████████| 8880/8880 [00:04<00:00, 2130.98it/s]
