In [1]:
# needs terratorch installed `pip install terratorch`
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.embed import extract_features
from landsatbench.eval import eval_knn, eval_linear_probe
from terratorch import BACKBONE_REGISTRY

root = "../data"
output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

k = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
list(BACKBONE_REGISTRY["terratorch"])

['prithvi_eo_tiny',
 'prithvi_eo_v1_100',
 'prithvi_eo_v2_300',
 'prithvi_eo_v2_600',
 'prithvi_eo_v2_300_tl',
 'prithvi_eo_v2_600_tl',
 'dofa_small_patch16_224',
 'dofa_base_patch16_224',
 'dofa_large_patch16_224',
 'dofa_huge_patch16_224',
 'satlas_swin_t_sentinel2_mi_ms',
 'satlas_swin_t_sentinel2_mi_rgb',
 'satlas_swin_t_sentinel2_si_ms',
 'satlas_swin_t_sentinel2_si_rgb',
 'satlas_swin_b_sentinel2_mi_ms',
 'satlas_swin_b_sentinel2_mi_rgb',
 'satlas_swin_b_sentinel2_si_ms',
 'satlas_swin_b_sentinel2_si_rgb',
 'satlas_swin_b_naip_mi_rgb',
 'satlas_swin_b_naip_si_rgb',
 'satlas_swin_b_landsat_mi_ms',
 'satlas_swin_b_landsat_mi_rgb',
 'satlas_swin_b_sentinel1_mi',
 'satlas_swin_b_sentinel1_si',
 'ssl4eol_resnet18_landsat_tm_toa_moco',
 'ssl4eol_resnet18_landsat_tm_toa_simclr',
 'ssl4eol_resnet18_landsat_etm_toa_moco',
 'ssl4eol_resnet18_landsat_etm_toa_simclr',
 'ssl4eol_resnet18_landsat_etm_sr_moco',
 'ssl4eol_resnet18_landsat_etm_sr_simclr',
 'ssl4eol_resnet18_landsat_oli_tirs_toa_m

### EuroSAT-L

In [None]:
from landsatbench.datasets.eurosat import dataset_statistics

# Prithvi transforms
prithvi_transforms = K.ImageSequential(
    K.Normalize(mean=dataset_statistics["mean"], std=dataset_statistics["std"]),
    K.Resize(224),
    keepdim=True,
)

dm = LandsatDataModule(name="eurosat", root=root, batch_size=32, num_workers=8, download=False)
dm.prepare_data()
dm.setup("fit")


class TerraTorchModel(torch.nn.Module):
    def __init__(self, model, pool=False):
        super().__init__()
        self.model = model
        self.pool = pool

    def forward(self, x):
        x = x[:, 1:, ...]  # remove the first band (SR_1) which prithvi does not use
        x = self.model(x)[-1]
        if self.pool:
            x = x[:, 1:, :].mean(dim=1)
        else:
            x = x[:, 0, :]
        return x


model_name = "prithvi_eo_v1_100"

model = BACKBONE_REGISTRY.build(model_name, pretrained=True, num_frames=1)
model.eval()
model = model.to(device)
model = TerraTorchModel(model, pool=True)

config.json:   0%|          | 0.00/895 [00:00<?, ?B/s]

Prithvi_EO_V1_100M.pt:   0%|          | 0.00/454M [00:00<?, ?B/s]

In [3]:
dm.setup("fit")
x_train, y_train = extract_features(
    model, dm.train_dataloader(), device, transforms=prithvi_transforms
)

dm.setup("test")
x_test, y_test = extract_features(
    model, dm.test_dataloader(), device, transforms=prithvi_transforms
)

filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
np.savez(
    filename,
    x_train=x_train,
    y_train=y_train.astype(np.int16),
    x_test=x_test,
    y_test=y_test.astype(np.int16),
)

  0%|          | 0/507 [00:00<?, ?it/s]

100%|██████████| 507/507 [01:01<00:00,  8.29it/s]
100%|██████████| 169/169 [00:20<00:00,  8.26it/s]


In [4]:
# KNN eval
filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
embeddings = np.load(filename)
x_train, y_train, x_test, y_test = (
    embeddings["x_train"],
    embeddings["y_train"],
    embeddings["x_test"],
    embeddings["y_test"],
)
metrics = eval_knn(x_train, y_train, x_test, y_test, k=5, scale=True)

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|           0.864444 |            0.864444 |            0.870047 |         0.864444 |         0.861133 |     0.864444 |     0.860225 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+


In [None]:
# Linear Probe
filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
embeddings = np.load(filename)
x_train, y_train, x_test, y_test = (
    embeddings["x_train"],
    embeddings["y_train"],
    embeddings["x_test"],
    embeddings["y_test"],
)
metrics = eval_linear_probe(x_train, y_train, x_test, y_test, scale=True)



+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|           0.911111 |            0.911111 |              0.9183 |         0.911111 |         0.908047 |     0.911111 |      0.91125 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
