# Annotate a (web)dataset with the S2 cell of each example

Create a tree of S2 cells adaptively, making sure each cell doesn't have too many/few examples.

Apply the S2 cell tree to a webdataset, labeling each example with the smallest S2 cell in the tree that contains the lat/lng of the example.

In [4]:
%load_ext autoreload
%autoreload 2
import collections
from pathlib import Path

import pandas
import webdataset
import s2sphere

import tqdm

from mlutil import label_mapping

MIN_CELL_LEVEL = 6
MAX_CELL_LEVEL = 23

CELL_MAX_EXAMPLES = 10000
CELL_MIN_EXAMPLES = 100

TRAIN_DF_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "im2gps_2007.pkl"
train_df = pandas.read_pickle(TRAIN_DF_PATH)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Analyse the train dataset to build the S2 cell tree

In [5]:
# Format: {level: {cell_id: (child_cell_id, ...)}}
cells_by_level = collections.defaultdict(lambda: collections.defaultdict(set))

# Build the cell_by_id index and initialize cells_by_level
for row in tqdm.tqdm(train_df.itertuples(), total=len(train_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng).parent(MAX_CELL_LEVEL)

    parent_cell = s2_cell_id.parent(MIN_CELL_LEVEL)
    cells_by_level[parent_cell.level()][parent_cell.id()].add(s2_cell_id.id())

100%|██████████| 591441/591441 [00:13<00:00, 44250.32it/s]


In [8]:
# Walk the tree and split cells
for level in range(1, MAX_CELL_LEVEL):
    celldict = cells_by_level[level]
    for cell_id, cellset in tqdm.tqdm(celldict.items(), desc=f"Level {level}"):
        if len(cellset) <= CELL_MAX_EXAMPLES:
            continue

        next_level = level+1
        for child_id in cellset:
            child_cell = s2sphere.CellId(child_id)
            next_level_parent = child_cell.parent(next_level)
            cells_by_level[next_level][next_level_parent.id()].add(child_id)
        cellset.clear()

# Flatten the cells_by_level dict and remove cells with too few examples
candidate_celldicts = {}
for level, celldict in cells_by_level.items():
    for cell_id, cellset in celldict.items():
        if len(cellset) >= CELL_MIN_EXAMPLES:
            candidate_celldicts[cell_id] = cellset

print(f"Number of cells = {len(candidate_celldicts)}")
candidate_tokens = [s2sphere.CellId(cell_id).to_token() for cell_id in sorted(candidate_celldicts.keys())]

Level 1: 0it [00:00, ?it/s]
Level 2: 0it [00:00, ?it/s]
Level 3: 0it [00:00, ?it/s]
Level 4: 0it [00:00, ?it/s]
Level 5: 0it [00:00, ?it/s]
Level 6:   0%|          | 0/4060 [00:00<?, ?it/s]

Level 6: 100%|██████████| 4060/4060 [00:00<00:00, 1573978.58it/s]
Level 7: 100%|██████████| 4/4 [00:00<00:00, 28149.69it/s]
Level 8: 0it [00:00, ?it/s]
Level 9: 0it [00:00, ?it/s]
Level 10: 0it [00:00, ?it/s]
Level 11: 0it [00:00, ?it/s]
Level 12: 0it [00:00, ?it/s]
Level 13: 0it [00:00, ?it/s]
Level 14: 0it [00:00, ?it/s]
Level 15: 0it [00:00, ?it/s]
Level 16: 0it [00:00, ?it/s]
Level 17: 0it [00:00, ?it/s]
Level 18: 0it [00:00, ?it/s]
Level 19: 0it [00:00, ?it/s]
Level 20: 0it [00:00, ?it/s]
Level 21: 0it [00:00, ?it/s]
Level 22: 0it [00:00, ?it/s]

Number of cells = 418





In [9]:
# Print tokens for viz (paste into s2.inair.space)
print(",".join(candidate_tokens))

0099,009b,0717,0d0d,0d11,0d13,0d19,0d1b,0d1f,0d23,0d25,0d2f,0d37,0d3b,0d41,0d43,0d49,0d51,0d61,0d6d,0d71,0d73,0da7,0daf,0ec1,1297,12a1,12a5,12af,12b5,12bb,12c9,12cd,12d5,1325,132b,132f,1335,133b,134b,134d,135b,1449,1459,1495,1499,149b,149d,14a1,14a3,14a9,14bf,14cb,14d3,14e7,14f5,1501,1503,1519,151d,151f,1559,182d,1835,185d,19cd,19dd,1dcd,1e95,1ef7,2a33,2dd3,2e69,3051,3055,30db,30e3,3103,3109,3111,3135,3175,31cd,31db,3397,3401,3403,3405,3443,345d,3469,346f,357d,35b3,35f1,3693,390d,3919,396d,3975,39eb,3a53,3baf,3be7,3e5f,3f8f,3fbd,3fcf,40a5,40ab,40d5,414b,45df,463d,463f,4641,464f,4653,465f,468d,468f,4693,4697,46b5,46dd,46ef,46fd,4709,470b,470f,4715,4717,471f,4741,4761,4763,4765,476d,4771,4773,4777,4779,477b,477d,477f,4781,4783,4785,4787,4789,478d,478f,4791,4795,4797,4799,479b,479d,479f,47a9,47b1,47b3,47b7,47b9,47bd,47bf,47c1,47c3,47c5,47c7,47d9,47dd,47df,47e7,4843,4845,484f,4859,485b,485d,485f,4861,4863,4865,4867,4869,486b,486d,486f,4871,4873,4875,4877,4879,487b,487d,487f,4885,4887,4889,

In [None]:
# Save the cell mapping to disk
TRAIN_CELLS_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "s2cell_2007"

mapping = label_mapping.LabelMapping(candidate_tokens)
mapping.to_csv(TRAIN_CELLS_PATH / "cells.csv")

## Annotate the target dataset using the above s2 cell tree

In [5]:
#cell_set = set(candidate_celldicts.keys())
mapping = label_mapping.LabelMapping.read_csv(TRAIN_CELLS_PATH / "cells.csv")
cell_set = set([s2sphere.CellId.from_token(token).id() for token in mapping.name_to_label.keys()])
print(len(cell_set))
target_df = pandas.read_pickle(TARGET_DF_PATH)

s2cell_labels = []
for row in tqdm.tqdm(target_df.itertuples(), total=len(target_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng)

    while s2_cell_id.id() not in cell_set:
        if s2_cell_id.level() < MIN_CELL_LEVEL:
            break
        s2_cell_id = s2_cell_id.parent()

    if s2_cell_id.id() not in cell_set:
        # This example can't be labeled
        s2cell_labels.append(None)
    else:
        s2cell_labels.append(s2_cell_id.to_token())
        assert s2_cell_id.to_token() in mapping.name_to_label

target_df["s2cell"] = s2cell_labels
target_df.to_pickle(DATASET_OUT_PATH / "s2_annotated.pkl")

1776


100%|██████████| 2997/2997 [00:00<00:00, 10372.90it/s]
