In [1]:
import os
import sys

sys.path.append("..")

import json
from itertools import product
import pickle
from pprint import pprint

import pandas as pd
import kornia.augmentation as K
import numpy as np
import torch
import torch.nn as nn
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize


from src.models import get_model_by_name
from src.datasets.sat6 import SAT6, SAT6DataModule
from src.utils import extract_features

device = torch.device("cuda")

In [2]:
# Fit and save features
model_names = [
    "resnet50_pretrained_moco",
    "resnet18_pretrained_moco",
    "resnet50_pretrained_seco",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
sizes = [34, 224]

for model_name, size in product(model_names, sizes):
    run = f"{model_name}_{size}"
    print(f"Extracting features for {run}")

    if os.path.exists(f"{run}.pkl"):
        continue

    dm = SAT6DataModule(root="../data/sat6/", batch_size=64, num_workers=16, seed=0)
    dm.setup()

    model = get_model_by_name(model_name, rgb=True, device=device)

    if model_name == "imagestats":
        transforms = nn.Sequential(nn.Identity()).to(device)
    else:
        transforms = nn.Sequential(K.Resize(size)).to(device)

    x_train, y_train = extract_features(
        model, dm.train_dataloader(), device, transforms=transforms
    )
    x_test, y_test = extract_features(
        model, dm.test_dataloader(), device, transforms=transforms
    )
    data = dict(x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
    with open(f"{run}.pkl", "wb") as f:
        pickle.dump(data, f)

Extracting features for resnet50_pretrained_moco_34
Extracting features for resnet50_pretrained_moco_224
Extracting features for resnet18_pretrained_moco_34
Extracting features for resnet18_pretrained_moco_224
Extracting features for resnet50_pretrained_seco_34


100%|██████████| 5063/5063 [01:41<00:00, 50.05it/s]
100%|██████████| 1266/1266 [00:26<00:00, 48.18it/s]


Extracting features for resnet50_pretrained_seco_224


100%|██████████| 5063/5063 [14:40<00:00,  5.75it/s]
100%|██████████| 1266/1266 [03:44<00:00,  5.63it/s]


Extracting features for resnet50_pretrained_imagenet_34


100%|██████████| 5063/5063 [01:41<00:00, 49.73it/s]
100%|██████████| 1266/1266 [00:26<00:00, 47.45it/s]


Extracting features for resnet50_pretrained_imagenet_224


100%|██████████| 5063/5063 [14:59<00:00,  5.63it/s]
100%|██████████| 1266/1266 [03:36<00:00,  5.84it/s]


Extracting features for resnet50_randominit_34


100%|██████████| 5063/5063 [01:38<00:00, 51.34it/s]
100%|██████████| 1266/1266 [00:27<00:00, 46.22it/s]


Extracting features for resnet50_randominit_224


100%|██████████| 5063/5063 [14:37<00:00,  5.77it/s]
100%|██████████| 1266/1266 [03:43<00:00,  5.66it/s]


Extracting features for imagestats_34


100%|██████████| 5063/5063 [00:09<00:00, 552.36it/s]
100%|██████████| 1266/1266 [00:03<00:00, 369.86it/s]


Extracting features for imagestats_224


100%|██████████| 5063/5063 [00:08<00:00, 599.58it/s]
100%|██████████| 1266/1266 [00:03<00:00, 359.49it/s]


Extracting features for mosaiks_512_3_34


100%|██████████| 5063/5063 [00:57<00:00, 88.27it/s]
100%|██████████| 1266/1266 [00:15<00:00, 82.30it/s]


Extracting features for mosaiks_512_3_224


100%|██████████| 5063/5063 [17:52<00:00,  4.72it/s]
100%|██████████| 1266/1266 [04:29<00:00,  4.70it/s]


In [3]:
# Evaluate features
model_names = [
    "resnet50_pretrained_moco",
    "resnet18_pretrained_moco",
    "resnet50_pretrained_seco",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
sizes = [34, 224]


output = "sat6-results.json"
if not os.path.exists(output):
    with open(output, "w") as f:
        json.dump({}, f, indent=2)

K = 5
for model_name, size in product(model_names, sizes):
    with open(output) as f:
        results = json.load(f)

    run = f"{model_name}_{size}"
    print(f"Evaluating {run}")

    if run in results:
        continue

    filename = f"{run}.pkl"
    if not os.path.exists(filename):
        continue

    with open(filename, "rb") as f:
        data = pickle.load(f)

    x_train = data["x_train"]
    y_train = data["y_train"]
    x_test = data["x_test"]
    y_test = data["y_test"]

    knn_model = KNeighborsClassifier(n_neighbors=K, n_jobs=8)
    knn_model.fit(X=x_train, y=y_train)

    y_test_onehot = label_binarize(y_test, classes=np.arange(len(SAT6.classes)))
    y_pred = knn_model.predict(x_test)

    metrics = {
        "f1_weighted": f1_score(y_test, y_pred, average="weighted"),
        "f1_macro": f1_score(y_test, y_pred, average="macro"),
        "f1_micro": f1_score(y_test, y_pred, average="micro"),
        "precision_micro": precision_score(y_test, y_pred, average="micro"),
        "precision_macro": precision_score(y_test, y_pred, average="macro"),
        "precision_weighted": precision_score(y_test, y_pred, average="weighted"),
        "recall_micro": recall_score(y_test, y_pred, average="micro"),
        "recall_macro": recall_score(y_test, y_pred, average="macro"),
        "recall_weighted": recall_score(y_test, y_pred, average="weighted"),
        "accuracy": accuracy_score(y_test, y_pred),
    }
    pprint(metrics)
    results[run] = metrics

    with open(output, "w") as f:
        json.dump(results, f, indent=2)

Evaluating resnet50_pretrained_moco_34
{'accuracy': 0.9815185185185186,
 'f1_macro': 0.9725980414997126,
 'f1_micro': 0.9815185185185186,
 'f1_weighted': 0.9814887932465317,
 'precision_macro': 0.9756063151805705,
 'precision_micro': 0.9815185185185186,
 'precision_weighted': 0.9815127251398028,
 'recall_macro': 0.9698622395579832,
 'recall_micro': 0.9815185185185186,
 'recall_weighted': 0.9815185185185186}
Evaluating resnet50_pretrained_moco_224
{'accuracy': 0.9986172839506173,
 'f1_macro': 0.9975171316701115,
 'f1_micro': 0.9986172839506173,
 'f1_weighted': 0.9986174517459343,
 'precision_macro': 0.9974722916113595,
 'precision_micro': 0.9986172839506173,
 'precision_weighted': 0.9986184650591721,
 'recall_macro': 0.9975629979552622,
 'recall_micro': 0.9986172839506173,
 'recall_weighted': 0.9986172839506173}
Evaluating resnet18_pretrained_moco_34
{'accuracy': 0.9704567901234568,
 'f1_macro': 0.9602830970200888,
 'f1_micro': 0.9704567901234568,
 'f1_weighted': 0.9703735248397836,
 'p

In [9]:
# Clean metrics
with open("sat6-results.json") as f:
    results = json.load(f)

df = pd.DataFrame.from_dict(results).transpose()
df["rgb"] = ["RGB" if "rgb" in model_name else "MSI" for model_name in df.index]
df["size"] = [int(model_name.split("_")[-1]) for model_name in df.index]
df["encoder"] = [
    model_name.rsplit("_", 1)[0].replace("_rgb", "") for model_name in df.index
]
df = df.sort_values(["rgb", "encoder", "size"], ascending=True)
df.to_csv("sat6-results.csv")
df

Unnamed: 0,f1_weighted,f1_micro,precision_micro,precision_weighted,recall_micro,recall_weighted,accuracy,rgb,size,encoder
imagestats_rgb_34,0.996775,0.996778,0.996778,0.996778,0.996778,0.996778,0.996778,RGB,34,imagestats
imagestats_rgb_224,0.996775,0.996778,0.996778,0.996778,0.996778,0.996778,0.996778,RGB,224,imagestats
mosaiks_512_3_rgb_34,0.98621,0.986198,0.986198,0.986253,0.986198,0.986198,0.986198,RGB,34,mosaiks_512_3
mosaiks_512_3_rgb_224,0.984657,0.984642,0.984642,0.984709,0.984642,0.984642,0.984642,RGB,224,mosaiks_512_3
resnet18_pretrained_moco_rgb_34,0.970833,0.970889,0.970889,0.970831,0.970889,0.970889,0.970889,RGB,34,resnet18_pretrained_moco
resnet18_pretrained_moco_rgb_224,0.998827,0.998827,0.998827,0.998828,0.998827,0.998827,0.998827,RGB,224,resnet18_pretrained_moco
resnet50_pretrained_imagenet_rgb_34,0.931054,0.931123,0.931123,0.931122,0.931123,0.931123,0.931123,RGB,34,resnet50_pretrained_imagenet
resnet50_pretrained_imagenet_rgb_224,0.99773,0.997728,0.997728,0.997739,0.997728,0.997728,0.997728,RGB,224,resnet50_pretrained_imagenet
resnet50_pretrained_moco_rgb_34,0.981685,0.981704,0.981704,0.9817,0.981704,0.981704,0.981704,RGB,34,resnet50_pretrained_moco
resnet50_pretrained_moco_rgb_224,0.998988,0.998988,0.998988,0.998988,0.998988,0.998988,0.998988,RGB,224,resnet50_pretrained_moco
