In [None]:
# needs terratorch installed `pip install terratorch`
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.embed import extract_features
from landsatbench.eval import eval_knn, eval_linear_probe

root = "../data"
output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

k = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "satlas_swinv2b"

### EuroSAT-L

In [2]:
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchvision.models import swin_v2_b


def create_model(weights, **kwargs):
    model = swin_v2_b(weights=None, **kwargs)
    num_channels = weights.meta["in_chans"]
    out_channels = model.features[0][0].out_channels
    model.features[0][0] = torch.nn.Conv2d(
        num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
    )
    state_dict = weights.get_state_dict(progress=True)
    model.load_state_dict(state_dict, strict=False)
    model.head = torch.nn.Identity()  # remove head
    return model


model = create_model(weights=Swin_V2_B_Weights.LANDSAT_SI_SATLAS)

In [21]:
with torch.inference_mode():
    x = torch.randn(1, 11, 256, 256)
    print(model(x).shape)  # should be [1, 1024, 16, 16]

torch.Size([1, 1024])


In [None]:
import torch.nn as nn
from kornia.contrib import Lambda
from torchgeo.transforms.transforms import _Clamp

satlas_transforms = nn.Sequential(
    K.Resize(256),
    K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
    _Clamp(p=1, min=0, max=1),
    # pad 4 channels to 11 channels from b,7,h,w to b,11,h,w
    Lambda(lambda x: torch.cat([x, torch.zeros_like(x[:, :4])], dim=1)),
)

dm = LandsatDataModule(name="eurosat", root=root, batch_size=32, num_workers=8, download=False)
dm.prepare_data()
dm.setup("fit")

In [None]:
dm.setup("fit")
x_train, y_train = extract_features(
    model, dm.train_dataloader(), device, transforms=satlas_transforms
)

dm.setup("test")
x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms=satlas_transforms)

filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
np.savez(
    filename,
    x_train=x_train,
    y_train=y_train.astype(np.int16),
    x_test=x_test,
    y_test=y_test.astype(np.int16),
)

  0%|          | 0/1013 [00:00<?, ?it/s]

100%|██████████| 1013/1013 [01:44<00:00,  9.71it/s]
100%|██████████| 338/338 [00:35<00:00,  9.57it/s]


In [None]:
# KNN eval
filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
embeddings = np.load(filename)
x_train, y_train, x_test, y_test = (
    embeddings["x_train"],
    embeddings["y_train"],
    embeddings["x_test"],
    embeddings["y_test"],
)
metrics = eval_knn(x_train, y_train, x_test, y_test, k=5, scale=True)

In [None]:
# Linear Probe
filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
embeddings = np.load(filename)
x_train, y_train, x_test, y_test = (
    embeddings["x_train"],
    embeddings["y_train"],
    embeddings["x_test"],
    embeddings["y_test"],
)
metrics = eval_linear_probe(x_train, y_train, x_test, y_test, scale=True)



+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|           0.895926 |            0.895926 |            0.901938 |         0.895926 |         0.893021 |     0.895926 |     0.895491 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
