In [4]:
import os
from pathlib import Path

import h3
import numpy as np
import polars as pl
import pandas as pd
import tensorflow as tf
from tqdm.auto import tqdm

In [5]:
params = {
    "export_dir": "vision-export-20240929050006-aka-2.17",
    "export_short_version": "2.17",
    "train_only_cid_data": True,
    "train_only_wild_data": False,
    "h3_resolution": 6,
    "num_random_samples": 100_000,
    "elevation_file": "elevation_h3_resolution6.csv",
    "experiment_dir": "/data-ssd/alex/experiments/geo_prior_tf/2_17",
    "batch_size": 1024,
    "num_epochs": 200,
    "initial_lr": 0.0005,
    "shuffle_buffer_size": 5_000,
    "full_shuffle_before_tfrecords": False,
    "lr_warmup_cosine_decay": True,
    "wandb_project": "geomodel_tf",
}

In [6]:
h3_column_name = "h3_0{}".format(params["h3_resolution"])
export_dir = Path(params["export_dir"])

In [7]:
tax = pl.read_csv(
    export_dir / "taxonomy.csv",
    columns=["leaf_class_id", "spatial_class_id", "taxon_id", "name"],
)
leaf_tax = tax.filter(
    ~pl.col("leaf_class_id").is_null()
)
leaf_tax

taxon_id,leaf_class_id,spatial_class_id,name
i64,i64,i64,str
129726,95877,95877,"""Cephalochordata"""
48272,40739,40739,"""Ciona intestinalis"""
81614,30262,30262,"""Ciona savignyi"""
712964,54943,54943,"""Ciona robusta"""
472689,62312,62312,"""Ascidia mentula"""
…,…,…,…
783080,95309,95309,"""Firstpapillomavirinae"""
914165,90922,90922,"""Begomovirus"""
1538530,29137,29137,"""Bracoviriform congregatae"""
1538531,45009,45009,"""Bracoviriform glomeratae"""


In [8]:
num_classes = len(leaf_tax)
num_classes

95903

In [9]:
tfrecord_file = os.path.join(
    params["export_dir"],
    "geo_spatial_grid_datasets",
    "r{}_empty_cells_with_elevation.tf".format(params["h3_resolution"]),
)

# cool let's train a model

In [10]:
class ResLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ResLayer, self).__init__()
        self.w1 = tf.keras.layers.Dense(
            256, activation="relu", kernel_initializer="he_normal", use_bias=False
        )
        self.w2 = tf.keras.layers.Dense(
            256, activation="relu", kernel_initializer="he_normal", use_bias=False,
        )
        self.dropout = tf.keras.layers.Dropout(rate=0.5)
        self.add = tf.keras.layers.Add()

    def call(self, inputs):
        x = self.w1(inputs)
        x = self.dropout(x)
        x = self.w2(x)
        x = self.add([x, inputs])
        return x

    def get_config(self):
        return {}

In [11]:
fcnet = tf.keras.models.Sequential(
    [
        tf.keras.layers.Input(
            5,
        ),
        # encode_location_layer,
        tf.keras.layers.Dense(
            256, activation="relu", kernel_initializer="he_normal", use_bias=False
        ),
        ResLayer(),
        ResLayer(),
        ResLayer(),
        ResLayer(),
        tf.keras.layers.Dense(num_classes, use_bias=False),
        tf.keras.layers.Dropout(rate=0.2),
        tf.keras.layers.Activation("sigmoid", dtype="float32", name="predictions"),
    ]
)
fcnet.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 256)               1280      
                                                                 
 res_layer (ResLayer)        (None, 256)               131072    
                                                                 
 res_layer_1 (ResLayer)      (None, 256)               131072    
                                                                 
 res_layer_2 (ResLayer)      (None, 256)               131072    
                                                                 
 res_layer_3 (ResLayer)      (None, 256)               131072    
                                                                 
 dense_9 (Dense)             (None, 95903)             24551168  
                                                                 
 dropout_4 (Dropout)         (None, 95903)             0

2024-10-25 14:43:25.421153: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2024-10-25 14:43:25.421193: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2024-10-25 14:43:25.421200: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
2024-10-25 14:43:25.421230: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-10-25 14:43:25.421248: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


#### make tf.data dataset

In [12]:
gp_grid_feature_description = {
    "l0": tf.io.FixedLenFeature([], tf.float32),
    "l1": tf.io.FixedLenFeature([], tf.float32),
    "l2": tf.io.FixedLenFeature([], tf.float32),
    "l3": tf.io.FixedLenFeature([], tf.float32),
    "elevation": tf.io.FixedLenFeature([], tf.float32),
    "leaf_class_ids": tf.io.VarLenFeature(tf.int64),
}
def grid_parse_function(example_proto):
    # Parse the input tf.train.Example proto using the dictionary above.
    return tf.io.parse_single_example(
        example_proto, gp_grid_feature_description
    )
def preprocess_line(line):
    l0 = tf.expand_dims(line["l0"], axis=0)
    l1 = tf.expand_dims(line["l1"], axis=0)
    l2 = tf.expand_dims(line["l2"], axis=0)
    l3 = tf.expand_dims(line["l3"], axis=0)
    elevation = tf.expand_dims(line["elevation"], axis=0)
    encoded_loc = tf.concat([l0, l1, l2, l3, elevation], axis=0)
    leaf_class_ids = multi_hot(line["leaf_class_ids"])
    return encoded_loc, leaf_class_ids

In [13]:
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)

In [14]:
!ls -lah {tfrecord_file}

-rw-r--r--@ 1 alex  staff   231M Oct 25 14:39 vision-export-20240929050006-aka-2.17/geo_spatial_grid_datasets/r6_empty_cells_with_elevation.tf


In [15]:
# isn't there a better way to do this? cardinality?
num_examples = len(list(raw_dataset))

In [16]:
multi_hot = tf.keras.layers.CategoryEncoding(
    num_tokens=num_classes, output_mode="multi_hot"
)
ds = raw_dataset.map(grid_parse_function)
ds = ds.map(preprocess_line)
ds = ds.shuffle(params["shuffle_buffer_size"], reshuffle_each_iteration=True)
ds = ds.batch(params["shuffle_buffer_size"])
ds = ds.repeat()
ds = ds.prefetch(tf.data.AUTOTUNE)

In [17]:
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=params["initial_lr"])
bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
fcnet.compile(
    optimizer=optimizer,
    loss=bce,
    metrics=["Precision", "Recall"],
)


In [18]:
steps_per_epoch = int(np.ceil(num_examples / params["batch_size"]))
steps_per_epoch

904

In [None]:
history = fcnet.fit(
    ds,
    steps_per_epoch=steps_per_epoch,
    epochs=3,
)

Epoch 1/3


2024-10-25 14:44:03.303220: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


  7/904 [..............................] - ETA: 2:25:23 - loss: 0.3096 - precision: 6.4094e-05 - recall: 0.0244

In [None]:
# with tf.device('cpu:0'):

#     history = fcnet.fit(
#         ds,
#         steps_per_epoch=steps_per_epoch,
#         epochs=3,
#     )
#     pass

Epoch 1/3
