In [None]:
import os
import torch
import numpy as np
import rioxarray as rxr
import matplotlib.pyplot as plt
from terratorch import FULL_MODEL_REGISTRY
from plotting_utils import plot_modality

# Select device
if torch.cuda.is_available():
    device = 'cuda'    
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [None]:
# Load input data
examples = [
    '38D_378R_2_3.tif',
    '282D_485L_3_3.tif',
    '433D_629L_3_1.tif',
    '637U_59R_1_3.tif',
    '609U_541L_3_0.tif',
]

# Select example between 0 and 4
file = examples[0]

# Define modalities
modalities = ['S2L2A', 'S1RTC', 'DEM', 'LULC', 'NDVI']
data = {m: rxr.open_rasterio(f'../examples/{m}/{file}') for m in modalities}
# Tensor with shape [B, C, 224, 224]
data = {
    k: torch.Tensor(v.values, device='cpu').unsqueeze(0)
    for k, v in data.items()
}

In [None]:
# Run any-to-any generation (this can take a while without a GPU, consider reducing timesteps for faster inference)
outputs = {}
for m in modalities:
    print(f'Processing {m}')
    out_modalities = modalities[:]
    out_modalities.remove(m)
    
    # Init model
    model = FULL_MODEL_REGISTRY.build(
        'terramind_v1_base_generate',
        modalities=[m],
        output_modalities=out_modalities,
        pretrained=True,
        standardize=True,
    )
    _ = model.to(device)
    
    input = data[m].clone().to(device)
    with torch.no_grad():
      generated = model(input, verbose=True, timesteps=10)
    outputs[m] = generated

In [None]:
# Plot any-to-any generations
n_mod = len(modalities)
fig, axes = plt.subplots(nrows=n_mod, ncols=n_mod + 1, figsize=[12, 10])
axes[0][0].set_title('Input')
for i, m in enumerate(modalities):
    axes[0][i + 1].set_title(m)

for (m, input), ax in zip(data.items(), axes):
    plot_modality(m, input, ax=ax[0])
    for a in ax:
        a.axis('off')

for k, m_output in enumerate(outputs.values()):
    for m, out in m_output.items():        
        j = modalities.index(m) + 1
        plot_modality(m, out, ax=axes[k][j])
        
plt.savefig(f'any_to_any_{os.path.basename(file)}.pdf')
plt.show()