In [None]:
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.embed import extract_features
from landsatbench.eval import eval_knn, eval_linear_probe
from torchgeo.models import DOFALarge16_Weights, dofa_large_patch16_224

root = "../data"
output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

bands = ("SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7")

model_name = "dofa_large_patch16_224"
k = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## EuroSAT-L

In [7]:
from landsatbench.datasets.eurosat import dataset_statistics

# Prithvi transforms
transforms = K.ImageSequential(
    K.Normalize(mean=dataset_statistics["mean"], std=dataset_statistics["std"]),
    K.Resize(224),
    keepdim=True,
)

dm = LandsatDataModule(name="eurosat", root=root, batch_size=32, num_workers=8, download=False)
dm.prepare_data()
dm.setup("fit")


class DOFAModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.wavelengths = [0.443, 0.482, 0.561, 0.655, 0.865, 1.610, 2.200]

    def forward(self, x):
        return self.model.forward_features(x, wavelengths=self.wavelengths)


# model_name = "dofa_base_patch16_224"
# model = DOFAModel(dofa_base_patch16_224(weights=DOFABase16_Weights.DOFA_MAE))

model_name = "dofa_large_patch16_224"
model = DOFAModel(dofa_large_patch16_224(weights=DOFALarge16_Weights.DOFA_MAE))

model = model.eval()
model = model.to(device)

Downloading: "https://hf.co/torchgeo/dofa/resolve/b8db318b64a90b9e085ec04ba8851233c5893666/dofa_large_patch16_224-0ff904d3.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/dofa_large_patch16_224-0ff904d3.pth


100%|██████████| 1.26G/1.26G [00:25<00:00, 53.4MB/s]


In [8]:
dm.setup("fit")
x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms=transforms)

dm.setup("test")
x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms=transforms)

filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
np.savez(
    filename,
    x_train=x_train,
    y_train=y_train.astype(np.int16),
    x_test=x_test,
    y_test=y_test.astype(np.int16),
)

  0%|          | 0/507 [00:00<?, ?it/s]

100%|██████████| 507/507 [02:31<00:00,  3.35it/s]
100%|██████████| 169/169 [00:51<00:00,  3.30it/s]


In [9]:
# KNN eval
filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
embeddings = np.load(filename)
x_train, y_train, x_test, y_test = (
    embeddings["x_train"],
    embeddings["y_train"],
    embeddings["x_test"],
    embeddings["y_test"],
)
metrics = eval_knn(x_train, y_train, x_test, y_test, k=5, scale=True)

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|            0.84537 |             0.84537 |             0.85927 |          0.84537 |         0.841297 |      0.84537 |     0.845984 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+


In [10]:
# Linear Probe
filename = os.path.join(output_dir, f"eurosat-{model_name}.npz")
embeddings = np.load(filename)
x_train, y_train, x_test, y_test = (
    embeddings["x_train"],
    embeddings["y_train"],
    embeddings["x_test"],
    embeddings["y_test"],
)
metrics = eval_linear_probe(x_train, y_train, x_test, y_test, scale=True)



+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|           0.897593 |            0.897593 |            0.904595 |         0.897593 |         0.895392 |     0.897593 |     0.898124 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
