In [None]:
import os
import torch
import numpy as np
import rioxarray as rxr
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color, LinearSegmentedColormap
from terratorch import FULL_MODEL_REGISTRY

# Select device
if torch.cuda.is_available():
    device = 'cuda'    
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [None]:
# Load input data
examples = [
    '38D_378R_2_3.tif',
    '282D_485L_3_3.tif',
    '433D_629L_3_1.tif',
    '637U_59R_1_3.tif',
    '609U_541L_3_0.tif',
]

# Select example between 0 and 4
file = examples[0]

# Define modalities
modalities = ['S2L2A', 'S1RTC', 'DEM', 'LULC', 'NDVI']
data = {m: rxr.open_rasterio(f'../examples/{m}/{file}') for m in modalities}
# Tensor with shape [B, C, 224, 224]
data = {
    k: torch.Tensor(v.values, device='cpu').unsqueeze(0)
    for k, v in data.items()
}

In [None]:
# Run any-to-any generation (this can take a while without a GPU, consider reducing timesteps for faster inference)
outputs = {}
for m in modalities:
    print(f'Processing {m}')
    out_modalities = modalities[:]
    out_modalities.remove(m)
    
    # Init model
    model = FULL_MODEL_REGISTRY.build(
        'terramind_v1_base_generate',
        modalities=[m],
        output_modalities=out_modalities,
        pretrained=True,
        standardize=True,
    )
    _ = model.to(device)
    
    input = data[m].clone().to(device)
    with torch.no_grad():
      generated = model(input, verbose=True, timesteps=10)
    outputs[m] = generated

In [None]:
# Plotting utils
COLORBLIND_HEX = ["#000000", "#3171AD", "#469C76", '#83CA70', "#EAE159", "#C07CB8", "#C19368", "#6FB2E4", "#F1F1F1", "#C66526"]   
COLORBLIND_RGB = [hex2color(hex) for hex in COLORBLIND_HEX]
esri_cmap = LinearSegmentedColormap.from_list('esri', COLORBLIND_RGB, N=10)

def plot_modality_data(modality, data, ax):
    # Remove batch dim
    data = data[0].clone().cpu().numpy()
    interpolation = cmap = vmin = vmax = None
    if modality in ['S2L1C', 'S2L2A']:
        data = data[[3,2,1]]
        data = data.clip(0, 2000)
        data = (data / 2000 * 255).astype(np.uint8)
    elif modality in ['S1RTC', 'S1GRD']:
        # RGB bands [VH, VV, VH]
        vv = data[0:1]
        vh = data[1:2]
        vv = vv.clip(-30, 5)
        vv = ((vv + 30) / 35 * 255).astype(np.uint8)
        vh = vh.clip(-40, 0)
        vh = ((vh + 40) / 40 * 255).astype(np.uint8)
        data = np.concatenate([vh, vv, vh], axis=0)
    elif modality in ['NDVI']:
        cmap = 'RdYlGn'
        vmin, vmax = -1, 1
    elif modality in ['DEM']:
        data_min, data_max = np.min(data) - 5, np.max(data) + 5
        data = (data - data_min) / (data_max - data_min + 1e-6) * 255
        cmap = 'BrBG_r'
        vmin, vmax= 0, 255
    elif modality in ['LULC']:
        if data.shape[0] > 1:
            data = data.argmax(axis=0)
        cmap = esri_cmap
        vmin, vmax= 0, 9
        interpolation = 'nearest'
    else:
        raise ValueError(f'Unknown modality: {modality}')

    if len(data.shape) == 3:
        data = data.transpose(1,2,0)
    
    ax.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax, interpolation=interpolation)
    ax.axis('off')

In [None]:
# Plot generations
n_mod = len(modalities)
fig, axes = plt.subplots(nrows=n_mod, ncols=n_mod + 1, figsize=[12, 10])
axes[0][0].set_title('Input')
for i, m in enumerate(modalities):
    axes[0][i + 1].set_title(m)

for (m, input), ax in zip(data.items(), axes):
    plot_modality_data(m, input, ax[0])
    for a in ax:
        a.axis('off')

for k, m_output in enumerate(outputs.values()):
    for m, out in m_output.items():        
        j = modalities.index(m) + 1
        plot_modality_data(m, out, axes[k][j])
        
plt.savefig(f'any_to_any_{os.path.basename(file)}.pdf')
plt.show()

# Note: TerraMind uses chained generations (see 4M for details)