In [1]:
import os
import sys

sys.path.append("..")

import json
from itertools import product
import pickle
from pprint import pprint

import pandas as pd
import kornia.augmentation as K
import numpy as np
import torch
import torch.nn as nn
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize
from torchgeo.datasets import So2Sat

from src.models import get_model_by_name
from src.datasets import So2SatDataModule
from src.utils import extract_features

device = torch.device("cuda")

In [None]:
# Fit
model_names = [
    "resnet50_pretrained_moco",
    "imagestats",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "mosaiks_512_3",
]
rgbs = [False, True]
sizes = [34, 224]
version = "3_random"

for model_name, rgb, size in product(model_names, rgbs, sizes):
    run = f"{model_name}{'_rgb' if rgb else ''}_{size}"
    print(f"Extracting features for {run}")
    if rgb:
        bands = So2Sat.rgb_bands
        pad_missing_bands = False
    else:
        bands = So2Sat.all_s2_band_names
        pad_missing_bands = True

    dm = So2SatDataModule(
        root="../data/so2sat/",
        bands=bands,
        version=version,
        batch_size=32,
        num_workers=16,
        pad_missing_bands=pad_missing_bands,
        seed=0,
    )
    dm.setup()

    model = get_model_by_name(model_name, rgb, device=device)

    if model_name == "imagestats":
        transforms = nn.Sequential(nn.Identity()).to(device)
    else:
        transforms = nn.Sequential(K.Resize(size)).to(device)

    x_train, y_train = extract_features(
        model, dm.train_dataloader(), device, transforms=transforms
    )
    x_test, y_test = extract_features(
        model, dm.test_dataloader(), device, transforms=transforms
    )
    data = dict(x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
    with open(f"{run}.pkl", "wb") as f:
        pickle.dump(data, f)

Extracting features for imagestats_34


100%|██████████| 10025/10025 [03:12<00:00, 52.13it/s]
100%|██████████| 2497/2497 [00:46<00:00, 53.97it/s]


Extracting features for imagestats_224


100%|██████████| 10025/10025 [02:55<00:00, 57.12it/s]
100%|██████████| 2497/2497 [00:43<00:00, 57.90it/s]


Extracting features for imagestats_rgb_34


100%|██████████| 10025/10025 [02:55<00:00, 57.28it/s]
100%|██████████| 2497/2497 [00:42<00:00, 58.18it/s]


Extracting features for imagestats_rgb_224


100%|██████████| 10025/10025 [03:05<00:00, 53.96it/s]
100%|██████████| 2497/2497 [00:46<00:00, 53.14it/s]


Extracting features for resnet50_pretrained_imagenet_34


100%|██████████| 10025/10025 [03:10<00:00, 52.71it/s]
100%|██████████| 2497/2497 [00:48<00:00, 51.91it/s]


Extracting features for resnet50_pretrained_imagenet_224


100%|██████████| 10025/10025 [16:14<00:00, 10.29it/s]
100%|██████████| 2497/2497 [04:07<00:00, 10.11it/s]


Extracting features for resnet50_pretrained_imagenet_rgb_34


100%|██████████| 10025/10025 [02:57<00:00, 56.60it/s]
100%|██████████| 2497/2497 [00:45<00:00, 54.28it/s]


Extracting features for resnet50_pretrained_imagenet_rgb_224


100%|██████████| 10025/10025 [15:26<00:00, 10.82it/s]
100%|██████████| 2497/2497 [03:53<00:00, 10.72it/s]


Extracting features for resnet50_randominit_34


100%|██████████| 10025/10025 [03:07<00:00, 53.47it/s]
100%|██████████| 2497/2497 [00:47<00:00, 52.53it/s]


Extracting features for resnet50_randominit_224


100%|██████████| 10025/10025 [16:20<00:00, 10.23it/s]
100%|██████████| 2497/2497 [03:55<00:00, 10.60it/s]


Extracting features for resnet50_randominit_rgb_34


100%|██████████| 10025/10025 [02:55<00:00, 56.97it/s]
100%|██████████| 2497/2497 [00:45<00:00, 54.46it/s]


Extracting features for resnet50_randominit_rgb_224


100%|██████████| 10025/10025 [15:03<00:00, 11.09it/s]
100%|██████████| 2497/2497 [03:42<00:00, 11.24it/s]


Extracting features for mosaiks_512_3_34


100%|██████████| 10025/10025 [02:38<00:00, 63.27it/s]
100%|██████████| 2497/2497 [00:45<00:00, 55.20it/s]


Extracting features for mosaiks_512_3_224


100%|██████████| 10025/10025 [21:12<00:00,  7.88it/s]
100%|██████████| 2497/2497 [05:21<00:00,  7.77it/s]


Extracting features for mosaiks_512_3_rgb_34


100%|██████████| 10025/10025 [03:01<00:00, 55.21it/s]
100%|██████████| 2497/2497 [00:44<00:00, 55.76it/s]


Extracting features for mosaiks_512_3_rgb_224


100%|██████████| 10025/10025 [18:27<00:00,  9.06it/s]
100%|██████████| 2497/2497 [04:37<00:00,  9.01it/s]


In [2]:
# Evaluate features
model_names = [
    # "resnet50_pretrained_moco",
    # "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
rgbs = [False, True]
sizes = [34, 224]

results = {}
K = 5
for model_name, rgb, size in product(model_names, rgbs, sizes):
    run = f"{model_name}{'_rgb' if rgb else ''}_{size}"
    print(f"Evaluating {run}")

    filename = f"{run}.pkl"
    if not os.path.exists(filename):
        continue

    with open(filename, "rb") as f:
        data = pickle.load(f)

    x_train = data["x_train"]
    y_train = data["y_train"]
    x_test = data["x_test"]
    y_test = data["y_test"]

    knn_model = KNeighborsClassifier(n_neighbors=K, n_jobs=8)
    knn_model.fit(X=x_train, y=y_train)

    y_test_onehot = label_binarize(y_test, classes=np.arange(len(So2Sat.classes)))
    y_pred = knn_model.predict(x_test)
    y_score = knn_model.predict_proba(x_test)

    metrics = {
        "f1_weighted": f1_score(y_test, y_pred, average="weighted"),
        "f1_macro": f1_score(y_test, y_pred, average="macro"),
        "f1_micro": f1_score(y_test, y_pred, average="micro"),
        "precision_micro": precision_score(y_test, y_pred, average="micro"),
        "precision_macro": precision_score(y_test, y_pred, average="macro"),
        "precision_weighted": precision_score(y_test, y_pred, average="weighted"),
        "recall_micro": recall_score(y_test, y_pred, average="micro"),
        "recall_macro": recall_score(y_test, y_pred, average="macro"),
        "recall_weighted": recall_score(y_test, y_pred, average="weighted"),
        "accuracy": accuracy_score(y_test, y_pred),
    }
    pprint(metrics)
    results[run] = metrics

Evaluating resnet50_randominit_34


In [None]:
# Dump metrics
with open("so2sat-results.json", "w") as f:
    json.dump(results, f, indent=2)

In [None]:
# Clean metrics
with open("so2sat-results.json") as f:
    results = json.load(f)

df = pd.DataFrame.from_dict(results).transpose()
df["rgb"] = ["RGB" if "rgb" in model_name else "MSI" for model_name in df.index]
df["size"] = [int(model_name.split("_")[-1]) for model_name in df.index]
df["encoder"] = [
    model_name.rsplit("_", 1)[0].replace("_rgb", "") for model_name in df.index
]
df = df.sort_values(["rgb", "encoder", "size"], ascending=True)
df.to_csv("sat6-results.csv")
df