# Annotate a (web)dataset with the S2 cell of each example

Create a tree of S2 cells adaptively, making sure each cell doesn't have too many/few examples.

Apply the S2 cell tree to a webdataset, labeling each example with the smallest S2 cell in the tree that contains the lat/lng of the example.

In [1]:
%load_ext autoreload
%autoreload 2
import copy
import collections
from pathlib import Path

import pandas
import webdataset
import s2sphere

import tqdm

from mlutil import label_mapping

MIN_CELL_LEVEL = 5
MAX_CELL_LEVEL = 23

IM2GPS_2007_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "im2gps_2007.pkl"
SV_WORLD1_PATH = Path.home() / "datasets" / "img2loc" / "outputs" / "world1" / "s3_parameterized.pkl"
IM2GPS_2023_PATH = Path.home() / "datasets" / "im2gps" / "outputs" / "im2gps_2023.pkl"

train_df = pandas.concat([
    pandas.read_pickle(IM2GPS_2007_PATH),
    pandas.read_pickle(SV_WORLD1_PATH).head(20000),
    pandas.read_pickle(IM2GPS_2023_PATH),
])

## Analyse the train dataset to build the S2 cell tree

In [2]:
# Format: {level: {cell_id: (child_cell_id, ...)}}
cells_by_level = collections.defaultdict(lambda: collections.defaultdict(set))

# Build the cell_by_id index and initialize cells_by_level
for row in tqdm.tqdm(train_df.itertuples(), total=len(train_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng)

    parent_cell = s2_cell_id.parent(MIN_CELL_LEVEL)
    cells_by_level[parent_cell.level()][parent_cell.id()].add(s2_cell_id.id())

  0%|          | 0/1668402 [00:00<?, ?it/s]

100%|██████████| 1668402/1668402 [00:25<00:00, 64630.74it/s]


In [3]:
CELL_MAX_EXAMPLES = 5000
CELL_MIN_EXAMPLES = 100

def split_tree_leaves_only(input_cells_by_level):
    cells_by_level = copy.deepcopy(input_cells_by_level)

    # Walk the tree and split cells
    for level in range(1, MAX_CELL_LEVEL):
        celldict = cells_by_level[level]
        for cell_id, cellset in tqdm.tqdm(celldict.items(), desc=f"Level {level}"):
            if len(cellset) <= CELL_MAX_EXAMPLES:
                continue

            next_level = level+1
            for child_id in cellset:
                child_cell = s2sphere.CellId(child_id)
                next_level_parent = child_cell.parent(next_level)
                cells_by_level[next_level][next_level_parent.id()].add(child_id)
            cellset.clear()

    # Flatten the cells_by_level dict and remove cells with too few examples
    candidate_celldicts = {}
    for level, celldict in cells_by_level.items():
        for cell_id, cellset in celldict.items():
            if len(cellset) >= CELL_MIN_EXAMPLES:
                candidate_celldicts[cell_id] = cellset

    print(f"Number of cells = {len(candidate_celldicts)}")
    cell_counts = pandas.DataFrame([(cell_id, len(cellset)) for cell_id, cellset in candidate_celldicts.items()], columns=["cell_id", "count"])
    print(cell_counts.describe())

    candidate_tokens = [s2sphere.CellId(cell_id).to_token() for cell_id in sorted(candidate_celldicts.keys())]
    return candidate_tokens

def split_tree_keep_parents(input_cells_by_level):
    cells_by_level = copy.deepcopy(input_cells_by_level)

    # Walk the tree and split cells, keeping cells in the parent cell
    for level in range(1, MAX_CELL_LEVEL):
        celldict = cells_by_level[level]
        for cell_id, cellset in tqdm.tqdm(celldict.items(), desc=f"Level {level}"):
            if len(cellset) <= CELL_MAX_EXAMPLES:
                continue

            next_level = level+1
            for child_id in cellset:
                child_cell = s2sphere.CellId(child_id)
                next_level_parent = child_cell.parent(next_level)
                cells_by_level[next_level][next_level_parent.id()].add(child_id)

    # Starting from bottom, remove cells with too few examples
    for level in range(MAX_CELL_LEVEL, 1, -1):
        cell_items = list(cells_by_level[level].items())
        for cell_id, cellset in tqdm.tqdm(cell_items, desc=f"Level {level}"):
            if len(cellset) < CELL_MIN_EXAMPLES:
                del cells_by_level[level][cell_id]

    # Flatten without filtering
    candidate_celldicts = {}
    for level, celldict in cells_by_level.items():
        for cell_id, cellset in celldict.items():
            candidate_celldicts[cell_id] = cellset

    print(f"Number of cells = {len(candidate_celldicts)}")
    cell_counts = pandas.DataFrame([(cell_id, len(cellset)) for cell_id, cellset in candidate_celldicts.items()], columns=["cell_id", "count"])
    print(cell_counts.describe())

    candidate_tokens = [s2sphere.CellId(cell_id).to_token() for cell_id in sorted(candidate_celldicts.keys())]
    return candidate_tokens


#candidate_tokens = split_tree_leaves_only(cells_by_level)
candidate_tokens = split_tree_keep_parents(cells_by_level)

Level 1: 0it [00:00, ?it/s]
Level 2: 0it [00:00, ?it/s]
Level 3: 0it [00:00, ?it/s]
Level 4: 0it [00:00, ?it/s]
Level 5: 100%|██████████| 1995/1995 [00:00<00:00, 3855.69it/s]
Level 6: 100%|██████████| 132/132 [00:00<00:00, 433.00it/s]
Level 7: 100%|██████████| 86/86 [00:00<00:00, 394.75it/s]
Level 8: 100%|██████████| 52/52 [00:00<00:00, 360.09it/s]
Level 9: 100%|██████████| 35/35 [00:00<00:00, 283.71it/s]
Level 10: 100%|██████████| 39/39 [00:00<00:00, 476.81it/s]
Level 11: 100%|██████████| 32/32 [00:00<00:00, 3256.29it/s]
Level 12: 100%|██████████| 4/4 [00:00<00:00, 118987.35it/s]
Level 13: 0it [00:00, ?it/s]
Level 14: 0it [00:00, ?it/s]
Level 15: 0it [00:00, ?it/s]
Level 16: 0it [00:00, ?it/s]
Level 17: 0it [00:00, ?it/s]
Level 18: 0it [00:00, ?it/s]
Level 19: 0it [00:00, ?it/s]
Level 20: 0it [00:00, ?it/s]
Level 21: 0it [00:00, ?it/s]
Level 22: 0it [00:00, ?it/s]
Level 23: 0it [00:00, ?it/s]
Level 22: 0it [00:00, ?it/s]
Level 21: 0it [00:00, ?it/s]
Level 20: 0it [00:00, ?it/s]
Level 

Number of cells = 930
            cell_id         count
count  9.300000e+02    930.000000
mean   6.013823e+18   1875.522581
std    3.161695e+18   3495.953480
min    4.391010e+16    100.000000
25%    4.073506e+18    242.250000
50%    5.220657e+18    551.500000
75%    9.669791e+18   1888.750000
max    1.374836e+19  41297.000000


In [45]:
# Print tokens for viz (paste into s2.inair.space)
print(",".join(candidate_tokens))

009c,00a4,0714,07ac,094c,0b44,0c3c,0c44,0c4c,0c64,0c6c,0d04,0d0c,0d14,0d19,0d1b,0d1c,0d1f,0d24,0d2c,0d34,0d3c,0d41,0d424,0d43,0d434,0d44,0d45,0d47,0d4c,0d54,0d5c,0d64,0d6c,0d74,0d7c,0d94,0d9c,0da4,0dac,0db4,0dbc,0e84,0e94,0ec4,0fdc,103c,1254,1294,129c,12a1,12a3,12a4,12a44,12a49,12a4a24,12a4a2c,12a4a3,12a4a34,12a4a3c,12a4a4,12a4b,12a4bc,12a4c,12a4f,12a5,12a54,12a5c,12a7,12ac,12b4,12bc,12c9,12cb,12cc,12cd,12cf,12d4,12dc,12e4,12fc,1304,130c,1314,131c,1324,1329,132b,132c,132d,132f,1334,133c,1344,134c,1354,135c,1364,13ac,1434,144c,1454,145c,1494,149c,14a4,14ac,14b4,14bc,14c4,14cc,14d4,14dc,14e4,14f4,1504,1514,151c,1524,152c,1534,1554,155c,15a4,15c4,1764,1774,177c,1814,1824,182c,1834,183c,1844,185c,18dc,194c,1954,195c,1964,19c4,19cc,19dc,1a4c,1b8c,1b94,1b9c,1ba4,1bf4,1c0c,1c14,1c3c,1c6c,1c74,1dcc,1dd4,1e64,1e7c,1e94,1ebc,1ec4,1ee4,1eec,1ef4,1f64,21dc,21e4,21f4,2214,2a34,2dd4,2ddc,2e44,2e6c,2e7c,302c,3034,304c,3054,30c4,30cc,30d4,30dc,30e4,30fc,3104,310c,3114,311c,3124,312c,3134,313c,3144,314

In [4]:
# Save the cell mapping to disk
cell_mapping_path = \
    Path.home() / "datasets" / "im2gps" / "outputs" / "s2cell_930_ml.csv"

mapping = label_mapping.LabelMapping(candidate_tokens)
mapping.to_csv(cell_mapping_path)

## Annotate the target dataset using the above s2 cell tree

In [5]:
#cell_set = set(candidate_celldicts.keys())
mapping = label_mapping.LabelMapping.read_csv(TRAIN_CELLS_PATH / "cells.csv")
cell_set = set([s2sphere.CellId.from_token(token).id() for token in mapping.name_to_label.keys()])
print(len(cell_set))
target_df = pandas.read_pickle(TARGET_DF_PATH)

s2cell_labels = []
for row in tqdm.tqdm(target_df.itertuples(), total=len(target_df.index)):
    latlng = s2sphere.LatLng.from_degrees(row.latitude, row.longitude)
    s2_cell_id = s2sphere.CellId.from_lat_lng(latlng)

    while s2_cell_id.id() not in cell_set:
        if s2_cell_id.level() < MIN_CELL_LEVEL:
            break
        s2_cell_id = s2_cell_id.parent()

    if s2_cell_id.id() not in cell_set:
        # This example can't be labeled
        s2cell_labels.append(None)
    else:
        s2cell_labels.append(s2_cell_id.to_token())
        assert s2_cell_id.to_token() in mapping.name_to_label

target_df["s2cell"] = s2cell_labels
target_df.to_pickle(DATASET_OUT_PATH / "s2_annotated.pkl")

1776


100%|██████████| 2997/2997 [00:00<00:00, 10372.90it/s]
