In [9]:
import pandas as pd
import plotly.graph_objects as go
from collections import defaultdict
from s2sphere import CellId, LatLng, Cell

In [None]:
# http://s2geometry.io/about/overview
# http://s2geometry.io/devguide/s2cell_hierarchy.html
def adaptive_partition(df: pd.DataFrame, t1: int = 10_000, t2: int = 50, max_level: int = 18):
    all_cells = defaultdict(list)
    
    # Loop through all lat,lon values
    for _, row in df.iterrows():
        lat, lon = row['city_lat'], row['city_lon']
        latlng = LatLng.from_degrees(lat, lon)
        cell_id = CellId.from_lat_lng(latlng)
        
        # Add the lat,lon to be cells at each hierarchy
        for level in range(max_level + 1):
            parent = cell_id.parent(level)
            all_cells[parent.id()].append((lat, lon))
            
    final_cells = set()
    processed = set()
    
    def recurse(cell_id_val: CellId, depth: int):
        """Recursive function to calculate and slice cells of the grid"""
        # Cell already processed
        if cell_id_val in processed:
            return
        processed.add(cell_id_val)
        
        points_list = all_cells.get(cell_id_val, [])
        # No points exist in the cell
        if len(points_list) == 0:
            return
            
        # Not big enough to split or done splitting
        if len(points_list) <= t1 or depth >= max_level:
            # Only keep cells that contain enough points
            if len(points_list) >= t2:
                final_cells.add(cell_id_val)
            return
        
        # Too many points, recurse through each split of the cell
        cell_id_obj = CellId(cell_id_val)
        for i in range(4):
            child = cell_id_obj.child(i)
            recurse(child.id(), depth + 1)
    
    # Iterate through each level and recurve through the cells
    for level in range(max_level + 1):
        level_cells = [cid for cid in all_cells.keys() if CellId(cid).level() == level]
        for cell_id_val in level_cells:
            if cell_id_val not in processed:
                recurse(cell_id_val, level)
    
    return final_cells

def get_cell_vertices(cell_id_val: CellId):
    """Grabs all 4 vertices of a cell"""
    cell_id_obj = CellId(cell_id_val)
    cell = Cell(cell_id_obj)
    vertices = []
    for i in range(4):
        vertex = cell.get_vertex(i)
        latlng = LatLng.from_point(vertex)
        vertices.append((latlng.lat().degrees, latlng.lng().degrees))
    return vertices

In [11]:
df = pd.read_csv('imgs/sampled.csv', index_col=0)
print("Data shape:", df.shape)
display(df.head())

Data shape: (9999, 10)


Unnamed: 0,uuid,source,orig_id,city_lat,city_lon,country,count_per_country,iso3,s2_cell_id,label
0,f8f0e783-3978-4d00-9a3a-6dac4f80d140,Mapillary,876347396733383,35.7,139.7167,Japan,770,JPN,6935543426150563840,3
1,12fbfe8d-44fc-43f0-a237-870012dbc9ad,Mapillary,334319938328909,38.9047,-77.0163,United States,1421,USA,9925933578724573184,4
2,2db989bc-349d-4c98-9847-66ddc403f2cb,Mapillary,132885805478331,25.0478,121.5319,Taiwan,206,TWN,3765009288481734656,1
3,fd6d9b0b-bdaf-4625-8aa8-b62cee406b88,Mapillary,809001326698493,6.7833,-58.1667,Guyana,2,GUY,10160120759347838976,5
4,392f8489-3ae8-46b4-bce5-3f92f5248d67,Mapillary,128697239221256,37.9842,23.7281,Greece,128,GRC,1495195076287004672,0


In [12]:
# t1=5_000, t2=50, max_level=30 for 1,000,000 looks pretty good
# Level 0: 6 cells
# Level 1: 24 cells
# ...
# Total cells = 6 * 4^level
max_level = 3
final_cells = adaptive_partition(df, t1=5_000, t2=50, max_level=max_level)

In [13]:
def latlon_to_cellid(lat: float, lon: float):
    """Find which final cell this lat/lon belongs to"""
    latlng = LatLng.from_degrees(lat, lon)
    cell_id = CellId.from_lat_lng(latlng)
    
    # Traverse from max_level down to find the corresponding cell
    # Each lat,lon only belongs to the highest level cell
    for level in range(max_level, -1, -1):
        parent = cell_id.parent(level)
        parent_id = parent.id()
        
        if parent_id in final_cells:
            return parent_id
    
    return None

# Add the cell id corresponding to each latitude and longitude
df['s2_cell_id'] = df.apply(
    lambda row: latlon_to_cellid(row['city_lat'], row['city_lon']), 
    axis=1
)

# Remove null cell_ids
df.dropna(subset=['s2_cell_id'], inplace=True)
cell_ids = [int(id) for id in set(df['s2_cell_id'].unique())]
print(f"Number of unique cells: {len(cell_ids)}")

Number of unique cells: 78


In [14]:
def visualize_s2_cells(final_cells: set[CellId]):
    fig = go.Figure()
    
    for cell_id_val in final_cells:
        vertices = get_cell_vertices(cell_id_val)
        lats = [v[0] for v in vertices] + [vertices[0][0]]
        lons = [v[1] for v in vertices] + [vertices[0][1]]
        
        fig.add_trace(go.Scattergeo(
            lon=lons,
            lat=lats,
            mode='lines',
            line=dict(width=1, color='blue'),
            showlegend=False
        ))
    
    fig.update_layout(
        title=f'Adaptive S2 Cell Partitioning - {len(final_cells)} Cells',
        geo=dict(
            projection_type='natural earth',
            showland=True,
            showcountries=True,
        ),
        height=720,
        width=1280,
    )
    return fig

fig = visualize_s2_cells(cell_ids)
fig.show()

In [15]:
# Create the labels as a sequence from 0 to num_unique_cells
replacement_dict = dict()
for i, cell_id in enumerate(cell_ids):
    replacement_dict[cell_id] = i
    
df['label'] = df['s2_cell_id'].replace(replacement_dict).astype(int)

print(f"Number of unique cells: {len(df['s2_cell_id'].unique())}")
print(df['label'].unique())

Number of unique cells: 78
[ 3  4  1  5  0  2  6  7  8  9 10 11 12 13 21 14 15 34 22 35 16 36 23 37
 24 25 38 39 40 26 62 63 64 17 27 32 33 28 18 31 30 19 65 52 49 53 66 20
 54 67 41 55 68 42 69 70 50 71 72 56 73 57 58 74 43 75 44 76 59 51 45 60
 29 46 77 47 48 61]


In [16]:
# Saves class labels back to the sampled csv
df.to_csv('imgs/sampled.csv')