In [2]:
import time
from functools import reduce
from pathlib import Path

import networkx as nx
import polars as pl

from locus.utils.pl_utils import batch_iter
from tqdm import tqdm

%matplotlib widget

In [3]:
PROJECT_ROOT = Path().cwd().parent
PROCESSED_DATA_DIR = PROJECT_ROOT / "data" / "processed"

In [4]:
df = pl.scan_parquet(PROCESSED_DATA_DIR / "LDoGI/shards/*.parquet")
print(df.head().collect())
c = df.select(pl.len()).collect()["len"][0] # count
c

shape: (5, 4)
┌─────┬────────────┬────────────┬───────────────────────────────────┐
│ id  ┆ latitude   ┆ longitude  ┆ image                             │
│ --- ┆ ---        ┆ ---        ┆ ---                               │
│ i64 ┆ f64        ┆ f64        ┆ binary                            │
╞═════╪════════════╪════════════╪═══════════════════════════════════╡
│ 0   ┆ 41.906     ┆ 12.455     ┆ b"\xff\xd8\xff\xe0\x00\x10JFIF\x… │
│ 1   ┆ 48.211072  ┆ 16.36736   ┆ b"\xff\xd8\xff\xe0\x00\x10JFIF\x… │
│ 2   ┆ 43.942876  ┆ 12.774091  ┆ b"\xff\xd8\xff\xe0\x00\x10JFIF\x… │
│ 3   ┆ 41.339055  ┆ 14.507789  ┆ b"\xff\xd8\xff\xe0\x00\x10JFIF\x… │
│ 4   ┆ -23.210269 ┆ -44.693223 ┆ b"\xff\xd8\xff\xe0\x00\x10JFIF\x… │
└─────┴────────────┴────────────┴───────────────────────────────────┘


3993900

In [5]:
df = pl.scan_parquet(PROCESSED_DATA_DIR / "LDoGI/shards/*.parquet")
df = df.drop("id", "image")
print(df.head().collect())
c = df.select(pl.len()).collect()["len"][0] # count
c

shape: (5, 2)
┌────────────┬────────────┐
│ latitude   ┆ longitude  │
│ ---        ┆ ---        │
│ f64        ┆ f64        │
╞════════════╪════════════╡
│ 41.906     ┆ 12.455     │
│ 48.211072  ┆ 16.36736   │
│ 43.942876  ┆ 12.774091  │
│ 41.339055  ┆ 14.507789  │
│ -23.210269 ┆ -44.693223 │
└────────────┴────────────┘


3993900

In [6]:
G = nx.read_gml(PROCESSED_DATA_DIR / "LDoGI/quadtrees/qt_min50_max5000_df100pct.gml")

In [10]:
G = nx.read_gml(PROCESSED_DATA_DIR / "LDoGI/quadtrees/qt_min10_max1000_df10pct.gml")
active_cells = [node for node in list(G.nodes) if G.nodes[node]["state"] == CellState.ACTIVE]
len(active_cells)


1462

In [8]:
# create enum of cell states
class CellState:
    STOPPED = 0
    EVALUATING = 1
    ACTIVE = 2

In [8]:
active_cells = [node for node in list(G.nodes) if G.nodes[node]["state"] == CellState.ACTIVE]

In [9]:
cell = "132023110331"
print(G[cell])
print(G.nodes[cell])

{}
{'state': 2}


In [10]:
{'132023110': {}, '132023111': {}, '132023112': {}, '132023113': {}}

{'132023110': {}, '132023111': {}, '132023112': {}, '132023113': {}}

In [11]:
len(active_cells)

2857

In [12]:
active_cells[0]

'11'

In [13]:
sum_path_lens = reduce(lambda s, x: s + len(x), active_cells, 0)
max_path_len = max([len(x) for x in active_cells])
count_active_cells = len(active_cells)

avg_path_len = sum_path_lens / count_active_cells
print(max_path_len)
print(avg_path_len)

15
9.334616730836542


In [14]:
import torch
def calc_enclosing_cell(lon: float, lat: float, active_cells: list[str]):
    """
    Given a point (lon, lat) and a graph, return the cell that encloses the point.
    """

    def get_next_cell(lon: float, lat: float, west_lon: float, east_lon: float, south_lat: float, north_lat: float):
        ret_west_lon = west_lon
        ret_east_lon = east_lon
        ret_south_lat = south_lat
        ret_north_lat = north_lat

        quad = 0

        half_lon = (west_lon + east_lon) / 2
        if lon >= half_lon:
            quad += 1
            ret_west_lon = half_lon
        else:
            ret_east_lon = half_lon

        half_lat = (south_lat + north_lat) / 2
        if lat < half_lat:
            quad += 2
            ret_north_lat = half_lat
        else:
            ret_south_lat = half_lat

        return quad, (ret_west_lon, ret_east_lon, ret_south_lat, ret_north_lat)

    west_lon = torch.tensor(-180, dtype=torch.float32)
    east_lon = torch.tensor(180, dtype=torch.float32)
    south_lat = torch.tensor(-90, dtype=torch.float32)
    north_lat = torch.tensor(90, dtype=torch.float32)

    cell = ""
    cell_pool = [c for c in active_cells]

    while True:
        if cell == '13202311033':
            pass
        quad, (west_lon, east_lon, south_lat, north_lat) = get_next_cell(
            lon, lat, west_lon, east_lon, south_lat, north_lat
        )

        cell += str(quad)
        cell_pool = [c for c in cell_pool if c.startswith(cell)]

        if len(cell_pool) == 1 and cell == cell_pool[0]:
            return cell

        if len(cell_pool) == 0:
            f"Not found: {cell}"
            return None

In [15]:
calc_enclosing_cell(100.458984, 13.75806, active_cells)

'132023110331'

In [34]:
h = df.head(128).collect()
print(h)

shape: (128, 2)
┌────────────┬────────────┐
│ latitude   ┆ longitude  │
│ ---        ┆ ---        │
│ f64        ┆ f64        │
╞════════════╪════════════╡
│ 41.906     ┆ 12.455     │
│ 48.211072  ┆ 16.36736   │
│ 43.942876  ┆ 12.774091  │
│ 41.339055  ┆ 14.507789  │
│ -23.210269 ┆ -44.693223 │
│ …          ┆ …          │
│ 55.679973  ┆ 12.571996  │
│ 47.954639  ┆ 13.500051  │
│ 40.769967  ┆ -73.993327 │
│ 43.633838  ┆ 1.381359   │
│ 35.778306  ┆ -78.633828 │
└────────────┴────────────┘


In [17]:
# get 14th element as list
h.row(14)

(38.700515, -9.056854)

In [18]:
problem_line = calc_enclosing_cell(h.row(14)[1], h.row(14)[0], active_cells)

In [19]:
print(problem_line)

031130013


In [35]:
avg_time = 0
runs = 5

for run_i in range(runs):
    start_time = time.time()

    for i, row in enumerate(h.iter_rows()):
        calc_enclosing_cell(row[1], row[0], active_cells)


    end_time = time.time()
    avg_time += end_time - start_time
    print(f"run {run_i} took: {end_time - start_time}s")

print()
print(f"average: {avg_time/runs}s")

run 0 took: 0.0889139175415039s
run 1 took: 0.07351016998291016s
run 2 took: 0.06662344932556152s
run 3 took: 0.06803321838378906s
run 4 took: 0.06723761558532715s

average: 0.07286367416381836s


In [32]:
5.35202*128/10000

0.068505856

In [33]:
4_200_000*5.35202/10000

2247.8484000000003

In [24]:
len(active_cells)

2857

In [25]:
def cell_bounds(cell: str):
    west_long = -180
    east_long = 180
    south_lat = -90
    north_lat = 90

    for divide in cell:
        if int(divide) < 2:
            south_lat = (south_lat + north_lat) / 2
        else:
            north_lat = (south_lat + north_lat) / 2

        if int(divide) % 2 == 0:
            east_long = (west_long + east_long) / 2
        else:
            west_long = (west_long + east_long) / 2

    return south_lat, north_lat, west_long, east_long

In [26]:
avg_time = 0
runs = 5

for run_i in range(runs):
    start_time = time.time()

    for cell in enumerate(active_cells):
        cell_bounds(cell)


    end_time = time.time()
    avg_time += end_time - start_time
    print(f"run {run_i} took: {end_time - start_time}s")

print()
print(f"average: {avg_time/runs}s")

run 0 took: 0.0017485618591308594s
run 1 took: 0.0016384124755859375s
run 2 took: 0.0020799636840820312s
run 3 took: 0.0016498565673828125s
run 4 took: 0.0016698837280273438s

average: 0.001757335662841797s
