In [None]:
import torch
import rioxarray as rxr
import matplotlib.pyplot as plt
from terratorch.registry import FULL_MODEL_REGISTRY
from plotting_utils import plot_s2, plot_modality

# Select device
if torch.cuda.is_available():
    device = 'cuda'    
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [None]:
# Build model that generates S-2 L2A (via tokenizer), S-1 GRD, DEM, LULC, and NDVI based on raw S-2 L2A input:
model = FULL_MODEL_REGISTRY.build(
    'terramind_v1_base_generate',
    modalities=['S2L2A'],  # Define the input
    output_modalities=['S2L2A', 'S1GRD', 'DEM', 'LULC', 'NDVI'],  # Define the output
    pretrained=True,
    standardize=True,  # If standardize=True, you don't need to do the standardization yourself.
    # offset={'S2L2A': 1000}  # Optional offset in your data. The offset is also applied to the generation. 
)

_ = model.to(device)

# If you need the standardization values, see
from terratorch.models.backbones.terramind.model.terramind_register import v1_pretraining_mean, v1_pretraining_std

In [None]:
# Load an S-2 L2A example 
examples = [
    '../examples/S2L2A/38D_378R_2_3.tif',
    '../examples/S2L2A/282D_485L_3_3.tif',
    '../examples/S2L2A/433D_629L_3_1.tif',
    '../examples/S2L2A/637U_59R_1_3.tif',
    '../examples/S2L2A/609U_541L_3_0.tif',
]

# Select example between 0 and 4
data = rxr.open_rasterio(examples[1])
# Conver to shape [B, C, 224, 224]
data = torch.Tensor(data.values, device='cpu').unsqueeze(0)

In [None]:
# Visualize S-2 L2A input as RGB
plot_s2(data)

In [None]:
# Run model with diffusion steps
input = data.to(device)
with torch.no_grad():
  generated = model(input, verbose=True, timesteps=10)

In [None]:
# Plot generations
n_plots = len(generated) + 1
fig, ax = plt.subplots(1, n_plots, figsize=(5 * n_plots, 5))

plot_s2(input, ax=ax[0])
ax[0].set_title('Input')

for i, (mod, value) in enumerate(generated.items()):
    plot_modality(mod, value, ax=ax[i + 1])

    ax[i+1].set_title('generated ' + mod)
    
plt.show()