In [None]:
import torch
import numpy as np
import rioxarray as rxr
import matplotlib.pyplot as plt
from terratorch import FULL_MODEL_REGISTRY
from terratorch.models.backbones.terramind.model.terramind_register import v1_pretraining_mean, v1_pretraining_std

# Select device
if torch.cuda.is_available():
    device = 'cuda'    
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

In [None]:
# Build model
model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_s2l2a', pretrained=True)

# For other modalities:
# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_s1rtc', pretrained=True)
# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_dem', pretrained=True)
# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_lulc', pretrained=True)
# model = FULL_MODEL_REGISTRY.build('terramind_v1_tokenizer_ndvi', pretrained=True)

_ = model.to(device)

In [None]:
# Load an example (Replace S2L2A in the file paths for other modalities) 
examples = [
    '../examples/S2L2A/38D_378R_2_3.tif',
    '../examples/S2L2A/282D_485L_3_3.tif',
    '../examples/S2L2A/433D_629L_3_1.tif',
    '../examples/S2L2A/637U_59R_1_3.tif',
    '../examples/S2L2A/609U_541L_3_0.tif',
]

# Select example between 0 and 4
data = rxr.open_rasterio(examples[1])
# Conver to shape [B, C, 224, 224]
data = torch.Tensor(data.values, device='cpu').unsqueeze(0)

In [None]:
# Visualize S-2 L2A input as RGB
rgb = data[0, [3,2,1]].clone().permute(1,2,0)
rgb = (rgb / 2000).clip(0, 1) * 255
rgb = rgb.cpu().numpy().round().astype(np.uint8)
plt.imshow(rgb)
plt.axis('off')
plt.show()

In [None]:
# Normalize input
mean = torch.Tensor(v1_pretraining_mean['untok_sen2l2a@224'])
std = torch.Tensor(v1_pretraining_std['untok_sen2l2a@224'])
input = (data - mean[None, :, None, None]) / std[None, :, None, None]

# See keys for other modalities:
# v1_pretraining_mean.keys()

In [None]:
# Run model with diffusion steps
input = input.to(device)
with torch.no_grad():
    # Encode & decode image
    reconstruction = model(input, timesteps=10)

    # Alternatively split the encoding and decoding process to analyze tokens 
    # Encode image
    # _, _, tokens = model.encode(input)
    # Decode tokens
    # reconstruction = model.decode_tokens(tokens, verbose=True, timesteps=10)

# Denormalize
reconstruction = reconstruction.cpu()
reconstruction = (reconstruction * std[None, :, None, None]) + mean[None, :, None, None]

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

# Visualize S-2 L2A input as RGB
rgb = data[0, [3,2,1]].clone().permute(1,2,0)
rgb = (rgb / 2000).clip(0, 1) * 255
rgb = rgb.cpu().numpy().round().astype(np.uint8)
ax[0].imshow(rgb)
ax[0].axis('off')
ax[0].set_title('Input')

# Visualize S-2 L2A reconstruction as RGB
rgb = reconstruction[0, [3,2,1]].clone().permute(1,2,0)
rgb = (rgb / 2000).clip(0, 1) * 255
rgb = rgb.cpu().numpy().round().astype(np.uint8)
ax[1].imshow(rgb)
ax[1].axis('off')
ax[1].set_title('Reconstruction')

plt.show()