In [1]:
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import pandas as pd
import timm
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.datasets.eurosat import dataset_statistics
from landsatbench.embed import extract_features
from landsatbench.eval import eval_knn, eval_linear_probe
from torchgeo.models import (
    ResNet18_Weights,
    ResNet50_Weights,
    ViTSmall16_Weights,
    resnet18,
    resnet50,
    vit_small_patch16_224,
)
from tqdm import tqdm

output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

k = 5
device = torch.device("cuda")

root = "../data"
dm = LandsatDataModule(name="eurosat", root=root, batch_size=16, num_workers=8)
dm.prepare_data()

### Transforms

In [2]:
# Imagenet transforms
mins = dataset_statistics["min"]
maxs = dataset_statistics["max"]
means = dataset_statistics["mean"] / maxs
stds = dataset_statistics["std"] / maxs
imagenet_transforms = K.ImageSequential(
    K.Normalize(mean=0.0, std=maxs), K.Normalize(mean=means, std=stds), K.Resize(224), keepdim=True
)

# SSL4EO-L transforms
min = 7272.72727272727272727272
max = 18181.81818181818181818181
ssl4eol_transforms = K.ImageSequential(
    K.Normalize(mean=min, std=1.0),
    K.Normalize(mean=0.0, std=max - min),
    K.Resize(224),
    keepdim=True,
)

## Models

In [16]:
def create_model(model_name, weights, **kwargs):
    model = timm.create_model(model_name, pretrained=False, **kwargs)
    state_dict = weights.get_state_dict(progress=True)
    if "norm.weight" in state_dict:
        state_dict["fc_norm.weight"] = state_dict["norm.weight"]
        state_dict["fc_norm.bias"] = state_dict["norm.bias"]
        del state_dict["norm.weight"]
        del state_dict["norm.bias"]
    model.load_state_dict(state_dict, strict=False)
    return model

In [None]:
models = {
    "resnet18-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet18", pretrained=True),
    ),
    "resnet50-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet50", pretrained=True),
    ),
    "vits-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="vit_small_patch16_224", pretrained=True),
    ),
    "resnet18-ssl4eol-moco": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet18-ssl4eol-simclr": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "resnet50-ssl4eol-moco": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet50-ssl4eol-simclr": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "vits-ssl4eol-moco": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "vits-ssl4eol-simclr": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
}

### Generate Embeddings

In [19]:
for name, v in tqdm(models.items(), total=len(models)):
    print(f"Generating embeddings for {name}...")
    model = v["model"](**v["kwargs"], num_classes=0, in_chans=7).to(device)
    transforms = v["transforms"]

    dm.setup("fit")
    x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms)

    dm.setup("test")
    x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms)

    filename = os.path.join(output_dir, f"eurosat-{name}.npz")
    np.savez(
        filename,
        x_train=x_train,
        y_train=y_train.astype(np.int16),
        x_test=x_test,
        y_test=y_test.astype(np.int16),
    )

  0%|          | 0/2 [00:00<?, ?it/s]

Generating embeddings for vits-ssl4eol-moco...


100%|██████████| 1013/1013 [00:23<00:00, 42.82it/s]
100%|██████████| 338/338 [00:08<00:00, 39.60it/s]
 50%|█████     | 1/2 [00:32<00:32, 32.84s/it]

Generating embeddings for vits-ssl4eol-simclr...


100%|██████████| 1013/1013 [00:23<00:00, 42.87it/s]
100%|██████████| 338/338 [00:08<00:00, 40.68it/s]
100%|██████████| 2/2 [01:05<00:00, 32.71s/it]


### Compute Metrics

In [20]:
# Save predictions
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(name)
    filename = os.path.join(output_dir, f"eurosat-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_knn(x_train, y_train, x_test, y_test, k=k, scale=False)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("eurosat-knn-results.csv")

  0%|          | 0/2 [00:00<?, ?it/s]

vits-ssl4eol-moco
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|           0.822037 |            0.822037 |            0.825975 |         0.822037 |         0.817863 |     0.822037 |     0.818444 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+


 50%|█████     | 1/2 [00:00<00:00,  1.11it/s]

vits-ssl4eol-simclr


100%|██████████| 2/2 [00:01<00:00,  1.19it/s]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |
|           0.824074 |            0.824074 |            0.824365 |         0.824074 |         0.817476 |     0.824074 |     0.818038 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+





In [None]:
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(f"Evaluating {name}...")
    filename = os.path.join(output_dir, f"eurosat-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_linear_probe(x_train, y_train, x_test, y_test)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("eurosat-lp-results.csv")

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...


 11%|█         | 1/9 [00:45<06:00, 45.12s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.809815 |            0.809815 |            0.810086 |         0.809815 |         0.804619 |     0.809815 |     0.804464 |      0.891767 |      0.868487 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-imagenet...


 22%|██▏       | 2/9 [15:44<1:03:52, 547.55s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|               0.79 |                0.79 |            0.792642 |             0.79 |         0.784506 |         0.79 |     0.784881 |      0.882747 |      0.860991 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-imagenet...


 33%|███▎      | 3/9 [18:37<37:38, 376.37s/it]  

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.863889 |            0.863889 |            0.869926 |         0.863889 |         0.860897 |     0.863889 |     0.863187 |      0.931559 |      0.920612 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet18-ssl4eol-moco...


 44%|████▍     | 4/9 [22:35<26:49, 321.85s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.828333 |            0.828333 |             0.82928 |         0.828333 |         0.823269 |     0.828333 |       0.8237 |      0.912005 |      0.891593 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet18-ssl4eol-simclr...


 56%|█████▌    | 5/9 [24:24<16:20, 245.13s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|              0.685 |               0.685 |            0.671323 |            0.685 |         0.682865 |        0.685 |      0.67236 |      0.772236 |       0.72561 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-moco...


 67%|██████▋   | 6/9 [33:36<17:28, 349.49s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|            0.82463 |             0.82463 |             0.82589 |          0.82463 |         0.820793 |      0.82463 |     0.820525 |      0.907131 |      0.885899 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-simclr...


 78%|███████▊  | 7/9 [44:15<14:48, 444.20s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|            0.73463 |             0.73463 |            0.727464 |          0.73463 |         0.729573 |      0.73463 |     0.724979 |      0.830385 |      0.790142 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-moco...


 89%|████████▉ | 8/9 [47:19<06:01, 361.34s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.887963 |            0.887963 |            0.894178 |         0.887963 |         0.885518 |     0.887963 |     0.887755 |      0.953285 |      0.943742 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-simclr...


100%|██████████| 9/9 [48:25<00:00, 322.83s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.776481 |            0.776481 |            0.768097 |         0.776481 |         0.768796 |     0.776481 |     0.763816 |      0.858516 |      0.823774 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+



