In [None]:
import json
import math
import os
from urllib.parse import urlencode

import cv2
import polars as pl
import requests
from deepforest import main
from dotenv import load_dotenv
from geopy.distance import geodesic
from matplotlib import pyplot as plt
from owslib.wms import WebMapService

In [None]:
# silence some deepforest warnings
import warnings

warnings.filterwarnings(
    "ignore", message=".*root_dir argument for the location of images.*"
)
warnings.filterwarnings(
    "ignore",
    message=".*An image was passed directly to predict_tile, the results.root_dir attribute will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir*",
)

In [None]:
load_dotenv()

In [None]:
IMG_SIZE = (500, 500)

## Get tree data

In [None]:
planning_base_csv_url = "https://files.planning.data.gov.uk/"

In [None]:
dataset = "tree"
r = requests.get(f"{planning_base_csv_url}dataset/{dataset}.csv")

filename = "data/trees.csv"
with open(filename, "wb") as f_out:
    f_out.write(r.content)

In [None]:
data = pl.read_csv(filename).select(["name", "point", "address-text"])
data

In [None]:
# pick some example trees
example_indices = [
    765,
    7865,
    8030,
    12458,
    14120,
    17346,
    20745,
    21555,
    22483,
    24920,
    26152,
    31622,
    35518,
    41236,
    46031,
    56861,
    57073,
    63443,
    71083,
    80506,
    81889,
    84328,
]
single_example_index = 765

## Get data from Google map tiles

Images are worse quality than the WMS link from Paul, and the tile system uses a different coordinate system making it harder to work with. Will default to WMS.

Usage limit for this API is 100,000 requests per month. Do not run this on the whole dataset.

You can monitor usage here (select Map tiles API): https://console.cloud.google.com/google/maps-apis/quotas?invt=AbueGQ&project=mhclg-data-quality&api=static-maps-backend.googleapis.com

In [None]:
api_key = os.environ.get("GOOGLE_MAPS_API_KEY")

In [None]:
def getXY(lon, lat, zoom):
    lat_rad = math.radians(lat)
    n = 2.0**zoom
    xtile = int((lon + 180.0) / 360.0 * n)
    ytile = int(
        (1.0 - math.log(math.tan(lat_rad) + 1 / math.cos(lat_rad)) / math.pi) / 2.0 * n
    )
    return xtile, ytile

In [None]:
def get_session_token():
    url = f"https://tile.googleapis.com/v1/createSession?key={api_key}"

    payload = {"mapType": "satellite"}
    headers = {"Content-Type": "application/json"}

    response = requests.post(url, json=payload, headers=headers)
    session_token = response.json().get("session", "")
    if session_token == "":
        print("Couldn't get session token.")
        return
    return session_token

In [None]:
def download_google_image(lat, long, zoom, filename):
    x, y = getXY(lon, lat, zoom)
    session_token = get_session_token()

    url = f"https://tile.googleapis.com/v1/2dtiles/{zoom}/{x}/{y}?session={session_token}&key={api_key}"
    res = requests.get(url)

    if not os.path.exists("data/google_images"):
        os.makedirs("data/google_images")
    with open(filename, "wb") as f_out:
        f_out.write(res.content)

In [None]:
for example_index in example_indices:
    point = data.get_column("point")[example_index].split("(")[1].split(")")[0]
    lon, lat = map(float, point.split(" "))

    filename = f"data/google_images/tree_{example_index}.png"
    download_google_image(lat, lon, zoom=18, filename=filename)

## Get data from Google static maps

This is another alternative to map tiles, which did not have the best quality. It is easier to work with however the service is more expensive.

#### DO NOT RUN unless you have to.

Usage limit for this API is 10,000 requests per month. Do not run this on the whole dataset.

You can monitor usage here (select Maps static API): https://console.cloud.google.com/google/maps-apis/quotas?invt=AbueGQ&project=mhclg-data-quality&api=static-maps-backend.googleapis.com

In [None]:
static_maps_base_url = "https://maps.googleapis.com/maps/api/staticmap?"

In [None]:
def download_static_map(params, filename):
    query_string = urlencode(params)
    full_url = f"{static_maps_base_url}{query_string}"

    res = requests.get(full_url)
    if not os.path.exists("data/google_static_images"):
        os.makedirs("data/google_static_images")
    with open(filename, "wb") as f_out:
        f_out.write(res.content)

In [None]:
# for example_index in example_indices:
#     point = data.get_column("point")[example_index].split("(")[1].split(")")[0]
#     lon, lat = map(float, point.split(" "))
#     params = {
#         "center": f"{lat},{lon}",
#         "zoom": 18,
#         "size": f"{IMG_SIZE[0]}x{IMG_SIZE[1]}",
#         "maptype": "satellite",
#         "scale": 2,
#         "key": api_key,
#     }
#     filename = f"data/google_static_images/tree_{example_index}.png"
#     download_static_map(params, filename)

## Get data from WMS

In [None]:
wms_url = os.environ.get("WMS_URL")
wms = WebMapService(wms_url)

In [None]:
def get_bbox(lat, lon, offset):
    xmin = lon - offset
    ymin = lat - offset
    xmax = lon + offset
    ymax = lat + offset
    return (xmin, ymin, xmax, ymax)

In [None]:
def download_wms_image(bbox, filename):
    xmin, ymin, xmax, ymax = bbox

    img = wms.getmap(
        layers=["APGB_Latest_UK_125mm"],
        srs="EPSG:4326",
        bbox=(xmin, ymin, xmax, ymax),
        size=IMG_SIZE,
        format="image/png",
    )

    if not os.path.exists("data/wms_images"):
        os.makedirs("data/wms_images")
    with open(filename, "wb") as f_out:
        f_out.write(img.read())

In [None]:
trees = []
for example_index in example_indices:
    point = data.get_column("point")[example_index].split("(")[1].split(")")[0]
    lon, lat = map(float, point.split(" "))
    offset = 0.0005
    bbox = get_bbox(lat, lon, offset)

    filename = f"data/wms_images/tree_{example_index}.png"
    download_wms_image(bbox, filename=filename)
    trees.append(
        {
            "id": example_index,
            "wms_filename": filename,
            "lat": lat,
            "lon": lon,
            "bbox": bbox,
        }
    )

In [None]:
# add other filenames
trees_df = (
    pl.from_records(trees)
    .with_columns(
        pl.col("id")
        .map_elements(
            lambda x: f"data/google_static_images/tree_{x}.png", return_dtype=pl.Utf8
        )
        .alias("static_filename"),
        pl.col("id")
        .map_elements(
            lambda x: f"data/google_images/tree_{x}.png", return_dtype=pl.Utf8
        )
        .alias("tile_filename"),
    )
    .select(
        ["id", "lat", "lon", "bbox", "wms_filename", "static_filename", "tile_filename"]
    )
)

In [None]:
trees_df.write_parquet("data/trees_df.parquet")

## Load tree data if the image download didn't run

In [None]:
trees_df = pl.read_parquet("data/trees_df.parquet")
trees_df

## Deepforest model finetuned on trees in urban areas in Berlin

In [None]:
MODEL_PATH = "~/Downloads/model.opendata_luftbild_dop60.patch400.ckpt"
model = main.deepforest.load_from_checkpoint(checkpoint_path=MODEL_PATH)

In [None]:
def predict_boxes(image, **kwargs):
    return model.predict_tile(image=image, return_plot=False, **kwargs)


def show_boxes(image, pred_boxes, index):
    image2 = image.copy()

    for index, row in pred_boxes.iterrows():
        cv2.rectangle(
            image2,
            (int(row["xmin"]), int(row["ymin"])),
            (int(row["xmax"]), int(row["ymax"])),
            (255, 165, 0),
            thickness=1,
            lineType=cv2.LINE_AA,
        )

    plt.imshow(image2)
    plt.title(f"Tree {index}")
    plt.axis("off")
    plt.show()

#### With WMS images

In [None]:
MODEL_INFERENCE = {"patch_size": 5000, "patch_overlap": 0.9, "iou_threshold": 1}

for index, filename in trees_df.select(["id", "wms_filename"]).rows():
    example_image = cv2.imread(filename)
    pred_boxes = predict_boxes(example_image, **MODEL_INFERENCE)
    show_boxes(example_image, pred_boxes, index)

#### With Google static images

In [None]:
MODEL_INFERENCE = {"patch_size": 5000, "patch_overlap": 0.9, "iou_threshold": 1}

for index, filename in trees_df.select(["id", "static_filename"]).rows():
    example_image = cv2.imread(filename)
    pred_boxes = predict_boxes(example_image, **MODEL_INFERENCE)
    show_boxes(example_image, pred_boxes, index)

## Get the closest box

In [None]:
def pixel_to_epsg4326(x, y, img_width, img_height, bbox):
    xmin, ymin, xmax, ymax = bbox

    lon = xmin + (x / img_width) * (xmax - xmin)
    lat = ymax - (y / img_height) * (ymax - ymin)  # subtract because y=0 is top

    return lon, lat

In [None]:
def espg4326_to_pixel(lat, lon, bbox, img_size):
    xmin, ymin, xmax, ymax = bbox
    width, height = img_size

    x_frac = (lon - xmin) / (xmax - xmin)
    y_frac = 1 - (lat - ymin) / (ymax - ymin)  # Invert y-axis (lat increases upward)

    x_pixel = int(x_frac * width)
    y_pixel = int(y_frac * height)

    return x_pixel, y_pixel

#### With WMS images

In [None]:
img_width, img_height = IMG_SIZE
MODEL_INFERENCE = {"patch_size": 5000, "patch_overlap": 0.9, "iou_threshold": 1}

wms_stats = []

for index, lat, lon, bbox, wms_filename, _, _ in trees_df.rows():
    image = cv2.imread(wms_filename)
    pred_boxes = predict_boxes(image, **MODEL_INFERENCE)

    image2 = image.copy()
    distances = []
    other_boxes = []
    for _, row in pred_boxes.iterrows():
        x_center = (row["xmin"] + row["xmax"]) / 2
        y_center = (row["ymin"] + row["ymax"]) / 2
        pred_lon, pred_lat = pixel_to_epsg4326(
            x_center, y_center, img_width, img_height, bbox
        )
        box_dist = geodesic((pred_lat, pred_lon), (lat, lon)).meters
        pred_box = (row["xmin"], row["ymin"], row["xmax"], row["ymax"])
        distances.append(((pred_lat, pred_lon), box_dist, pred_box))

    distances.sort(key=lambda x: x[1])
    image3 = image.copy()

    best_box = distances[0][2]
    best_dist = distances[0][1]
    cv2.rectangle(
        image3,
        (int(best_box[0]), int(best_box[1])),
        (int(best_box[2]), int(best_box[3])),
        (255, 165, 0),
        thickness=1,
        lineType=cv2.LINE_AA,
    )

    x_pixel, y_pixel = espg4326_to_pixel(lat, lon, bbox, IMG_SIZE)
    cv2.circle(image3, (x_pixel, y_pixel), radius=5, color=(0, 0, 255), thickness=-1)

    wms_stats.append(best_dist)

    print(lat, lon)
    plt.imshow(image3)
    plt.title(f"Tree {index}; distance = {best_dist:.2f}m")
    plt.axis("off")
    plt.show()

#### With Google static images

In [None]:
# double image size because scale=2 was used to get the image
# (twice as many pixels for the same area) for better quality
img_size = (IMG_SIZE[0] * 2, IMG_SIZE[1] * 2)
img_width, img_height = img_size
MODEL_INFERENCE = {"patch_size": 5000, "patch_overlap": 0.9, "iou_threshold": 1}
static_stats = []

for index, lat, lon, bbox, _, static_filename, _ in trees_df.rows():
    image = cv2.imread(static_filename)
    pred_boxes = predict_boxes(image, **MODEL_INFERENCE)

    image2 = image.copy()
    distances = []
    other_boxes = []
    for _, row in pred_boxes.iterrows():
        x_center = (row["xmin"] + row["xmax"]) / 2
        y_center = (row["ymin"] + row["ymax"]) / 2
        pred_lon, pred_lat = pixel_to_epsg4326(
            x_center, y_center, img_width, img_height, bbox
        )
        box_dist = geodesic((pred_lat, pred_lon), (lat, lon)).meters
        pred_box = (row["xmin"], row["ymin"], row["xmax"], row["ymax"])
        distances.append(((pred_lat, pred_lon), box_dist, pred_box))

    distances.sort(key=lambda x: x[1])
    image3 = image.copy()

    best_box = distances[0][2]
    best_dist = distances[0][1]
    cv2.rectangle(
        image3,
        (int(best_box[0]), int(best_box[1])),
        (int(best_box[2]), int(best_box[3])),
        (255, 165, 0),
        thickness=2,
        lineType=cv2.LINE_AA,
    )

    x_pixel, y_pixel = espg4326_to_pixel(lat, lon, bbox, img_size)
    cv2.circle(image3, (x_pixel, y_pixel), radius=10, color=(0, 0, 255), thickness=-1)

    static_stats.append(best_dist)

    print(lat, lon)
    plt.imshow(image3)
    plt.title(f"Tree {index}; distance = {best_dist:.2f}m")
    plt.axis("off")
    plt.show()

## Plot distances from true coordinates

In [None]:
color = "#00625E"

max_distance = max(max(wms_stats), max(static_stats))
max_distance_rounded = math.ceil(max_distance)

bins = list(range(0, max_distance_rounded + 1))

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6, 6), height_ratios=[1, 1])
fig.subplots_adjust(hspace=0.1)

ax1.hist(wms_stats, bins=bins, color=color, label="WMS images")
ax1.set_title("WMS error distribution")
ax1.set_ylabel("Count")
ax1.legend(loc="best")

ax2.hist(static_stats, bins=bins, color=color, label="Google static API images")
ax2.set_title("Static error distribution")
ax2.set_ylabel("Count")
ax2.legend(loc="best")
ax2.set_xlabel("Distance from true point (meters)")

ax2.set_xticks(bins)

plt.tight_layout()
plt.savefig("data/distance_histograms.png", dpi=300, bbox_inches="tight")
plt.show()

## Find outliers

In [None]:
(
    data.filter(pl.col("name").is_not_null())
    .with_columns(
        lat=pl.col("point")
        .str.split("(")
        .list.get(1)
        .str.split(")")
        .list.get(0)
        .str.split(" ")
        .list.get(1)
        .cast(pl.Float64),
        lon=pl.col("point")
        .str.split("(")
        .list.get(1)
        .str.split(")")
        .list.get(0)
        .str.split(" ")
        .list.get(0)
        .cast(pl.Float64),
    )
    .sort("lon")
)