In [22]:
import pandas as pd
import plotly.graph_objects as go
from collections import defaultdict
from s2sphere import CellId, LatLng, Cell

In [23]:
# http://s2geometry.io/about/overview
# http://s2geometry.io/devguide/s2cell_hierarchy.html
def adaptive_partition(df: pd.DataFrame, t1: int = 10_000, t2: int = 50, max_level: int = 18):
    all_cells = defaultdict(list)
    
    # Loop through all lat,lon values
    for _, row in df.iterrows():
        lat, lon = row['lat'], row['lon']
        latlng = LatLng.from_degrees(lat, lon)
        cell_id = CellId.from_lat_lng(latlng)
        
        # Add the lat,lon to be cells at each hierarchy
        for level in range(max_level + 1):
            parent = cell_id.parent(level)
            all_cells[parent.id()].append((lat, lon))
            
    final_cells = set()
    processed = set()
    
    def recurse(cell_id_val: CellId, depth: int):
        """Recursive function to calculate and slice cells of the grid"""
        # Cell already processed
        if cell_id_val in processed:
            return
        processed.add(cell_id_val)
        
        points_list = all_cells.get(cell_id_val, [])
        # No points exist in the cell
        if len(points_list) == 0:
            return
            
        # Not big enough to split or done splitting
        if len(points_list) <= t1 or depth >= max_level:
            # Only keep cells that contain enough points
            if len(points_list) >= t2:
                final_cells.add(cell_id_val)
            return
        
        # Too many points, recurse through each split of the cell
        cell_id_obj = CellId(cell_id_val)
        for i in range(4):
            child = cell_id_obj.child(i)
            recurse(child.id(), depth + 1)
    
    # Iterate through each level and recurve through the cells
    for level in range(max_level + 1):
        level_cells = [cid for cid in all_cells.keys() if CellId(cid).level() == level]
        for cell_id_val in level_cells:
            if cell_id_val not in processed:
                recurse(cell_id_val, level)
    
    return final_cells

def get_cell_vertices(cell_id_val: CellId):
    """Grabs all 4 vertices of a cell"""
    cell_id_obj = CellId(cell_id_val)
    cell = Cell(cell_id_obj)
    vertices = []
    for i in range(4):
        vertex = cell.get_vertex(i)
        latlng = LatLng.from_point(vertex)
        vertices.append((latlng.lat().degrees, latlng.lng().degrees))
    return vertices

In [24]:
sampled_df = pd.read_csv('imgs/sampled.csv', index_col=0)
points_df = pd.read_csv('points.csv')
points_df = points_df.rename(columns={'id': 'orig_id'})
city_df = pd.merge(sampled_df, points_df, on='orig_id')
print("Data shape:", city_df.shape)
display(city_df.head())

Data shape: (2482, 20)



You are merging on int and float columns where the float values are not equal to their int representation.



Unnamed: 0,uuid,source,orig_id,city_lat,city_lon,city,country,count_per_country,iso3,s2_cell_id,label,captured_at,compass_angle,creator_id,is_pano,sequence_id,organization_id,city_id,lat,lon
0,67e0b6e7-1407-4622-94e7-f217ffc18a3c,Mapillary,381162100000000.0,38.9047,-77.0163,Washington,United States,10000,USA,9925933578724573184,0,1566759633000,285.2,102667878645138,True,6p5mkhzd59221lmjtip7yk,818974700000000.0,1840006060,38.904126,-77.027485
1,528d4572-9482-4090-abd0-bea9c9d21135,Mapillary,159777400000000.0,38.9047,-77.0163,Washington,United States,10000,USA,9925933578724573184,0,1568211689500,89.7,101980082048553,False,xcjga0j635lrh552q2te4y,818974700000000.0,1840006060,38.900281,-77.032979
2,b5c0d547-b909-43ba-950c-584c18180fe9,Mapillary,465616400000000.0,38.9047,-77.0163,Washington,United States,10000,USA,9925933578724573184,0,1564584179000,269.6,102667878645138,True,46vxjl4fgot12f6ur92mhp,818974700000000.0,1840006060,38.902639,-77.029663
3,ed2ec1ef-b57e-48e1-bfd7-9fe321311e6a,Mapillary,733205000000000.0,38.9047,-77.0163,Washington,United States,10000,USA,9925933578724573184,0,1603570812000,178.7,102833118765299,True,lUpXgnsO6TBeGidf7IEAMH,250193700000000.0,1840006060,38.900886,-77.024036
4,14bd8029-ec44-4477-9253-1c584d57d4c2,Mapillary,2048604000000000.0,38.9047,-77.0163,Washington,United States,10000,USA,9925933578724573184,0,1639746731000,180.624806,107669525342328,False,7UTlciuyaGYpVXAK4ZJxL5,,1840006060,38.897642,-77.027201


In [25]:
# t1=5_000, t2=50, max_level=30 for 1,000,000 looks pretty good
# Level 0: 6 cells
# Level 1: 24 cells
# ...
# Total cells = 6 * 4^level
max_level = 25
final_cells = adaptive_partition(city_df, t1=100, t2=5, max_level=max_level)

In [26]:
def latlon_to_cellid(lat: float, lon: float):
    """Find which final cell this lat/lon belongs to"""
    latlng = LatLng.from_degrees(lat, lon)
    cell_id = CellId.from_lat_lng(latlng)
    
    # Traverse from max_level down to find the corresponding cell
    # Each lat,lon only belongs to the highest level cell
    for level in range(max_level, -1, -1):
        parent = cell_id.parent(level)
        parent_id = parent.id()
        
        if parent_id in final_cells:
            return parent_id
    
    return None

# Add the cell id corresponding to each latitude and longitude
city_df['s2_cell_id'] = city_df.apply(
    lambda row: latlon_to_cellid(row['lat'], row['lon']), 
    axis=1
)

# Remove null cell_ids
city_df.dropna(subset=['s2_cell_id'], inplace=True)
cell_ids = [int(id) for id in set(city_df['s2_cell_id'].unique())]
print(f"Number of unique cells: {len(cell_ids)}")

Number of unique cells: 488


In [27]:
def visualize_s2_cells(final_cells: set[CellId]):
    fig = go.Figure()
    
    for cell_id_val in final_cells:
        vertices = get_cell_vertices(cell_id_val)
        lats = [v[0] for v in vertices] + [vertices[0][0]]
        lons = [v[1] for v in vertices] + [vertices[0][1]]
        
        fig.add_trace(go.Scattergeo(
            lon=lons,
            lat=lats,
            mode='lines',
            line=dict(width=1, color='blue'),
            showlegend=False
        ))
    
    fig.update_layout(
        title=f'Adaptive S2 Cell Partitioning - {len(final_cells)} Cells',
        geo=dict(
            projection_type='natural earth',
            showland=True,
            showcountries=True,
        ),
        height=720,
        width=1280,
    )
    return fig

fig = visualize_s2_cells(cell_ids)
fig.show()

In [28]:
# Create the labels as a sequence from 0 to num_unique_cells
replacement_dict = dict()
for i, cell_id in enumerate(cell_ids):
    replacement_dict[cell_id] = i
    
city_df['label'] = city_df['s2_cell_id'].replace(replacement_dict).astype(int)

print(f"Number of unique cells: {len(city_df['s2_cell_id'].unique())}")
print(city_df['label'].unique())

Number of unique cells: 488
[  6   7   4   5   3   8   9  10  14  22  42  43  44  45  46 466 255 446
 245 429 214 357  63 453 332  93 425 463 316 134  40  41  68  34 458 372
 467 238  99  75 283  20 388 389 327 459 210  92  21 280 420 367 350 312
 168 351  78  48 436 135 281  23 437  65 157 120 456 226 163 284  39   2
 171 248 352 234  47 136 137  15 219  55 138 333  17 353  94 208  56 215
 368 139 470 204 418 117 471 303 403 335 475 304  79 442 172 140 118  57
 249 227 298 141 404 348 310 341  80 299 190  16 142 419 365 345 358  58
  18 329 112  86 174 256  30 305 447 211 457  12 191 390 320 101 216 464
  13 110 192 282 290 197 391 392 217 260 448 111 369  25 426 479 431 218
 483 148 401 209 189  76 469 449 123 336 400 233 432 373 212 439 160 330
 285 206 235 362  51 393 182 272 397 257 402 225  81 250 224  19  96 121
  90 444 428 412  52 147 213 279  95 188 451 321 405 301 143 106 394 440
 395 347 481 366 115 116 408 251 167 480 164 374 476 445  66 276 252 322
 149 313 207 193 468 42

In [29]:
# Saves class labels back to the sampled csv
city_df.to_csv('imgs/sampled.csv')