# Jeo Quickstart

This notebook provides a hands-on introduction to the `jeo` library, a framework for training deep learning models for geospatial remote sensing and Earth Observation using JAX and Flax. It is designed to work with large-scale datasets, often created using Google Earth Engine (GEE) via the accompanying `geeflow` library.


In this quickstart, you will walk through the fundamental steps of the **training flow**. To keep things simple and fast, we will use the standard CIFAR-10 image dataset instead of a large geospatial one. You will learn how to:
1.  Load a configuration file.
2.  Instantiate a training dataset using a `tf.data` input pipeline.
3.  Define a model and initialize its parameters.
4.  Set up an optimizer and a learning rate schedule.
5.  Execute a single training step to see how the model's weights are updated.
Finally, we will demonstrate how to run the end-to-end training process using the main `train.main()` function.

In [None]:
import functools
import importlib

import logging
import sys
from absl import logging as absl_logging


import jax
import jax.numpy as jnp
import flax
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tree

%matplotlib inline

tf.compat.v1.enable_eager_execution()
print("tf.executing_eagerly(): ", tf.executing_eagerly())
print("JAX devices:\n  " + "\n  ".join([repr(d) for d in jax.devices()]))

In [None]:
# Make absl.logging use the standard Python logging setup.
absl_logging.use_python_logging(quiet=True)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for h in logger.handlers[:]:
    logger.removeHandler(h)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter(
    '%(levelname)s: %(asctime)s %(filename)s:%(lineno)d] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
absl_logging.info("ABSL logging is now configured for this notebook.")
logging.info("Standard logging is also configured.")

In [None]:
"""# Following the trainer flow

This section presents an example workflow of the trainer, and we'll perform the following steps:

1. Load a config `jeo/configs/tests/tiny_bit.py`. It is a simple classification using BiT model on Cifar-10 dataset.
2. Instantiate the train dataset.
3. Inspect the train dataset.
4. Instantiate the model.
5. Load model parameters.
6. Perform a training step.
7. Run evaluation.

Note: You can as well follow the `jeo/train.py` module, which does these steps.
"""
# Get config

# Has tree_info for quick inspection of nested structures.
from jeo.tools import inspect
from jeo.configs.tests import tiny_bit

tiny_bit = importlib.reload(tiny_bit)

config = tiny_bit.get_config()

# We can adjust it even more if needed.
config.batch_size = 256
config.val_steps = 2  # Don't perform more than 2 steps at evaluation.

# Show config.
config

In [None]:
# Let's quickly look into the specification of preprocessing:
print(config.pp_train)

# This string describes ops to be applied on the input data within the tf.data.Dataset input pipeline.
print("Input pipeline ops are applied in sequential order:")
for x in config.pp_train.split("|"):
  print(f"  {x}")

## Train dataset

In [None]:
# Let's specify batch sizes and set the random seed.

rng = jax.random.PRNGKey(0)
rng, rng_loop = jax.random.split(rng, 2)

batch_size_train = config.batch_size
batch_size_eval = config.get("batch_size_eval", default=batch_size_train)
local_batch_size_train = batch_size_train // jax.process_count()
local_batch_size_eval = batch_size_eval // jax.process_count()

# Get the dataset (using TFDS for Cifar-10)
from jeo import input_pipeline
from jeo.pp import pp_builder  # Preprocessing fn builder.

# fillin() function is used to replace xm-related tokens in paths.
# Since we don't use XM, we can keep as identity.
fillin = lambda x: x

train_ds, num_train_examples = input_pipeline.get_data(
    train=True,
    dataset=config.dataset,
    split=config.train_split,
    data_dir=fillin(config.get("dataset_dir")),
    dataset_module=config.get("dataset_module"),
    **config.get("dataset_kwargs", default={}),
    batch_size=local_batch_size_train,
    preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train),
    shuffle_buffer_size=config.get("shuffle_buffer_size"),
    prefetch=config.get("prefetch_to_host", 2),
    cache_raw=False)

# Start prefetching already.
train_iter = input_pipeline.start_input_pipeline(train_ds, config.get("prefetch_to_device", 1))

# Let's look into the spec of the data (single batch):
train_ds.element_spec

In [None]:
# Get some examples.

# Get 8 batches:
ex = [jax.tree.map(np.array, next(train_iter)) for _ in range(8)]

# Show content shapes:
inspect.tree_info(ex[0])

In [None]:
# Note: each example has 2 batch dimensions right now: (num_devices,
# batch_size_per_device).

# Visualize
plt.figure(figsize=(13, 6))
for i in range(8):
  plt.subplot(2, 4, 1+i)
  # Since our images are scaled into the range [-1, 1], we need to rescale into [0, 1] for visualization
  plt.imshow(ex[i]['image'][0, 0] /2.0 +0.5)  # only first image of batch
  # Labels are already one-hot encoded (as specified in the pp_train string).
  plt.title(str(ex[i]['labels'][0, 0]))

# get number of images, steps per epoch, total steps
ntrain_img = input_pipeline.get_num_examples(
      config.dataset, config.train_split,
      data_dir=config.get("dataset_dir"))
steps_per_epoch = ntrain_img / batch_size_train

if config.get("total_epochs"):
  total_steps = int(config.total_epochs * steps_per_epoch)
else:
  total_steps = config.total_steps

ntrain_img, steps_per_epoch, total_steps

## Model

In [None]:
# Let's get the model
from jeo import train_utils


print("Get module based on the one specified in config: ", config.model_name)
model_mod = train_utils.import_module(config.model_name, "models")

# Instantiate model with config overwrites in config.model
model = model_mod.Model(**config.model)

# Let's have a look at the model - that's a very simple ResNet with just a few args.
model

In [None]:
# Let's initialize the model.

@functools.partial(jax.jit, backend="cpu")
def init(rng):
  image_size = tuple(train_ds.element_spec["image"].shape[1:])
  x = jnp.zeros((local_batch_size_train,) + image_size, jnp.float32)
  variables = model.init(rng, x, train=True)
  model_state, params = flax.core.pop(variables, "params")
  params = flax.core.unfreeze(params)
  # Set bias in the head to a low value, such that loss is small initially.
  if "init_head_bias" in config:
    params["head"]["bias"] = jnp.full_like(params["head"]["bias"],
                                            config["init_head_bias"])
  return params, model_state

rng, rng_init, rng_dropout, rng_mask = jax.random.split(rng, 4)
rng_init = {"params": rng_init, "dropout": rng_dropout, "mask": rng_mask,
            "masking": rng_mask}
params_cpu, state_cpu = init(rng_init)

In [None]:
# Let's inspect the params - model weights.
inspect.tree_info(params_cpu)

# This particular model doesn't have state params (as eg. used for BatchNorm)
print(state_cpu)

## Optimizer

In [None]:
# Get optimizer
from jeo.tools import bv_optax

tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict(
    global_batch_size=batch_size_train,
    total_steps=total_steps,
    steps_per_epoch=steps_per_epoch))

# We jit this, such that the arrays are created on the CPU, not device[0].
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]

# Let's visualize the (relative!) learning rate schedule - the scaling with base_lr happens later.
plt.plot([sched_fns_cpu[0](x) for x in range(total_steps)])

# This particular shape is given by the spec of warmup_epochs/steps (eg. 1ep) and decay type (cosine).

## Train step

In [None]:
# Let's define an update train step:
import jeo.tasks.classification
from jeo.tasks import task_builder
import optax

# Get the problem task.
task = task_builder.from_config(config)

In [None]:
@functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1))
def update_fn(params, state, opt, rng, batch):
  """Update step."""
  measurements = {}
  # Get device-specific loss rng.
  rng, rng_model = jax.random.split(rng, 2)
  rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

  def loss_fn(params, state, batch):
    model_inputs = task.model_inputs(batch)
    model_outputs, mutated_state = model.apply(
        {"params": flax.core.freeze(params), **state}, *model_inputs,
        train=True, mutable=list(state.keys()),
        rngs={"dropout": rng_model_local, "mask": rng_model_local,
              "masking": rng_model_local})
    loss, aux = task.get_loss_and_aux(model_outputs, batch, train=True)
    aux["state"] = mutated_state
    return loss, aux

  (l, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, state,
                                                              batch)
  l, aux, grads = jax.lax.pmean((l, aux, grads), axis_name="batch")
  updates, opt = tx.update(grads, opt, params)
  params = optax.apply_updates(params, updates)
  state = aux.pop("state")

  gs = jax.tree.leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
  measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
  ps = jax.tree.leaves(params)
  measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
  us = jax.tree.leaves(updates)
  measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
  st = jax.tree.leaves(state)
  measurements["l2_state"] = jnp.sqrt(sum([jnp.vdot(s, s) for s in st]))
  for k, v in aux.items():
    measurements[f"train_{k}"] = v

  return params, state, opt, rng, l, measurements

In [None]:
# Before doing any training, we need to replicate the parameters across devices:
import flax.jax_utils as flax_utils

params_repl = flax_utils.replicate(params_cpu)
state_repl = flax_utils.replicate(state_cpu)
opt_repl = flax_utils.replicate(opt_cpu)
rngs_loop = flax_utils.replicate(rng_loop)

# Get next train batch

train_batch = next(train_iter)

# Do a training step

out = update_fn(params_repl, state_repl, opt_repl, rngs_loop, train_batch)

# The train-step update_fn returns:
#   new updated params, potentially updated states, updated optimizer states, a new rng seed, loss, and measurement metrics.

params, state, opt, rng, l, measurements = out

# You can inspect all of them eg. with inspect.tree_info(params)

# Let's have a look at computed measurements:
measurements

In [None]:
# Note that the values within the arrays are the same, just replicated across
# devices. So, a natural thing to do is to just take the first element (or use
# unreplicate).

flax_utils.unreplicate(measurements)

# End-to-End run with train.main()

In [None]:
from jeo import evaluators
from jeo import train
from jeo.configs.tests import tiny_bit

config = tiny_bit.get_config()
# Update config to run in a colab (in dependence of used accelerators).
config.xprof = False
config.batch_size = 64
config.total_epochs = 2  # 2 epochs is enough for demonstration.
if "fewshot" in config.evals:
  del config.evals.fewshot  # fewshot eval is currently not available.

from absl import flags
FLAGS = flags.FLAGS

FLAGS.config = config
FLAGS.cleanup = False
FLAGS.workdir = "/tmp/jeo_test/run1"

train.main(None)