In [24]:
import os
import sys
import json
from itertools import product
import pickle
from pprint import pprint

import pandas as pd
import kornia.augmentation as K
import numpy as np
import torch
import torch.nn as nn
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    f1_score,
    precision_score,
    recall_score,
)
from sklearn.preprocessing import label_binarize


sys.path.append("..")
from src.models import get_model_by_name
from src.datasets.treesatai import TreeSatAI, TreeSatAIDataModule
from src.utils import extract_features

device = torch.device("cuda")

In [7]:
# Fit
model_names = [
    "resnet50_pretrained_moco",
    "resnet18_pretrained_moco",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
rgbs = [False, True]
sizes = [34, 224]

for model_name, rgb, size in product(model_names, rgbs, sizes):
    run = f"{model_name}{'_rgb' if rgb else ''}_{size}"
    print(f"Extracting features for {run}")
    if rgb:
        bands = TreeSatAI.rgb_bands
        pad_missing_band = False
    else:
        bands = TreeSatAI.correct_band_order
        pad_missing_band = True

    dm = TreeSatAIDataModule(
        root="../data/treesatai/",
        bands=bands,
        multilabel=False,
        batch_size=32,
        num_workers=16,
        pad_missing_band=pad_missing_band,
        seed=0,
    )
    dm.setup()

    model = get_model_by_name(model_name, rgb, device=device)

    if model_name == "imagestats":
        transforms = nn.Sequential(nn.Identity()).to(device)
    else:
        transforms = nn.Sequential(K.Resize(size)).to(device)

    x_train, y_train = extract_features(
        model, dm.train_dataloader(), device, transforms=transforms
    )
    x_test, y_test = extract_features(
        model, dm.test_dataloader(), device, transforms=transforms
    )
    data = dict(x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
    with open(f"{run}.pkl", "wb") as f:
        pickle.dump(data, f)

Evaluating resnet50_pretrained_moco_34


100%|██████████| 1417/1417 [00:24<00:00, 58.22it/s]
100%|██████████| 158/158 [00:03<00:00, 45.91it/s]


Evaluating resnet50_pretrained_moco_224


100%|██████████| 1417/1417 [02:14<00:00, 10.55it/s]
100%|██████████| 158/158 [00:16<00:00,  9.69it/s]


Evaluating resnet50_pretrained_moco_rgb_34


100%|██████████| 1417/1417 [00:21<00:00, 66.75it/s]
100%|██████████| 158/158 [00:03<00:00, 45.25it/s]


Evaluating resnet50_pretrained_moco_rgb_224


100%|██████████| 1417/1417 [02:04<00:00, 11.39it/s]
100%|██████████| 158/158 [00:14<00:00, 10.79it/s]


Evaluating resnet18_pretrained_moco_34


100%|██████████| 1417/1417 [00:16<00:00, 84.73it/s] 
100%|██████████| 158/158 [00:02<00:00, 58.30it/s] 


Evaluating resnet18_pretrained_moco_224


100%|██████████| 1417/1417 [00:48<00:00, 29.14it/s]
100%|██████████| 158/158 [00:06<00:00, 24.93it/s]


Evaluating resnet18_pretrained_moco_rgb_34


100%|██████████| 1417/1417 [00:16<00:00, 86.67it/s] 
100%|██████████| 158/158 [00:02<00:00, 58.92it/s] 


Evaluating resnet18_pretrained_moco_rgb_224


100%|██████████| 1417/1417 [00:41<00:00, 33.82it/s]
100%|██████████| 158/158 [00:05<00:00, 27.57it/s]


Evaluating resnet50_pretrained_imagenet_34


100%|██████████| 1417/1417 [00:22<00:00, 64.31it/s]
100%|██████████| 158/158 [00:03<00:00, 43.01it/s]


Evaluating resnet50_pretrained_imagenet_224


100%|██████████| 1417/1417 [02:09<00:00, 10.98it/s]
100%|██████████| 158/158 [00:15<00:00, 10.35it/s]


Evaluating resnet50_pretrained_imagenet_rgb_34


100%|██████████| 1417/1417 [00:20<00:00, 67.49it/s]
100%|██████████| 158/158 [00:03<00:00, 44.96it/s]


Evaluating resnet50_pretrained_imagenet_rgb_224


100%|██████████| 1417/1417 [02:03<00:00, 11.45it/s]
100%|██████████| 158/158 [00:14<00:00, 10.70it/s]


Evaluating resnet50_randominit_34


100%|██████████| 1417/1417 [00:21<00:00, 65.09it/s]
100%|██████████| 158/158 [00:03<00:00, 46.42it/s]


Evaluating resnet50_randominit_224


100%|██████████| 1417/1417 [02:15<00:00, 10.43it/s]
100%|██████████| 158/158 [00:16<00:00,  9.73it/s]


Evaluating resnet50_randominit_rgb_34


100%|██████████| 1417/1417 [00:21<00:00, 67.06it/s]
100%|██████████| 158/158 [00:03<00:00, 43.02it/s]


Evaluating resnet50_randominit_rgb_224


100%|██████████| 1417/1417 [02:08<00:00, 11.00it/s]
100%|██████████| 158/158 [00:14<00:00, 10.74it/s]


Evaluating imagestats_34


100%|██████████| 1417/1417 [00:11<00:00, 126.77it/s]
100%|██████████| 158/158 [00:02<00:00, 75.91it/s] 


Evaluating imagestats_224


100%|██████████| 1417/1417 [00:10<00:00, 130.22it/s]
100%|██████████| 158/158 [00:02<00:00, 78.68it/s] 


Evaluating imagestats_rgb_34


100%|██████████| 1417/1417 [00:10<00:00, 134.70it/s]
100%|██████████| 158/158 [00:02<00:00, 78.14it/s] 


Evaluating imagestats_rgb_224


100%|██████████| 1417/1417 [00:10<00:00, 134.56it/s]
100%|██████████| 158/158 [00:01<00:00, 79.01it/s] 


Evaluating mosaiks_512_3_34


100%|██████████| 1417/1417 [00:14<00:00, 97.07it/s] 
100%|██████████| 158/158 [00:02<00:00, 61.43it/s] 


Evaluating mosaiks_512_3_224


100%|██████████| 1417/1417 [02:52<00:00,  8.21it/s]
100%|██████████| 158/158 [00:20<00:00,  7.79it/s]


Evaluating mosaiks_512_3_rgb_34


100%|██████████| 1417/1417 [00:13<00:00, 104.43it/s]
100%|██████████| 158/158 [00:02<00:00, 64.87it/s] 


Evaluating mosaiks_512_3_rgb_224


100%|██████████| 1417/1417 [02:32<00:00,  9.27it/s]
100%|██████████| 158/158 [00:17<00:00,  8.84it/s]


In [32]:
# Eval
model_names = [
    "resnet50_pretrained_moco",
    "resnet18_pretrained_moco",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
rgbs = [False, True]
sizes = [34, 224]

results = {}
K = 5
for model_name, rgb, size in product(model_names, rgbs, sizes):
    run = f"{model_name}{'_rgb' if rgb else ''}_{size}"
    print(f"Evaluating {run}")

    filename = f"{run}.pkl"
    if not os.path.exists(filename):
        continue

    with open(filename, "rb") as f:
        data = pickle.load(f)

    x_train = data["x_train"]
    y_train = data["y_train"]
    x_test = data["x_test"]
    y_test = data["y_test"]

    knn_model = KNeighborsClassifier(n_neighbors=K, n_jobs=8)
    knn_model.fit(X=x_train, y=y_train)

    y_test_onehot = label_binarize(y_test, classes=np.arange(len(TreeSatAI.classes)))
    y_pred = knn_model.predict(x_test)
    y_score = knn_model.predict_proba(x_test)

    metrics = {
        "mAP_weighted": average_precision_score(
            y_test_onehot, y_score, average="weighted"
        ),
        "mAP_micro": average_precision_score(y_test_onehot, y_score, average="micro"),
        "f1_weighted": f1_score(y_test, y_pred, average="weighted"),
        "f1_micro": f1_score(y_test, y_pred, average="micro"),
        "precision_micro": precision_score(y_test, y_pred, average="micro"),
        "precision_weighted": precision_score(y_test, y_pred, average="weighted"),
        "recall_micro": recall_score(y_test, y_pred, average="micro"),
        "recall_weighted": recall_score(y_test, y_pred, average="weighted"),
        "accuracy": accuracy_score(y_test, y_pred),
    }
    pprint(metrics)
    results[run] = metrics

Evaluating resnet50_pretrained_moco_34
{'accuracy': 0.2985725614591594,
 'f1_micro': 0.2985725614591594,
 'f1_weighted': 0.3081249888902874,
 'mAP_micro': 0.22489636870060092,
 'mAP_weighted': 0.24418412576051687,
 'precision_micro': 0.2985725614591594,
 'precision_weighted': 0.35023996267735086,
 'recall_micro': 0.2985725614591594,
 'recall_weighted': 0.2985725614591594}
Evaluating resnet50_pretrained_moco_224
{'accuracy': 0.3788659793814433,
 'f1_micro': 0.3788659793814433,
 'f1_weighted': 0.3908721677492111,
 'mAP_micro': 0.3089993997091709,
 'mAP_weighted': 0.3220018407854935,
 'precision_micro': 0.3788659793814433,
 'precision_weighted': 0.4281310472782435,
 'recall_micro': 0.3788659793814433,
 'recall_weighted': 0.3788659793814433}
Evaluating resnet50_pretrained_moco_rgb_34
{'accuracy': 0.20400475812846947,
 'f1_micro': 0.20400475812846947,
 'f1_weighted': 0.2084757708299295,
 'mAP_micro': 0.14423790640022532,
 'mAP_weighted': 0.16770630238605036,
 'precision_micro': 0.2040047581

  _warn_prf(average, modifier, msg_start, len(result))


{'accuracy': 0.16911181601903252,
 'f1_micro': 0.16911181601903252,
 'f1_weighted': 0.1700258364285811,
 'mAP_micro': 0.11898073786110604,
 'mAP_weighted': 0.1452629063007441,
 'precision_micro': 0.16911181601903252,
 'precision_weighted': 0.19998218991710706,
 'recall_micro': 0.16911181601903252,
 'recall_weighted': 0.16911181601903252}
Evaluating resnet50_pretrained_imagenet_rgb_224
{'accuracy': 0.20816812053925457,
 'f1_micro': 0.20816812053925457,
 'f1_weighted': 0.21351026662814473,
 'mAP_micro': 0.14389766554648364,
 'mAP_weighted': 0.1695238495888373,
 'precision_micro': 0.20816812053925457,
 'precision_weighted': 0.24366718098779167,
 'recall_micro': 0.20816812053925457,
 'recall_weighted': 0.20816812053925457}
Evaluating resnet50_randominit_34
{'accuracy': 0.24682791435368756,
 'f1_micro': 0.24682791435368756,
 'f1_weighted': 0.24378438443912817,
 'mAP_micro': 0.16758347009280145,
 'mAP_weighted': 0.189026658603769,
 'precision_micro': 0.24682791435368756,
 'precision_weighted

In [35]:
# Dump metrics
with open("treesatai-results.json", "w") as f:
    json.dump(results, f, indent=2)

In [39]:
# Clean metrics
with open("treesatai-results.json") as f:
    results = json.load(f)

df = pd.DataFrame.from_dict(results).transpose()
df["rgb"] = ["RGB" if "rgb" in model_name else "MSI" for model_name in df.index]
df["size"] = [
    "no" if model_name.split("_")[-1] == "34" else "yes" for model_name in df.index
]
df["encoder"] = [
    model_name.rsplit("_", 1)[0].replace("_rgb", "") for model_name in df.index
]
df = df.sort_values(["rgb", "encoder", "size"], ascending=True)
df.to_csv("treesatai-results.csv")
df

Unnamed: 0,mAP_weighted,mAP_micro,f1_weighted,f1_micro,precision_micro,precision_weighted,recall_micro,recall_weighted,accuracy,rgb,size,encoder
imagestats_34,0.286503,0.270979,0.359495,0.349524,0.349524,0.393434,0.349524,0.349524,0.349524,MSI,no,imagestats
imagestats_224,0.286503,0.270979,0.359495,0.349524,0.349524,0.393434,0.349524,0.349524,0.349524,MSI,yes,imagestats
mosaiks_512_3_34,0.286215,0.270447,0.355273,0.345163,0.345163,0.388608,0.345163,0.345163,0.345163,MSI,no,mosaiks_512_3
mosaiks_512_3_224,0.283875,0.268987,0.351697,0.342982,0.342982,0.382046,0.342982,0.342982,0.342982,MSI,yes,mosaiks_512_3
resnet18_pretrained_moco_34,0.261722,0.245449,0.337014,0.328311,0.328311,0.38003,0.328311,0.328311,0.328311,MSI,no,resnet18_pretrained_moco
resnet18_pretrained_moco_224,0.280094,0.264168,0.349149,0.337034,0.337034,0.395083,0.337034,0.337034,0.337034,MSI,yes,resnet18_pretrained_moco
resnet50_pretrained_imagenet_34,0.179225,0.153408,0.225946,0.219469,0.219469,0.255205,0.219469,0.219469,0.219469,MSI,no,resnet50_pretrained_imagenet
resnet50_pretrained_imagenet_224,0.226174,0.20494,0.287149,0.283703,0.283703,0.314323,0.283703,0.283703,0.283703,MSI,yes,resnet50_pretrained_imagenet
resnet50_pretrained_moco_34,0.244184,0.224896,0.308125,0.298573,0.298573,0.35024,0.298573,0.298573,0.298573,MSI,no,resnet50_pretrained_moco
resnet50_pretrained_moco_224,0.322002,0.308999,0.390872,0.378866,0.378866,0.428131,0.378866,0.378866,0.378866,MSI,yes,resnet50_pretrained_moco
