# Annotate a (web)dataset with the S2 cell of each example

Create a tree of S2 cells adaptively, making sure each cell doesn't have too many/few examples.

Apply the S2 cell tree to a webdataset, labeling each example with the smallest S2 cell in the tree that contains the lat/lng of the example.

In [1]:
%load_ext autoreload
%autoreload 2
import collections
from pathlib import Path

import pandas
import webdataset
import s2sphere

import tqdm

from mlutil import label_mapping

MIN_CELL_LEVEL = 6
MAX_CELL_LEVEL = 23

CELL_MAX_EXAMPLES = 10000
CELL_MIN_EXAMPLES = 100

IM2GPS_2007_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "im2gps_2007.pkl"
SV_WORLD1_PATH = Path.home() / "datasets" / "img2loc" / "outputs" / "world1" / "s3_parameterized.pkl"

train_df = pandas.concat([pandas.read_pickle(IM2GPS_2007_PATH), pandas.read_pickle(SV_WORLD1_PATH)])

## Analyse the train dataset to build the S2 cell tree

In [2]:
# Format: {level: {cell_id: (child_cell_id, ...)}}
cells_by_level = collections.defaultdict(lambda: collections.defaultdict(set))

# Build the cell_by_id index and initialize cells_by_level
for row in tqdm.tqdm(train_df.itertuples(), total=len(train_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng).parent(MAX_CELL_LEVEL)

    parent_cell = s2_cell_id.parent(MIN_CELL_LEVEL)
    cells_by_level[parent_cell.level()][parent_cell.id()].add(s2_cell_id.id())

100%|██████████| 795013/795013 [00:13<00:00, 57119.30it/s]


In [5]:
# Walk the tree and split cells
for level in range(1, MAX_CELL_LEVEL):
    celldict = cells_by_level[level]
    for cell_id, cellset in tqdm.tqdm(celldict.items(), desc=f"Level {level}"):
        if len(cellset) <= CELL_MAX_EXAMPLES:
            continue

        next_level = level+1
        for child_id in cellset:
            child_cell = s2sphere.CellId(child_id)
            next_level_parent = child_cell.parent(next_level)
            cells_by_level[next_level][next_level_parent.id()].add(child_id)
        cellset.clear()

# Flatten the cells_by_level dict and remove cells with too few examples
candidate_celldicts = {}
for level, celldict in cells_by_level.items():
    for cell_id, cellset in celldict.items():
        if len(cellset) >= CELL_MIN_EXAMPLES:
            candidate_celldicts[cell_id] = cellset

print(f"Number of cells = {len(candidate_celldicts)}")
cell_counts = pandas.DataFrame([(cell_id, len(cellset)) for cell_id, cellset in candidate_celldicts.items()], columns=["cell_id", "count"])
print(cell_counts.describe())

candidate_tokens = [s2sphere.CellId(cell_id).to_token() for cell_id in sorted(candidate_celldicts.keys())]

Level 1: 0it [00:00, ?it/s]
Level 2: 0it [00:00, ?it/s]

Level 2: 0it [00:00, ?it/s]
Level 3: 0it [00:00, ?it/s]
Level 4: 0it [00:00, ?it/s]
Level 5: 0it [00:00, ?it/s]
Level 6: 100%|██████████| 4746/4746 [00:00<00:00, 1500087.93it/s]
Level 7: 100%|██████████| 4/4 [00:00<00:00, 104857.60it/s]
Level 8: 0it [00:00, ?it/s]
Level 9: 0it [00:00, ?it/s]
Level 10: 0it [00:00, ?it/s]
Level 11: 0it [00:00, ?it/s]
Level 12: 0it [00:00, ?it/s]
Level 13: 0it [00:00, ?it/s]
Level 14: 0it [00:00, ?it/s]
Level 15: 0it [00:00, ?it/s]
Level 16: 0it [00:00, ?it/s]
Level 17: 0it [00:00, ?it/s]
Level 18: 0it [00:00, ?it/s]
Level 19: 0it [00:00, ?it/s]
Level 20: 0it [00:00, ?it/s]
Level 21: 0it [00:00, ?it/s]
Level 22: 0it [00:00, ?it/s]

Number of cells = 773
            cell_id        count
count  7.730000e+02   773.000000
mean   5.946330e+18   491.102199
std    3.103764e+18   775.948650
min    4.306567e+16   100.000000
25%    3.843541e+18   156.000000
50%    5.182799e+18   253.000000
75%    9.716235e+18   506.000000
max    1.223769e+19  9391.000000





In [4]:
# Print tokens for viz (paste into s2.inair.space)
print(",".join(candidate_tokens))

0099,009b,00a7,00b7,00b9,0717,07ab,07c7,0c41,0c47,0c61,0c6b,0d0d,0d11,0d13,0d17,0d19,0d1b,0d1f,0d23,0d25,0d2f,0d31,0d37,0d3b,0d3d,0d3f,0d41,0d43,0d45,0d47,0d49,0d4f,0d51,0d55,0d57,0d59,0d61,0d63,0d65,0d6b,0d6d,0d71,0d73,0da7,0daf,0ec1,0fdf,1039,103b,1297,12a1,12a3,12a5,12a7,12a9,12ab,12ad,12af,12b1,12b3,12b5,12b7,12bb,12c9,12cb,12cd,12cf,12d3,12d5,12dd,12e7,1311,1313,1315,1325,1329,132b,132d,132f,1331,1335,1339,133b,1345,1347,134b,134d,1351,1359,135b,135d,135f,1361,1449,1459,1495,1499,149b,149d,149f,14a1,14a3,14a9,14ad,14bb,14bf,14c3,14cb,14d3,14df,14e7,14f5,1501,1503,1519,151d,151f,152b,1559,177d,1829,182d,182f,1835,185d,19cd,19dd,1dcd,1e95,1ec3,1ef7,2a33,2d95,2dbf,2dcd,2dd1,2dd3,2dd7,2de5,2e41,2e43,2e65,2e69,2e6b,2e6f,2e71,2e77,2e79,2e7b,2fd5,3031,304b,3051,3055,30d7,30db,30e3,3103,3109,3111,3119,311d,3135,3143,3175,31cd,31db,3287,32f9,338f,3391,3397,33a1,33a9,33ab,33bd,3401,3403,3405,3443,345d,3469,346f,34e5,353f,3541,3543,3545,354f,3551,3553,3555,3557,355b,356b,357b,357d,35b3,35f1,

In [None]:
# Save the cell mapping to disk
TRAIN_CELLS_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "s2cell_2007"

mapping = label_mapping.LabelMapping(candidate_tokens)
mapping.to_csv(TRAIN_CELLS_PATH / "cells.csv")

## Annotate the target dataset using the above s2 cell tree

In [5]:
#cell_set = set(candidate_celldicts.keys())
mapping = label_mapping.LabelMapping.read_csv(TRAIN_CELLS_PATH / "cells.csv")
cell_set = set([s2sphere.CellId.from_token(token).id() for token in mapping.name_to_label.keys()])
print(len(cell_set))
target_df = pandas.read_pickle(TARGET_DF_PATH)

s2cell_labels = []
for row in tqdm.tqdm(target_df.itertuples(), total=len(target_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng)

    while s2_cell_id.id() not in cell_set:
        if s2_cell_id.level() < MIN_CELL_LEVEL:
            break
        s2_cell_id = s2_cell_id.parent()

    if s2_cell_id.id() not in cell_set:
        # This example can't be labeled
        s2cell_labels.append(None)
    else:
        s2cell_labels.append(s2_cell_id.to_token())
        assert s2_cell_id.to_token() in mapping.name_to_label

target_df["s2cell"] = s2cell_labels
target_df.to_pickle(DATASET_OUT_PATH / "s2_annotated.pkl")

1776


100%|██████████| 2997/2997 [00:00<00:00, 10372.90it/s]
