# KDTree

Testing the time to look-up nearby records with the `KDTree` implementation. Note that this implementation is actually a `2DTree` since it can only compute a valid distance comparison between longitude and latitude positions.

The `KDTree` object is used for finding the closest neighbour to a position, in this implementation we use the Haversine distance to compare positions.

In [None]:
import os


os.environ["POLARS_MAX_THREADS"] = "1"

import inspect
import random
from datetime import datetime
from string import ascii_letters, digits

import numpy as np
import polars as pl

from geotrees import KDTree, Record

## Set-up functions

Used for generating data, or for comparisons by doing brute-force approach.

In [None]:
def randnum() -> float:
    """Get a random number between -1 and 1"""
    return 2 * (np.random.rand() - 0.5)


def generate_uid(n: int) -> str:
    """Generates a pseudo uid by randomly selecting from characters"""
    chars = ascii_letters + digits
    return "".join(random.choice(chars) for _ in range(n))


def random_record() -> Record:
    """Generate a random record"""
    return Record(
        random.choice(range(-179, 180)) + randnum(),
        random.choice(range(-89, 90)) + randnum(),
    )


def check_cols(
    df: pl.DataFrame | pl.LazyFrame,
    cols: list[str],
    var_name: str = "dataframe",
) -> None:
    """
    Check that a dataframe contains a list of columns. Raises an error if not.

    Parameters
    ----------
    df : polars Frame
        Dataframe to check
    cols : list[str]
        Required columns
    var_name : str
        Name of the Frame - used for displaying in any error.
    """
    calling_func = inspect.stack()[1][3]
    if isinstance(df, pl.DataFrame):
        have_cols = df.columns
    elif isinstance(df, pl.LazyFrame):
        have_cols = df.collect_schema().names()
    else:
        raise TypeError("Input Frame is not a polars Frame")

    cols_in_frame = intersect(cols, have_cols)
    missing = [c for c in cols if c not in cols_in_frame]

    if len(missing) > 0:
        err_str = f"({calling_func}) - {var_name} missing required columns. "
        err_str += f"Require: {', '.join(cols)}. "
        err_str += f"Missing: {', '.join(missing)}."
        raise ValueError(err_str)

    return


def haversine_df(
    df: pl.DataFrame | pl.LazyFrame,
    lon: float,
    lat: float,
    radius: float = 6371,
    lon_col: str = "lon",
    lat_col: str = "lat",
) -> pl.DataFrame | pl.LazyFrame:
    """
    Compute haversine distance on earth surface between lon-lat positions
    in a polars DataFrame and a lon-lat position.

    Parameters
    ----------
    df : polars.DataFrame
        The data, containing required columns:
            * lon_col
            * lat_col
            * date_var
    lon : float
        The longitude of the position.
    lat : float
        The latitude of the position.
    radius : float
        Radius of earth in km
    lon_col : str
        Name of the longitude column
    lat_col : str
        Name of the latitude column

    Returns
    -------
    polars.DataFrame
        With additional column specifying distances between consecutive points
        in the same units as 'R'. With colname defined by 'out_colname'.
    """
    required_cols = [lon_col, lat_col]

    check_cols(df, required_cols, "df")
    return (
        df.with_columns(
            [
                pl.col(lat_col).radians().alias("_lat0"),
                pl.lit(lat).radians().alias("_lat1"),
                (pl.col(lon_col) - lon).radians().alias("_dlon"),
                (pl.col(lat_col) - lat).radians().alias("_dlat"),
            ]
        )
        .with_columns(
            (
                (pl.col("_dlat") / 2).sin().pow(2)
                + pl.col("_lat0").cos()
                * pl.col("_lat1").cos()
                * (pl.col("_dlon") / 2).sin().pow(2)
            ).alias("_a")
        )
        .with_columns(
            (2 * radius * (pl.col("_a").sqrt().arcsin()))
            .round(2)
            .alias("_dist")
        )
        .drop(["_lat0", "_lat1", "_dlon", "_dlat", "_a"])
    )


def intersect(a, b) -> set:
    """Intersection of a and b, items in both a and b"""
    return set(a) & set(b)


def nearest_ship(
    lon: float,
    lat: float,
    df: pl.DataFrame,
    lon_col: str = "lon",
    lat_col: str = "lat",
) -> pl.DataFrame:
    """
    Find the observation nearest to a position in space.

    Get a frame with only the records that is closest to the input point.

    Parameters
    ----------
    lon : float
        The longitude of the position.
    lat : float
        The latitude of the position.
    df : polars.DataFrame
        The pool of records to search. Can be pre-filtered and filter_datetime
        set to False.
    lon_col : str
        Name of the longitude column in the pool DataFrame
    lat_col : str
        Name of the latitude column in the pool DataFrame

    Returns
    -------
    polars.DataFrame
        Containing only records from the pool within max_dist of the input
        point, optionally at the same datetime if filter_datetime is True.
    """
    required_cols = [lon_col, lat_col]
    check_cols(df, required_cols, "df")

    return (
        df.pipe(
            haversine_df,
            lon=lon,
            lat=lat,
            lon_col=lon_col,
            lat_col=lat_col,
        )
        .filter(pl.col("_dist").eq(pl.col("_dist").min()))
        .drop(["_dist"])
    )

## Initialise random data

In [None]:
N = 16_000
lons = pl.int_range(-180, 180, eager=True)
lats = pl.int_range(-90, 90, eager=True)
dates = pl.datetime_range(
    datetime(1900, 1, 1, 0),
    datetime(1900, 1, 31, 23),
    interval="1h",
    eager=True,
)

lons_use = lons.sample(N, with_replacement=True).alias("lon")
lats_use = lats.sample(N, with_replacement=True).alias("lat")
# dates_use = dates.sample(N, with_replacement=True).alias("datetime")
# uids = pl.Series("uid", [generate_uid(8) for _ in range(N)])

df = pl.DataFrame([lons_use, lats_use])
print(df.shape)
print(df.head())

In [None]:
records = [Record(**r) for r in df.rows(named=True)]

## Initialise the `KDTree`

There is an overhead to constructing a `KDTree` object, so performance improvement is only for multiple comparisons.

In [None]:
%%time
kt = KDTree(records)

## Compare with brute force approach

In [None]:
%%timeit test_record = random_record()
kt.query(test_record)

In [None]:
%%timeit test_record = test_record = random_record()
np.argmin([test_record.distance(p) for p in records])

In [None]:
%%timeit test_record = random_record()
nearest_ship(lon=test_record.lon, lat=test_record.lat, df=df)

## Verify that results are correct

In [None]:
%%time
n_samples = 1000
tol = 1e-8
test_records = [
    Record(
        random.choice(range(-179, 180)) + randnum(),
        random.choice(range(-89, 90)) + randnum(),
    )
    for _ in range(n_samples)
]
kd_res = [kt.query(r) for r in test_records]
kd_recs = [_[0][0] for _ in kd_res]
kd_dists = [_[1] for _ in kd_res]
tr_recs = [
    records[np.argmin([r.distance(p) for p in records])] for r in test_records
]
tr_dists = [min([r.distance(p) for p in records]) for r in test_records]

if not all([abs(k - t) < tol for k, t in zip(kd_dists, tr_dists)]):
    raise ValueError("NOT MATCHING?")

In [None]:
test_lons = [r.lon for r in test_records]
test_lats = [r.lat for r in test_records]

kd_lons = [r.lon for r in kd_recs]
kd_lats = [r.lat for r in kd_recs]

tr_lons = [r.lon for r in tr_recs]
tr_lats = [r.lat for r in tr_recs]

df = pl.DataFrame(
    {
        "test_lon": test_lons,
        "test_lat": test_lats,
        "kd_dist": kd_dists,
        "kd_lon": kd_lons,
        "kd_lat": kd_lats,
        "tr_dist": tr_dists,
        "tr_lon": tr_lons,
        "tr_lat": tr_lats,
    }
).filter((pl.col("kd_dist") - pl.col("tr_dist")).abs().ge(tol))
df