In [1]:
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import pandas as pd
import timm
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.datasets.eurosat import dataset_statistics
from landsatbench.embed import extract_features
from landsatbench.eval import eval_linear_probe
from torchgeo.models import (
    ResNet18_Weights,
    ResNet50_Weights,
    ViTSmall16_Weights,
    resnet18,
    resnet50,
    vit_small_patch16_224,
)
from tqdm import tqdm

output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

k = 5
device = torch.device("mps")

root = "../data"
dm = LandsatDataModule(name="eurosat", root=root, batch_size=16, num_workers=8)
dm.prepare_data()

### Transforms

In [2]:
# Imagenet transforms
mins = dataset_statistics["min"]
maxs = dataset_statistics["max"]
means = dataset_statistics["mean"] / maxs
stds = dataset_statistics["std"] / maxs
imagenet_transforms = K.ImageSequential(
    K.Normalize(mean=0.0, std=maxs), K.Normalize(mean=means, std=stds), K.Resize(224), keepdim=True
)

# SSL4EO-L transforms
min = 7272.72727272727272727272
max = 18181.81818181818181818181
ssl4eol_transforms = K.ImageSequential(
    K.Normalize(mean=min, std=1.0),
    K.Normalize(mean=0.0, std=max - min),
    K.Resize(224),
    keepdim=True,
)

## Models

In [3]:
models = {
    "resnet18-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet18", pretrained=True),
    ),
    "resnet50-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet50", pretrained=True),
    ),
    "vits-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="vit_small_patch16_224", pretrained=True),
    ),
    "resnet18-ssl4eol-moco": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet18-ssl4eol-simclr": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "resnet50-ssl4eol-moco": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet50-ssl4eol-simclr": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "vits-ssl4eol-moco": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "vits-ssl4eol-simclr": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
}

### Generate Embeddings

In [None]:
for name, v in tqdm(models.items(), total=len(models)):
    print(f"Evaluating {name}...")
    model = v["model"](**v["kwargs"], num_classes=0, in_chans=7).to(device)
    transforms = v["transforms"]

    dm.setup("fit")
    x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms)

    dm.setup("test")
    x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms)

    filename = os.path.join(output_dir, f"eurosat-{name}.npz")
    np.savez(
        filename,
        x_train=x_train,
        y_train=y_train.astype(np.int16),
        x_test=x_test,
        y_test=y_test.astype(np.int16),
    )

### Compute Metrics

In [None]:
# Save predictions
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    filename = os.path.join(output_dir, f"eurosat-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval(x_train, y_train, x_test, y_test, k=k, scale=False)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("eurosat-knn-results.csv")

 11%|█         | 1/9 [00:00<00:01,  5.68it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.753939 |     0.762037 |            0.762037 |            0.763046 |         0.762037 |         0.752965 |   0.762037 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+


 33%|███▎      | 3/9 [00:01<00:01,  3.00it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.761879 |     0.764074 |            0.764074 |            0.782292 |         0.764074 |          0.75653 |   0.764074 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.793088 |     0.794259 |            0.794259 |            0.811936 |         0.794259 |          0.78818 |   0.794259 |
+--------------+--------------+---------------------+---------------------+------------------+-----------------

 44%|████▍     | 4/9 [00:01<00:01,  3.62it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.784216 |         0.79 |                0.79 |            0.791724 |             0.79 |         0.785625 |       0.79 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+


 56%|█████▌    | 5/9 [00:01<00:00,  4.01it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.695199 |     0.697963 |            0.697963 |            0.706667 |         0.697963 |         0.691464 |   0.697963 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+


 67%|██████▋   | 6/9 [00:02<00:01,  2.47it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.741937 |     0.751111 |            0.751111 |            0.746433 |         0.751111 |         0.742933 |   0.751111 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+


 89%|████████▉ | 8/9 [00:03<00:00,  2.47it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.717163 |     0.726111 |            0.726111 |             0.72301 |         0.726111 |         0.718711 |   0.726111 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.812096 |     0.815926 |            0.815926 |            0.820409 |         0.815926 |         0.811567 |   0.815926 |
+--------------+--------------+---------------------+---------------------+------------------+-----------------

100%|██████████| 9/9 [00:03<00:00,  2.82it/s]

+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+
|   f1_average |   f1_overall |   precision_overall |   precision_average |   recall_overall |   recall_average |   accuracy |
|     0.830212 |     0.834444 |            0.834444 |            0.839676 |         0.834444 |           0.8279 |   0.834444 |
+--------------+--------------+---------------------+---------------------+------------------+------------------+------------+





In [None]:
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(f"Evaluating {name}...")
    filename = os.path.join(output_dir, f"eurosat-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_linear_probe(x_train, y_train, x_test, y_test)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("eurosat-lp-results.csv")

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...


 11%|█         | 1/9 [00:45<06:00, 45.12s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.809815 |            0.809815 |            0.810086 |         0.809815 |         0.804619 |     0.809815 |     0.804464 |      0.891767 |      0.868487 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-imagenet...


 22%|██▏       | 2/9 [15:44<1:03:52, 547.55s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|               0.79 |                0.79 |            0.792642 |             0.79 |         0.784506 |         0.79 |     0.784881 |      0.882747 |      0.860991 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-imagenet...


 33%|███▎      | 3/9 [18:37<37:38, 376.37s/it]  

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.863889 |            0.863889 |            0.869926 |         0.863889 |         0.860897 |     0.863889 |     0.863187 |      0.931559 |      0.920612 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet18-ssl4eol-moco...


 44%|████▍     | 4/9 [22:35<26:49, 321.85s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.828333 |            0.828333 |             0.82928 |         0.828333 |         0.823269 |     0.828333 |       0.8237 |      0.912005 |      0.891593 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet18-ssl4eol-simclr...


 56%|█████▌    | 5/9 [24:24<16:20, 245.13s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|              0.685 |               0.685 |            0.671323 |            0.685 |         0.682865 |        0.685 |      0.67236 |      0.772236 |       0.72561 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-moco...


 67%|██████▋   | 6/9 [33:36<17:28, 349.49s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|            0.82463 |             0.82463 |             0.82589 |          0.82463 |         0.820793 |      0.82463 |     0.820525 |      0.907131 |      0.885899 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-simclr...


 78%|███████▊  | 7/9 [44:15<14:48, 444.20s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|            0.73463 |             0.73463 |            0.727464 |          0.73463 |         0.729573 |      0.73463 |     0.724979 |      0.830385 |      0.790142 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-moco...


 89%|████████▉ | 8/9 [47:19<06:01, 361.34s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.887963 |            0.887963 |            0.894178 |         0.887963 |         0.885518 |     0.887963 |     0.887755 |      0.953285 |      0.943742 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-simclr...


100%|██████████| 9/9 [48:25<00:00, 322.83s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.776481 |            0.776481 |            0.768097 |         0.776481 |         0.768796 |     0.776481 |     0.763816 |      0.858516 |      0.823774 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+



