# Annotate a (web)dataset with the S2 cell of each example

Create a tree of S2 cells adaptively, making sure each cell doesn't have too many/few examples.

Apply the S2 cell tree to a webdataset, labeling each example with the smallest S2 cell in the tree that contains the lat/lng of the example.

In [4]:
%load_ext autoreload
%autoreload 2
import collections
from pathlib import Path

import pandas
import webdataset
import s2sphere

import tqdm

import label_mapping

TRAIN_DF_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "im2gps_2007.pkl"
TRAIN_CELLS_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "s2cell_2007"

TARGET_DF_PATH = Path.home() / "datasets" / "im2gps3ktest" / "im2gps3ktest.pkl"
DATASET_OUT_PATH = Path.home() / "datasets" / "im2gps3ktest"

MIN_CELL_LEVEL = 6
MAX_CELL_LEVEL = 23

CELL_MAX_EXAMPLES = 500
CELL_MIN_EXAMPLES = 25

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Analyse the train dataset to build the S2 cell tree

In [38]:
train_df = pandas.read_pickle(TRAIN_DF_PATH)

# Format: {level: {cell_id: (child_cell_id, ...)}}
cell_by_id = {}
cells_by_level = collections.defaultdict(lambda: collections.defaultdict(set))

# Build the cell_by_id index and initialize cells_by_level
for index, row in tqdm.tqdm(train_df.iterrows(), total=len(train_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row["latitude"], row["longitude"])
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng).parent(MAX_CELL_LEVEL)
    cell_by_id[row["id"]] = s2_cell_id

    parent_cell = s2_cell_id.parent(MIN_CELL_LEVEL)
    cells_by_level[parent_cell.level()][parent_cell.id()].add(s2_cell_id.id())

100%|██████████| 635626/635626 [00:45<00:00, 13883.30it/s]


In [53]:
# Walk the tree and split cells
for level in range(1, MAX_CELL_LEVEL):
    print(f"Splitting at level {level}")
    celldict = cells_by_level[level]
    for cell_id, cellset in tqdm.tqdm(celldict.items()):
        if len(cellset) <= CELL_MAX_EXAMPLES:
            continue

        next_level = level+1
        for child_id in cellset:
            child_cell = s2sphere.CellId(child_id)
            next_level_parent = child_cell.parent(next_level)
            cells_by_level[next_level][next_level_parent.id()].add(child_id)
        cellset.clear()

# Flatten the cells_by_level dict and remove cells with too few examples
candidate_celldicts = {}
for level, celldict in cells_by_level.items():
    for cell_id, cellset in celldict.items():
        if len(cellset) >= CELL_MIN_EXAMPLES:
            candidate_celldicts[cell_id] = cellset

print(f"Number of cells = {len(candidate_celldicts)}")

# Save the cell mapping to disk
tokens = [s2sphere.CellId(cell_id).to_token() for cell_id in sorted(candidate_celldicts.keys())]
mapping = label_mapping.LabelMapping(tokens)
mapping.to_csv(DATASET_OUT_PATH / "cells.csv")

Splitting at level 1


0it [00:00, ?it/s]


Splitting at level 2


0it [00:00, ?it/s]


Splitting at level 3


0it [00:00, ?it/s]


Splitting at level 4


0it [00:00, ?it/s]


Splitting at level 5


0it [00:00, ?it/s]


Splitting at level 6


100%|██████████| 4060/4060 [00:00<00:00, 1421086.06it/s]


Splitting at level 7


100%|██████████| 324/324 [00:00<00:00, 2418068.50it/s]


Splitting at level 8


100%|██████████| 220/220 [00:00<00:00, 930188.39it/s]


Splitting at level 9


100%|██████████| 185/185 [00:00<00:00, 762975.65it/s]


Splitting at level 10


100%|██████████| 186/186 [00:00<00:00, 803440.31it/s]


Splitting at level 11


100%|██████████| 158/158 [00:00<00:00, 607424.41it/s]


Splitting at level 12


100%|██████████| 140/140 [00:00<00:00, 895125.85it/s]


Splitting at level 13


100%|██████████| 108/108 [00:00<00:00, 1143901.09it/s]


Splitting at level 14


100%|██████████| 16/16 [00:00<00:00, 224444.36it/s]


Splitting at level 15


0it [00:00, ?it/s]


Splitting at level 16


0it [00:00, ?it/s]


Splitting at level 17


0it [00:00, ?it/s]


Splitting at level 18


0it [00:00, ?it/s]


Splitting at level 19


0it [00:00, ?it/s]


Splitting at level 20


0it [00:00, ?it/s]


Splitting at level 21


0it [00:00, ?it/s]


Splitting at level 22


0it [00:00, ?it/s]

Number of cells = 1776





In [54]:
# Print tokens for viz (paste into s2.inair.space)
print(",".join(tokens))

0097,0099,009b,009d,00a7,0717,07ab,0b43,0b47,0b5d,0c41,0c47,0c61,0c6b,0d05,0d0b,0d0d,0d11,0d13,0d15,0d17,0d18c,0d1931,0d1933,0d1935,0d195,0d19c,0d1b,0d1f,0d23,0d25,0d2f,0d31,0d37,0d39,0d3b,0d3d,0d3f,0d41,0d4224,0d42284,0d4228c,0d42294,0d422f,0d4234,0d45,0d47,0d49,0d4f,0d51,0d55,0d57,0d59,0d5b,0d5d,0d5f,0d61,0d63,0d6b,0d6d,0d6f,0d71,0d73,0d7b,0d97,0d9f,0da1,0da7,0dad,0daf,0db1,0db3,0dbb,0dbd,0e39,0e3b,0ec1,0ec3,0fdd,0fdf,103b,1061,1257,1297,1299,129f,12a1,12a3,12a44,12a49,12a4a25,12a4a27,12a4a29,12a4a2b,12a4a2d,12a4a2e4,12a4a2ec,12a4a2f4,12a4a2fc,12a4a34,12a4a3c,12a4bc,12a4f,12a54,12a5c,12a7,12a9,12ab,12ad,12af,12b1,12b3,12b5,12b7,12bb,12c9,12cb,12cd,12cf,12d1,12d3,12d5,12d7,12d9,12db,12dd,12e3,12fd,130f,1311,1313,1315,1317,1319,1325,1329,132a3,132a51,132a53,132a55,132a57,132ac,132bc,132d,132ec,132f5,132f603,132f6044,132f604c,132f6054,132f605c,132f607,132f60c,132f614,132f61a4,132f61ac,132f61b4,132f61bc,132f61d,1331,1335,133b4,133bc,1347,134b,134d,134f,1351,1355,1357,1359,135b,135d,135f,

## Annotate the target dataset using the above s2 cell tree

In [5]:
#cell_set = set(candidate_celldicts.keys())
mapping = label_mapping.LabelMapping.read_csv(TRAIN_CELLS_PATH / "cells.csv")
cell_set = set([s2sphere.CellId.from_token(token).id() for token in mapping.name_to_label.keys()])
print(len(cell_set))
target_df = pandas.read_pickle(TARGET_DF_PATH)

s2cell_labels = []
for row in tqdm.tqdm(target_df.itertuples(), total=len(target_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng)

    while s2_cell_id.id() not in cell_set:
        if s2_cell_id.level() < MIN_CELL_LEVEL:
            break
        s2_cell_id = s2_cell_id.parent()

    if s2_cell_id.id() not in cell_set:
        # This example can't be labeled
        s2cell_labels.append(None)
    else:
        s2cell_labels.append(s2_cell_id.to_token())
        assert s2_cell_id.to_token() in mapping.name_to_label

target_df["s2cell"] = s2cell_labels
target_df.to_pickle(DATASET_OUT_PATH / "s2_annotated.pkl")

1776


100%|██████████| 2997/2997 [00:00<00:00, 10372.90it/s]
