In [None]:
import sys

sys.path.append("..")

import os

import kornia.augmentation as K
import numpy as np
import pandas as pd
import timm
import torch
from landsatbench.datamodule import LandsatDataModule
from landsatbench.datasets.lc100 import dataset_statistics
from landsatbench.embed import extract_features
from landsatbench.eval import eval_knn, eval_linear_probe
from torchgeo.models import (
    ResNet18_Weights,
    ResNet50_Weights,
    ViTSmall16_Weights,
    resnet18,
    resnet50,
    vit_small_patch16_224,
)
from tqdm import tqdm

output_dir = "../embeddings"
os.makedirs(output_dir, exist_ok=True)

k = 5
device = torch.device("mps")

root = "../data"
dm = LandsatDataModule(name="lc100", root=root, batch_size=16, num_workers=4, download=False)
dm.prepare_data()

### Transforms

In [2]:
# Imagenet transforms
mins = dataset_statistics["min"]
maxs = dataset_statistics["max"]
means = dataset_statistics["mean"]
stds = dataset_statistics["std"]
imagenet_transforms = K.ImageSequential(
    K.Normalize(mean=mins, std=1.0),
    K.Normalize(mean=0.0, std=maxs),
    K.Normalize(mean=means, std=stds),
)

# SSL4EO-L transforms
min = 7272.72727272727272727272
max = 18181.81818181818181818181
ssl4eol_transforms = K.ImageSequential(
    K.Normalize(mean=min, std=1.0), K.Normalize(mean=0.0, std=max - min)
)

## Models

In [3]:
models = {
    "resnet18-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet18", pretrained=True),
    ),
    "resnet50-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="resnet50", pretrained=True),
    ),
    "vits-imagenet": dict(
        model=timm.create_model,
        transforms=imagenet_transforms,
        kwargs=dict(model_name="vit_small_patch16_224", pretrained=True),
    ),
    "resnet18-ssl4eol-moco": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet18-ssl4eol-simclr": dict(
        model=resnet18,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "resnet50-ssl4eol-moco": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "resnet50-ssl4eol-simclr": dict(
        model=resnet50,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
    "vits-ssl4eol-moco": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO),
    ),
    "vits-ssl4eol-simclr": dict(
        model=vit_small_patch16_224,
        transforms=ssl4eol_transforms,
        kwargs=dict(weights=ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR),
    ),
}

### Generate Embeddings

In [None]:
for name, v in tqdm(models.items(), total=len(models)):
    print(f"Embedding {name}...")
    model = v["model"](**v["kwargs"], num_classes=0, in_chans=7).to(device)
    transforms = v["transforms"]

    dm.setup("fit")
    x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms)

    dm.setup("test")
    x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms)

    filename = os.path.join(output_dir, f"lc100-{name}.npz")
    np.savez(
        filename,
        x_train=x_train,
        y_train=y_train.astype(np.int16),
        x_test=x_test,
        y_test=y_test.astype(np.int16),
    )

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...




### Compute Metrics

In [None]:
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(f"Evaluating {name}...")
    filename = os.path.join(output_dir, f"lc100-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_knn(x_train, y_train, x_test, y_test, k=5, scale=False, multilabel=True)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("lc100-knn-results.csv")

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Evaluating resnet18-imagenet...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00753187 |            0.588987 |            0.358604 |         0.567311 |         0.347003 |     0.577946 |     0.350185 |      0.565084 |      0.364424 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Evaluating resnet50-imagenet...


 22%|██▏       | 2/9 [00:00<00:02,  3.32it/s]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00637312 |            0.588854 |            0.360049 |         0.562564 |         0.344989 |     0.575409 |      0.34978 |      0.565839 |      0.367034 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-imagenet...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 33%|███▎      | 3/9 [00:00<00:01,  3.83it/s]

Evaluating resnet18-ssl4eol-moco...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00637312 |            0.593656 |            0.359798 |         0.581894 |         0.355066 |     0.587716 |     0.354165 |       0.56896 |      0.366258 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 56%|█████▌    | 5/9 [00:01<00:00,  4.44it/s]

Evaluating resnet18-ssl4eol-simclr...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00579374 |            0.603768 |            0.373922 |         0.586366 |          0.36013 |      0.59494 |     0.362416 |      0.576867 |      0.368457 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-moco...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 67%|██████▋   | 6/9 [00:01<00:00,  3.59it/s]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00579374 |            0.586428 |            0.363905 |         0.579005 |          0.35896 |     0.582693 |     0.358915 |       0.56239 |      0.365622 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-simclr...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 89%|████████▉ | 8/9 [00:02<00:00,  3.76it/s]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00637312 |            0.594939 |            0.360234 |         0.575703 |         0.350599 |     0.585163 |     0.352064 |      0.567751 |      0.365065 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-moco...
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 9/9 [00:02<00:00,  3.87it/s]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00753187 |            0.584982 |            0.367167 |         0.579831 |         0.361184 |     0.582395 |     0.361073 |      0.557752 |      0.363224 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+





### Linear Probing

In [4]:
all_metrics = dict()
for name in tqdm(models, total=len(models)):
    print(f"Evaluating {name}...")
    filename = os.path.join(output_dir, f"lc100-{name}.npz")
    embeddings = np.load(filename)
    x_train, y_train, x_test, y_test = (
        embeddings["x_train"],
        embeddings["y_train"],
        embeddings["x_test"],
        embeddings["y_test"],
    )
    metrics = eval_linear_probe(x_train, y_train, x_test, y_test, multilabel=True)
    all_metrics[name] = metrics

pd.DataFrame(all_metrics).T.to_csv("lc100-lp-results.csv")

  0%|          | 0/9 [00:00<?, ?it/s]

Evaluating resnet18-imagenet...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 11%|█         | 1/9 [00:32<04:16, 32.09s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00579374 |            0.594923 |            0.359801 |         0.562633 |         0.342108 |     0.578328 |     0.348068 |      0.620584 |      0.370534 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-imagenet...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 22%|██▏       | 2/9 [02:27<09:28, 81.28s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00637312 |            0.584639 |            0.364494 |         0.567105 |         0.352645 |     0.575739 |     0.356569 |      0.606489 |      0.365946 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-imagenet...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 33%|███▎      | 3/9 [03:40<07:42, 77.15s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00695249 |            0.572441 |            0.359735 |         0.561257 |         0.352972 |     0.566794 |     0.355588 |      0.605113 |      0.366522 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet18-ssl4eol-moco...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 44%|████▍     | 4/9 [06:28<09:26, 113.22s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00521437 |            0.598416 |            0.367635 |         0.576804 |         0.352345 |     0.587411 |     0.355091 |      0.616008 |      0.367968 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet18-ssl4eol-simclr...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 56%|█████▌    | 5/9 [07:59<07:00, 105.16s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00405562 |            0.623402 |             0.37087 |         0.583821 |         0.343188 |     0.602963 |     0.344339 |       0.63796 |      0.368235 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-moco...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 67%|██████▋   | 6/9 [23:40<19:28, 389.41s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00637312 |            0.593004 |            0.364798 |           0.5691 |         0.349295 |     0.580806 |       0.3534 |      0.604643 |      0.366445 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating resnet50-ssl4eol-simclr...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 78%|███████▊  | 7/9 [33:30<15:09, 454.97s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00463499 |            0.617504 |            0.369826 |         0.576598 |         0.340756 |      0.59635 |     0.343679 |      0.621088 |      0.363875 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-moco...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 89%|████████▉ | 8/9 [35:01<05:38, 338.98s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00695249 |            0.576286 |            0.363949 |         0.572401 |         0.361582 |     0.574337 |     0.362018 |      0.598746 |       0.36679 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
Evaluating vits-ssl4eol-simclr...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 9/9 [35:28<00:00, 236.54s/it]

+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+
|   overall_accuracy |   overall_precision |   average_precision |   overall_recall |   average_recall |   overall_f1 |   average_f1 |   overall_map |   average_map |
|         0.00579374 |            0.608711 |            0.363405 |         0.577767 |         0.344686 |     0.592836 |     0.345042 |      0.633713 |      0.364635 |
+--------------------+---------------------+---------------------+------------------+------------------+--------------+--------------+---------------+---------------+



