In [19]:
import math
import os
from pathlib import Path

import h3
import numpy as np
import polars as pl
import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

In [2]:
params = {
    "export_dir": "vision-export-20240929050006-aka-2.17",
    "export_short_version": "2.17",
    "train_only_cid_data": True,
    "train_only_wild_data": False,
    "h3_resolution": 6,
    "num_random_samples": 100_000,
    "elevation_file": "elevation_h3_resolution6.csv",
    "experiment_dir": "/data-ssd/alex/experiments/geo_prior_tf/2_17",
    "batch_size": 1024,
    "num_epochs": 200,
    "initial_lr": 0.0005,
    "shuffle_buffer_size": 50_000,
    "full_shuffle_before_tfrecords": False,
    "lr_warmup_cosine_decay": True,
    "wandb_project": "geomodel_tf",
}

In [3]:
h3_column_name = "h3_0{}".format(params["h3_resolution"])

In [4]:
export_dir = Path(params["export_dir"])

In [5]:
df = pl.read_parquet(
    export_dir / "spatial_data.parquet",
    columns=["latitude", "longitude", "spatial_class_id", "community", "captive"]
)

In [6]:
df.sample(3)

latitude,longitude,spatial_class_id,community,captive
f64,f64,i64,i64,i64
-36.8509,174.7869,96396,1,0
34.4892,-84.8568,134325,0,0
19.0619,-104.3046,81171,1,0


In [7]:
tax = pl.read_csv(
    export_dir / "taxonomy.csv",
    columns=["leaf_class_id", "spatial_class_id", "taxon_id", "name"],
)
leaf_tax = tax.filter(
    ~pl.col("leaf_class_id").is_null()
)
leaf_tax

taxon_id,leaf_class_id,spatial_class_id,name
i64,i64,i64,str
129726,95877,95877,"""Cephalochordata"""
48272,40739,40739,"""Ciona intestinalis"""
81614,30262,30262,"""Ciona savignyi"""
712964,54943,54943,"""Ciona robusta"""
472689,62312,62312,"""Ascidia mentula"""
…,…,…,…
783080,95309,95309,"""Firstpapillomavirinae"""
914165,90922,90922,"""Begomovirus"""
1538530,29137,29137,"""Bracoviriform congregatae"""
1538531,45009,45009,"""Bracoviriform glomeratae"""


In [8]:
# only leaves
df = df.filter(
    pl.col("spatial_class_id").is_in(leaf_tax["spatial_class_id"])
)

In [9]:
# drop non-cid data
df = df.filter(
    pl.col("community") == 1
)

In [10]:
df = df.rename({
    "latitude": "lat",
    "longitude": "lng",
})

In [11]:
# drop invalid locations
num_obs = df.shape[0]
df = df.filter([
    pl.col("lat") < 90,
    pl.col("lat") > -90,
    pl.col("lng") < 180,
    pl.col("lng") > -180,
])

if num_obs - df.shape[0] > 0:
    print(" ", num_obs - df.shape[0], "items filtered due to invalid locations")
else:
    print("no invalid locations found")

  6 items filtered due to invalid locations


In [12]:
# drop invalid locations
num_obs = df.shape[0]
df = df.filter([
    pl.col("lat") != 0,
    pl.col("lng") != 0,
])

if num_obs - df.shape[0] > 0:
    print(" ", num_obs - df.shape[0], "items filtered due to null island")
else:
    print("no invalid locations found")

  474 items filtered due to null island


In [13]:
df = df[["lat", "lng", "spatial_class_id"]]
df.sample(3)

lat,lng,spatial_class_id
f64,f64,i64
42.0373,-88.0097,64223
43.2931,-79.8786,72230
37.0545,-122.0213,8922


In [14]:
%%time 

dfh3 = df.with_columns(
    pl.struct("lat", "lng")
    .map_elements(
        lambda x: h3.geo_to_h3(x["lat"], x["lng"], params["h3_resolution"]),
        return_dtype=str
    )
    .alias(h3_column_name)
)

CPU times: user 1min 10s, sys: 1.54 s, total: 1min 12s
Wall time: 1min 12s


In [15]:
dfh3_dense = (
    dfh3
    .group_by(h3_column_name)
    .agg(
        pl.col("spatial_class_id")
        .unique()
    )
)

In [16]:
dfh3_dense

h3_06,spatial_class_id
str,list[i64]
"""86bc58897ffffff""",[87658]
"""8628b322fffffff""","[16230, 23291, … 83700]"
"""86488e6f7ffffff""","[793, 916, … 93286]"
"""8628a4777ffffff""","[18007, 19839, … 70627]"
"""8626b1497ffffff""",[52949]
…,…
"""862810ad7ffffff""","[793, 825, … 86095]"
"""862b1a727ffffff""","[430, 801, … 86896]"
"""8660b4977ffffff""","[4678, 8481, … 79449]"
"""86b242aefffffff""","[727, 1542, … 93635]"


# random samples for negatives

In [20]:
def make_samples(batch_size):
    rand_loc = np.random.uniform(size=(batch_size, 2))

    theta1 = 2.0 * math.pi * rand_loc[:, 0]
    theta2 = np.arccos(2.0 * rand_loc[:, 1] - 1.0)

    lat = 1.0 - 2.0 * theta2 / math.pi
    lng = (theta1 / math.pi) - 1.0

    return list(zip(lng, lat))

In [21]:
samples = make_samples(params["num_random_samples"])
scaled_samples = [(x * 180, y * 90) for x, y in samples]
scaled_samples = list(zip(*scaled_samples))

In [23]:
negatives_df = pl.DataFrame(
    scaled_samples,
    ["lng", "lat"]
)
negatives_dfh3 = negatives_df.with_columns(
    pl.struct("lat", "lng")
    .map_elements(
        lambda x: h3.geo_to_h3(x["lat"], x["lng"], params["h3_resolution"]),
        return_dtype=str
    )
    .alias(h3_column_name)
)
negatives_dfh3.sample(5)

lng,lat,h3_06
f64,f64,str
-19.883455,19.905386,"""865583647ffffff"""
-7.259882,-13.855492,"""869800b97ffffff"""
-14.258337,-5.392309,"""867d86c37ffffff"""
-131.941698,-1.167177,"""86780d46fffffff"""
44.114396,9.860125,"""8652d332fffffff"""


In [24]:
negatives_dfh3 = negatives_dfh3.unique(subset=h3_column_name)

In [25]:
# merge negatives/empties with spatial data

In [26]:
# trim any negatives that would duplicate positives
negatives_dfh3 = negatives_dfh3.filter(
    ~pl.col(h3_column_name).is_in(dfh3_dense[h3_column_name])
)

In [29]:
dfh3_dense.sample(5)

h3_06,spatial_class_id
str,list[i64]
"""86480daa7ffffff""","[6820, 9105, … 78664]"
"""8630f319fffffff""","[573, 1507, … 86586]"
"""86195492fffffff""","[44, 532, … 87574]"
"""860c0636fffffff""","[3995, 12311, … 81862]"
"""8644f3a4fffffff""","[18238, 31641, … 75219]"


In [30]:
negatives_dfh3 = negatives_dfh3.with_columns(
    pl.lit([], dtype=pl.datatypes.List(pl.datatypes.Int64))
    .alias("spatial_class_id")
)

In [31]:
negatives_dfh3 = negatives_dfh3[[h3_column_name, "spatial_class_id"]]
negatives_dfh3.sample(3)

h3_06,spatial_class_id
str,list[i64]
"""86cd4934fffffff""",[]
"""86cfab937ffffff""",[]
"""86ee1c06fffffff""",[]


In [32]:
combined = pl.concat([
    dfh3_dense,
    negatives_dfh3,
])

In [33]:
len(combined)

925122

# add elevation

In [34]:
elevation = pl.read_csv(
    params["elevation_file"]
)

In [36]:
elevation.sample(5)

h3_06,elevation
str,f64
"""868b6b5a7ffffff""",437.265306
"""86c746317ffffff""",-32768.0
"""86646b00fffffff""",-32768.0
"""86b75d1a7ffffff""",-32768.0
"""86d4726afffffff""",-32768.0


In [37]:
elevation = elevation.with_columns(
    pl.when(pl.col("elevation") > 0)
    .then(pl.col("elevation") / pl.col("elevation").max())
    .otherwise(pl.col("elevation") / pl.col("elevation").min() * -1)
)

In [38]:
combined_with_elevation = combined.join(elevation, on=h3_column_name)

In [39]:
combined_with_elevation.sample(5)

h3_06,spatial_class_id,elevation
str,list[i64],f64
"""860885a67ffffff""",[68854],-0.015893
"""8628a58efffffff""",[327],0.17583
"""866619d07ffffff""","[151, 535, … 93432]",0.326015
"""86c2e8927ffffff""",[42802],0.009541
"""8626eb097ffffff""",[71982],0.038421


# convert to geo and normalize

In [40]:
combined_with_elevation = combined_with_elevation.with_columns(
    pl.col(h3_column_name)
    .map_elements(lambda x: h3.h3_to_geo(x))
    .alias("xy"),
)

  combined_with_elevation = combined_with_elevation.with_columns(


In [41]:
combined_with_elevation = combined_with_elevation.with_columns([
    pl.col("xy")
    .map_elements(lambda x: x[0], return_dtype=float)
    .alias("lat"),
    pl.col("xy")
    .map_elements(lambda x: x[1], return_dtype=float)
    .alias("lng")
])

In [42]:
combined_with_elevation = combined_with_elevation.with_columns([
    (pl.col("lng") * math.pi).sin().alias("a"),
    (pl.col("lat") * math.pi).sin().alias("b"),
    (pl.col("lng") * math.pi).cos().alias("c"),
    (pl.col("lat") * math.pi).cos().alias("d"),
])

In [43]:
combined_with_elevation.sample(5)

h3_06,spatial_class_id,elevation,xy,lat,lng,a,b,c,d
str,list[i64],f64,list[f64],f64,f64,f64,f64,f64,f64
"""861ea6147ffffff""","[3961, 4991, … 83535]",0.003419,"[44.840486, 10.562251]",44.840486,10.562251,0.980938,0.480415,-0.194323,-0.877041
"""862773af7ffffff""","[9344, 12918, 30816]",0.052999,"[46.896504, -88.786053]",46.896504,-88.786053,-0.622658,0.319443,-0.782494,-0.947606
"""860ecbd77ffffff""","[801, 35742, … 73878]",0.033728,"[49.427997, -82.075484]",49.427997,-82.075484,-0.234924,-0.974525,0.972014,-0.22428
"""86441304fffffff""","[1292, 3834, … 81980]",-0.568126,"[26.542405, -82.055294]",26.542405,-82.055294,-0.17284,0.991139,0.98495,-0.132826
"""862ab6acfffffff""","[712, 2774, … 91647]",0.030936,"[42.82344, -81.871845]",42.82344,-81.871845,0.391823,0.526671,0.920041,-0.850069


# make tfrecords

In [44]:
len(combined_with_elevation)

925122

In [45]:
def create_tf_example(l0, l1, l2, l3, elevation, leaf_class_ids):
    tf_example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "l0": tf.train.Feature(
                    float_list=tf.train.FloatList(value=[l0])
                ),
                "l1": tf.train.Feature(
                    float_list=tf.train.FloatList(value=[l1])
                ),
                "l2": tf.train.Feature(
                    float_list=tf.train.FloatList(value=[l2])
                ),
                "l3": tf.train.Feature(
                    float_list=tf.train.FloatList(value=[l3])
                ),
                "elevation": tf.train.Feature(
                    float_list=tf.train.FloatList(value=[elevation])
                ),
                "leaf_class_ids": tf.train.Feature(
                    int64_list=tf.train.Int64List(value=leaf_class_ids)
                ),
            }
        )
    )
    return tf_example


In [46]:
os.makedirs(
    os.path.join(params["export_dir"], "geo_spatial_grid_datasets"),
    exist_ok=True,
)
tfrecord_file = os.path.join(
    params["export_dir"],
    "geo_spatial_grid_datasets",
    "r{}_empty_cells_with_elevation.tf".format(params["h3_resolution"]),
)

In [47]:

print("  writing tfrecords")
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    i = 0
    for i in tqdm(range(len(combined_with_elevation))):
        row = combined_with_elevation[i]

        l0 = row["a"].item()
        l1 = row["b"].item()
        l2 = row["c"].item()
        l3 = row["d"].item()

        elevation = row["elevation"].item()
        cids = list(row["spatial_class_id"].item())
        
        example = create_tf_example(l0, l1, l2, l3, elevation, cids)
        writer.write(example.SerializeToString())
    writer.close()


  writing tfrecords


  0%|          | 0/925122 [00:00<?, ?it/s]