# Using `xbatcher` to train an autoencoder

---

## Imports

In [None]:
import os
from importlib import reload

# DL stuff
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

torch.set_default_dtype(torch.float64)

# Geospatial stuff
import xarray as xr
import xbatcher
import rioxarray
import xbatcher
from xbatcher.loaders.torch import MapDataset

# Etc
import numpy as np
from numpy.linalg import norm
from matplotlib import pyplot as plt
from tqdm.autonotebook import tqdm

# Locals
import functions
import autoencoder

## Get data

We will start by pulling a segment of NASADEM for Washington's Olympic peninsula.

In [None]:
# Rasterio adds a blank edge. Trim these out.
dem = rioxarray.open_rasterio("../ASTGTMV003_N47W124_dem.tif")
dem = dem.isel(y=slice(0, -1), x=slice(0, -1))
dem = (dem - dem.min()) / (dem.max() - dem.min())
dem

In [None]:
dem.isel(band=0).plot.imshow(cmap="terrain")

## Generate training examples
Here, we use xbatcher to window patches of terrain in the same location.

In [None]:
bgen_x = xbatcher.BatchGenerator(
    dem,
    input_dims=dict(x=32, y=32),
    input_overlap=dict(x=16, y=16)
)

ds = MapDataset(
    X_generator=bgen_x,)

loader = torch.utils.data.DataLoader(ds, batch_size=16, shuffle=True)

In [None]:
X = next(iter(loader))

print("Input tensor shape:", X.shape)

## Model setup

In [None]:
m = autoencoder.Autoencoder(base_channel_size=32, latent_dim=64, num_input_channels=1, width=32, height=32)
opt = m._configure_optimizers()

In [None]:
out = m(X)
print(out.shape)

## Model training

We aren't using pytorch-lightning and load a pre-trained model here to keep the notebook environment lean. For your project, we highly recommend using a framework to abstract away much of the boilerplate code below.

:::{danger}
This model is certainly overfitted. For brevity we have omitted a validation dataset, which is essential for building models that generalize well on unseen data.
:::

In [None]:
m.load_state_dict(torch.load("../autoencoder.torch", weights_only=True))

In [None]:
m.eval()
n_examples = 4
inputs = next(iter(loader))
outputs = m(inputs)

inputs = inputs.detach().cpu().numpy()
outputs = outputs.detach().cpu().numpy()

In [None]:
fig, axes = plt.subplots(n_examples, 2)

for i_row in range(n_examples):
    axes[i_row, 0].imshow(inputs[i_row, 0, ...])
    axes[i_row, 1].imshow(outputs[i_row, 0, ...])

for a in axes.flat:
    a.set_xticks([])
    a.set_yticks([])

axes[0, 0].set_title("Original patch")
axes[0, 1].set_title("Reconstruction")

fig.tight_layout()
plt.show()

## Reconstruction 1: Getting the full array back

Suppose we would like to evaluate how the autoencoder does on reconstructing the entire terrain patch by combining outputs across all input patches. To do so we can use the `predict_on_array` function described in the previous notebook. Our model outputs tensors with shape `(band=1, x=32, y=32)`. We need to specify each of these axes in the call to `predict_on_array`. `channel` does not change size and is not used by the `BatchGenerator`, so it goes in `core_dim`. Both `x` and `y` are used by the `BatchGenerator`, so although they do not change size they still go in `resample_dim`. That accounts for all tensor axes, so we can leave the `new_dim` argument as an empty list.

In [None]:
dem_reconst = functions.predict_on_array(
    dataset=ds,
    model=m,
    output_tensor_dim=dict(band=1, y=32, x=32),
    new_dim=[],
    core_dim=["band"],
    resample_dim=["x", "y"]
)

In [None]:
dem_reconst.isel(band=0).plot.imshow(cmap="terrain")

That certainly looks like the original DEM. Let's try plotting the error in the reconstruction.

In [None]:
err = (dem_reconst - dem)
err.isel(band=0).plot.imshow()
plt.show()

In [None]:
err.plot.hist()
plt.show()

Not bad!

## Reconstruction 2: Getting the latent dimension

A common application of autoencoders is to use the latent dimension for some application. Let's turn our autoencoder's predictions into a data cube. To do so we will modify the batch generator to not have overlapping windows. We also have to slightly clip the size of the input DEM. This is because we are effectively downscaling the spatial axes by a factor of 32. Since `3600 / 32` is not an integer, `predict_on_array` will not know how to rescale the array size. So, we have to clip the DEM to the nearest integer multiple of 32. In this case the nearest multiple is 3584, which we achieve by clipping 8 pixels from each side.

In [None]:
bgen_no_overlap = xbatcher.BatchGenerator(
    dem.isel(x=slice(8, -8), y=slice(8, -8)),
    input_dims=dict(x=32, y=32),
    input_overlap=dict(x=0, y=0)
)

ds_no_overlap = MapDataset(
    X_generator=bgen_no_overlap
)

loader = torch.utils.data.DataLoader(ds_no_overlap, batch_size=16, shuffle=True)

ex_input = next(iter(loader))

# Same as before
print("Input shape:", ex_input.shape)

Next we will write a function that the calls the encoder arm of the autoencoder and adds a fake x and y dimension.

In [None]:
def infer_with_encoder(x):
    return m.encoder(x)[:, None, None, :]

ex_output = infer_with_encoder(ex_input)
print("Output shape:", ex_output.shape)

Now we combine the outputs together into a new data cube.

In [None]:
latent_dim_cube = functions.predict_on_array(
    dataset=ds_no_overlap,
    model=infer_with_encoder,
    output_tensor_dim=dict(y=1, x=1, channel=64),
    new_dim=["channel"],
    core_dim=[],
    resample_dim=["x", "y"]
)

In [None]:
latent_dim_cube

Note that despite substantially re-arranging the input `DataArray`, we have retained the coordinate information at a resampled resolution.

If we simply sum the output over the channel dimension, we see that the encoder clearly distinguishes between upland and lowland areas.

In [None]:
latent_dim_cube.sum(dim="channel").plot.imshow()

As a final demonstration of this workflow, let's compute the cosine similarity of each of the below pixels with the latent encoding of [Mt. Olympus](https://en.wikipedia.org/wiki/Mount_Olympus_(Washington)).

In [None]:
olympus = dict(x=-123.7066, y=47.7998)
olympus_latent = latent_dim_cube.sel(**olympus, method="nearest")
olympus_latent

In [None]:
def numpy_cosine_similarity(x, y):
    return np.dot(x, y)/(norm(x)*norm(y))

In [None]:
olympus_similarity = xr.apply_ufunc(
    numpy_cosine_similarity,
    latent_dim_cube,
    input_core_dims = [["channel"]],
    output_core_dims = [[]],
    vectorize=True,
    kwargs=dict(y=olympus_latent.data)
)

In [None]:
olympus_similarity.plot.imshow()
plt.scatter(olympus["x"], olympus["y"], marker="*", c="purple", edgecolor="black", s=200)
plt.title("Cosine similarity with Mt. Olympus, WA")
plt.show()

Similarly, we can identify foothills with similar topography to Grisdale, WA.

In [None]:
grisdale = dict(y=47.356625678465925, x=-123.61183314426664)
grisdale_latent = latent_dim_cube.sel(**grisdale, method="nearest")

grisdale_similarity = xr.apply_ufunc(
    numpy_cosine_similarity,
    latent_dim_cube,
    input_core_dims = [["channel"]],
    output_core_dims = [[]],
    vectorize=True,
    kwargs=dict(y=grisdale_latent.data)
)

grisdale_similarity.plot.imshow()
plt.scatter(grisdale["x"], grisdale["y"], marker="*", c="purple", edgecolor="black", s=200)
plt.show()

This result is admittedly very similar to if we had just selected elevation bands :)

---

## Summary
Our goal with this notebook has been to show how xbatcher supports linking `xarray` objects with deep learning models, and with converting model output back into labeled `xarray` objects. We have demonstrated two examples of reconstructing model output, both when tensor shape changes and when it does not.

If you encounter any issues, please open an issue on the GitHub repository for this cookbook. Other feedback is welcome!