In [None]:
import torch
import numpy as np
import rioxarray as rxr
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color, LinearSegmentedColormap
from terratorch import FULL_MODEL_REGISTRY

# Select device
if torch.cuda.is_available():
    device = 'cuda'    
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [None]:
# Build model that generates S-2 L2A (via tokenizer), S-1 GRD and LULC maps based on raw S-2 L2A input:
model = FULL_MODEL_REGISTRY.build(
    'terramind_v1_base_generate',
    modalities=['S2L2A'],  # Define the input
    output_modalities=['S2L2A', 'S1GRD', 'LULC'],  # Define the output
    pretrained=True,
    standardize=True,  # If standardize=True, you don't need to do the standardization yourself.
    # offset={'S2L2A': 1000}  # Optional offset in your data. The offset is also applied to the generation. 
)

_ = model.to(device)

# If you need the standardization values, see
from terratorch.models.backbones.terramind.model.terramind_register import v1_pretraining_mean, v1_pretraining_std

In [None]:
# Load an S-2 L2A example 
examples = [
    '../examples/S2L2A/38D_378R_2_3.tif',
    '../examples/S2L2A/282D_485L_3_3.tif',
    '../examples/S2L2A/433D_629L_3_1.tif',
    '../examples/S2L2A/637U_59R_1_3.tif',
    '../examples/S2L2A/609U_541L_3_0.tif',
]

# Select example between 0 and 4
data = rxr.open_rasterio(examples[1])
# Conver to shape [B, C, 224, 224]
data = torch.Tensor(data.values, device='cpu').unsqueeze(0)

In [None]:
# Visualize S-2 L2A input as RGB
rgb = data[0, [3,2,1]].clone().permute(1,2,0)
rgb = (rgb / 2000).clip(0, 1) * 255
rgb = rgb.cpu().numpy().round().astype(np.uint8)
plt.imshow(rgb)
plt.axis('off')
plt.show()

In [None]:
# Run model with diffusion steps
input = data.to(device)
with torch.no_grad():
  generated = model(input, verbose=True, timesteps=10)

In [None]:
# Select outputs
s2l2a = generated['S2L2A']
s1grd = generated['S1GRD']
lulc = generated['LULC']

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(20, 5))

# Visualize S-2 L2A input as RGB
rgb = data[0, [3,2,1]].clone().permute(1,2,0)
rgb = (rgb / 2000).clip(0, 1) * 255
rgb = rgb.cpu().numpy().round().astype(np.uint8)
ax[0].imshow(rgb)
ax[0].axis('off')
ax[0].set_title('S-2 L2A Input')

# Visualize S-2 L2A reconstruction as RGB
rgb = s2l2a[0, [3,2,1]].clone().permute(1,2,0)
rgb = (rgb / 2000).clip(0, 1) * 255
rgb = rgb.cpu().numpy().round().astype(np.uint8)
ax[1].imshow(rgb)
ax[1].axis('off')
ax[1].set_title('S-2 L2A Generation')

# Visualize S-1 RTC as RGB
vv = s1grd[0, 0:1].clone().cpu().numpy()
vh = s1grd[0, 1:2].clone().cpu().numpy()
vv = vv.clip(-30, 5)
vv = ((vv + 30) / 35 * 255).astype(np.uint8)
vh = vh.clip(-40, 0)
vh = ((vh + 40) / 40 * 255).astype(np.uint8)
rgb = np.concatenate([vh, vv, vh], axis=0).transpose(1,2,0)

ax[2].imshow(rgb)
ax[2].axis('off')
ax[2].set_title('S-1 RTC Generation')

# Visualize LULC
# 'No Data', 'Water', 'Trees', 'Flooded vegetation', 'Crops', 'Built area', 'Bare ground', 'Snow/ice', 'Clouds', 'Rangeland'
COLORBLIND_HEX = ["#000000", "#3171AD", "#469C76", '#83CA70', "#EAE159", "#C07CB8", "#C19368", "#6FB2E4", "#F1F1F1", "#C66526"]   
COLORBLIND_RGB = [hex2color(hex) for hex in COLORBLIND_HEX]
esri_cmap = LinearSegmentedColormap.from_list('esri', COLORBLIND_RGB, N=10)

lulc_map = lulc.clone().argmax(dim=1)  # LULC is always returned as logits.
lulc_map = lulc_map.cpu().numpy()[0]
ax[3].imshow(lulc_map, cmap=esri_cmap, vmin=0, vmax=9, interpolation='nearest')
ax[3].axis('off')
ax[3].set_title('S-1 RTC Generation')

plt.show()