In [None]:
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import pandas as pd
import timm
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.datasets.bigearthnet import dataset_statistics
from landsatbench.embed import extract_features
from landsatbench.eval import eval_knn, eval_linear_probe
from torchgeo.models import (
    ResNet18_Weights,
    ResNet50_Weights,
    ViTSmall16_Weights,
    resnet18,
    resnet50,
    vit_small_patch16_224,
)
from tqdm import tqdm

output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

k = 5
device = torch.device("cuda")

root = "../data"
dm = LandsatDataModule(name="bigearthnet", root=root, batch_size=16, num_workers=4, download=False)
dm.prepare_data()

### Transforms

In [2]:
# Imagenet transforms
mins = dataset_statistics["min"]
maxs = dataset_statistics["max"]
means = dataset_statistics["mean"] / maxs
stds = dataset_statistics["std"] / maxs
imagenet_transforms = K.ImageSequential(
    K.Normalize(mean=0.0, std=maxs), K.Normalize(mean=means, std=stds), K.Resize(224), keepdim=True
)

# SSL4EO-L transforms
min = 7272.72727272727272727272
max = 18181.81818181818181818181
ssl4eol_transforms = K.ImageSequential(
    K.Normalize(mean=min, std=1.0),
    K.Normalize(mean=0.0, std=max - min),
    K.Resize(224),
    keepdim=True,
)

## Models

In [3]:
models = {
    "resnet18-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet18", pretrained=True),
    ),
    "resnet50-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet50", pretrained=True),
    ),
    "vits-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="vit_small_patch16_224", pretrained=True),
    ),
    "resnet18-ssl4eol-moco": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet18-ssl4eol-simclr": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "resnet50-ssl4eol-moco": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet50-ssl4eol-simclr": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "vits-ssl4eol-moco": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "vits-ssl4eol-simclr": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
}

### Generate Embeddings

In [None]:
for name, v in tqdm(models.items(), total=len(models)):
    print(f"Embedding {name}...")
    model = v["model"](**v["kwargs"], num_classes=0, in_chans=7).to(device)
    transforms = v["transforms"]

    dm.setup("fit")
    x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms)

    dm.setup("test")
    x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms)

    filename = os.path.join(output_dir, f"bigearthnet-{name}.npz")
    np.savez(
        filename,
        x_train=x_train,
        y_train=y_train.astype(np.int16),
        x_test=x_test,
        y_test=y_test.astype(np.int16),
    )

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...




### Compute Metrics

In [None]:
# Save predictions
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(f"Evaluating {name}...")
    filename = os.path.join(output_dir, f"bigearthnet-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_knn(
        x_train, y_train, x_test, y_test, k=k, scale=False, multilabel=True, faiss=True
    )
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("bigearthnet-knn-results.csv")

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.216977 |               0.638 |            0.528234 |         0.545188 |         0.419432 |     0.587954 |     0.456209 |      0.575238 |      0.427049 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 11%|█         | 1/9 [00:34<04:35, 34.47s/it]

Evaluating resnet50-imagenet...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|            0.22377 |             0.64043 |            0.535399 |         0.552848 |         0.427657 |     0.593425 |      0.46347 |      0.582858 |      0.434346 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 22%|██▏       | 2/9 [01:34<05:44, 49.27s/it]

Evaluating vits-imagenet...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.237618 |            0.667313 |            0.561163 |         0.585549 |         0.456159 |     0.623763 |     0.490793 |      0.619212 |      0.463364 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 33%|███▎      | 3/9 [01:55<03:39, 36.54s/it]

Evaluating resnet18-ssl4eol-moco...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|            0.23293 |            0.661783 |              0.5612 |         0.586431 |         0.459789 |     0.621832 |     0.496511 |      0.614236 |      0.461983 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 44%|████▍     | 4/9 [02:19<02:37, 31.44s/it]

Evaluating resnet18-ssl4eol-simclr...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.192776 |            0.581964 |            0.453455 |         0.497602 |          0.37139 |     0.536487 |     0.400562 |      0.514084 |      0.368676 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 56%|█████▌    | 5/9 [02:42<01:53, 28.47s/it]

Evaluating resnet50-ssl4eol-moco...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.222332 |            0.642362 |            0.534392 |         0.570517 |         0.444464 |     0.604311 |     0.476548 |        0.5934 |      0.443103 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 67%|██████▋   | 6/9 [03:40<01:55, 38.42s/it]

Evaluating resnet50-ssl4eol-simclr...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.202247 |             0.59981 |            0.470933 |         0.517634 |         0.392173 |     0.555701 |     0.420365 |      0.538564 |      0.387963 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 78%|███████▊  | 7/9 [04:35<01:27, 43.90s/it]

Evaluating vits-ssl4eol-moco...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.248812 |            0.682044 |            0.586692 |         0.615951 |          0.49983 |     0.647315 |     0.529707 |      0.644567 |      0.499811 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


 89%|████████▉ | 8/9 [04:56<00:36, 36.59s/it]

Evaluating vits-ssl4eol-simclr...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|           0.262239 |            0.700201 |            0.624772 |         0.642703 |         0.537459 |     0.670222 |     0.570857 |      0.670367 |      0.537914 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


100%|██████████| 9/9 [05:16<00:00, 35.20s/it]


In [None]:
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(f"Evaluating {name}...")
    filename = os.path.join(output_dir, f"bigearthnet-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_linear_probe(x_train, y_train, x_test, y_test, multilabel=True)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("bigearthnet-lp-results.csv")

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...


  0%|          | 0/9 [00:05<?, ?it/s]
