In [1]:
import os
import sys

sys.path.append("..")

import json
from itertools import product
import pickle
from pprint import pprint

import json
import pandas as pd
import pandas as pd
import kornia.augmentation as K
import numpy as np
import torch
import torch.nn as nn
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import label_binarize


from src.models import get_model_by_name
from src.datasets.sat6 import SAT6, SAT6DataModule
from src.utils import extract_features

device = torch.device("cuda")

In [15]:
# Fit and save features
model_names = [
    "resnet50_pretrained_moco",
    "resnet18_pretrained_moco",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
rgbs = [True]
sizes = [34, 224]

for model_name, rgb, size in product(model_names, rgbs, sizes):
    run = f"{model_name}{'_rgb' if rgb else ''}_{size}"
    print(f"Extracting features for {run}")

    dm = SAT6DataModule(root="data/sat6/", batch_size=64, num_workers=16, seed=0)
    dm.setup()

    model = get_model_by_name(model_name, rgb, device=device)

    if model_name == "imagestats":
        transforms = nn.Sequential(nn.Identity()).to(device)
    else:
        transforms = nn.Sequential(K.Resize(size)).to(device)

    x_train, y_train = extract_features(
        model, dm.train_dataloader(), device, transforms=transforms
    )
    x_test, y_test = extract_features(
        model, dm.test_dataloader(), device, transforms=transforms
    )
    data = dict(x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
    with open(f"{run}.pkl", "wb") as f:
        pickle.dump(data, f)

Extracting features for resnet50_pretrained_moco_rgb_34


100%|██████████| 5063/5063 [01:46<00:00, 47.51it/s]
100%|██████████| 1266/1266 [00:26<00:00, 48.61it/s]


Extracting features for resnet50_pretrained_moco_rgb_224


100%|██████████| 5063/5063 [14:54<00:00,  5.66it/s]
100%|██████████| 1266/1266 [03:34<00:00,  5.90it/s]


Extracting features for resnet18_pretrained_moco_rgb_34


100%|██████████| 5063/5063 [00:56<00:00, 88.97it/s]
100%|██████████| 1266/1266 [00:15<00:00, 82.88it/s]


Extracting features for resnet18_pretrained_moco_rgb_224


100%|██████████| 5063/5063 [04:35<00:00, 18.41it/s]
100%|██████████| 1266/1266 [01:11<00:00, 17.59it/s]


Extracting features for resnet50_pretrained_imagenet_rgb_34


100%|██████████| 5063/5063 [01:39<00:00, 50.64it/s]
100%|██████████| 1266/1266 [00:25<00:00, 49.28it/s]


Extracting features for resnet50_pretrained_imagenet_rgb_224


100%|██████████| 5063/5063 [14:19<00:00,  5.89it/s]
100%|██████████| 1266/1266 [03:48<00:00,  5.54it/s]


Extracting features for resnet50_randominit_rgb_34


100%|██████████| 5063/5063 [01:45<00:00, 48.13it/s]
100%|██████████| 1266/1266 [00:26<00:00, 47.60it/s]


Extracting features for resnet50_randominit_rgb_224


100%|██████████| 5063/5063 [14:21<00:00,  5.88it/s]
100%|██████████| 1266/1266 [03:31<00:00,  5.98it/s]


Extracting features for imagestats_rgb_34


100%|██████████| 5063/5063 [00:08<00:00, 571.93it/s]
100%|██████████| 1266/1266 [00:03<00:00, 356.26it/s]


Extracting features for imagestats_rgb_224


100%|██████████| 5063/5063 [00:09<00:00, 549.38it/s]
100%|██████████| 1266/1266 [00:03<00:00, 358.21it/s]


Extracting features for mosaiks_512_3_rgb_34


100%|██████████| 5063/5063 [01:00<00:00, 83.49it/s]
100%|██████████| 1266/1266 [00:15<00:00, 81.07it/s]


Extracting features for mosaiks_512_3_rgb_224


100%|██████████| 5063/5063 [17:56<00:00,  4.70it/s]
100%|██████████| 1266/1266 [04:32<00:00,  4.65it/s]


In [7]:
# Evaluate features
model_names = [
    "resnet50_pretrained_moco",
    "resnet18_pretrained_moco",
    "resnet50_pretrained_imagenet",
    "resnet50_randominit",
    "imagestats",
    "mosaiks_512_3",
]
rgbs = [True]
sizes = [34, 224]

results = {}
K = 3
for model_name, rgb, size in product(model_names, rgbs, sizes):
    run = f"{model_name}{'_rgb' if rgb else ''}_{size}"
    print(f"Evaluating {run}")

    filename = f"{run}.pkl"
    if not os.path.exists(filename):
        continue

    with open(filename, "rb") as f:
        data = pickle.load(f)

    x_train = data["x_train"]
    y_train = data["y_train"]
    x_test = data["x_test"]
    y_test = data["y_test"]

    knn_model = KNeighborsClassifier(n_neighbors=K)
    knn_model.fit(X=x_train, y=y_train)

    y_test_onehot = label_binarize(y_test, classes=np.arange(len(SAT6.classes)))
    y_pred = knn_model.predict(x_test)
    y_score = knn_model.predict_proba(x_test)

    metrics = {
        "f1_weighted": f1_score(y_test, y_pred, average="weighted"),
        "f1_micro": f1_score(y_test, y_pred, average="micro"),
        "precision_micro": precision_score(y_test, y_pred, average="micro"),
        "precision_weighted": precision_score(y_test, y_pred, average="weighted"),
        "recall_micro": recall_score(y_test, y_pred, average="micro"),
        "recall_weighted": recall_score(y_test, y_pred, average="weighted"),
        "accuracy": accuracy_score(y_test, y_pred),
    }
    pprint(metrics)
    results[run] = metrics

Evaluating resnet50_pretrained_moco_34
Evaluating resnet50_pretrained_moco_224
Evaluating resnet50_pretrained_moco_rgb_34
{'accuracy': 0.9817037037037037,
 'f1_micro': 0.9817037037037037,
 'f1_weighted': 0.9816853699645429,
 'precision_micro': 0.9817037037037037,
 'precision_weighted': 0.9817002484112021,
 'recall_micro': 0.9817037037037037,
 'recall_weighted': 0.9817037037037037}
Evaluating resnet50_pretrained_moco_rgb_224
{'accuracy': 0.9989876543209877,
 'f1_micro': 0.9989876543209877,
 'f1_weighted': 0.9989877017631823,
 'precision_micro': 0.9989876543209877,
 'precision_weighted': 0.9989880426730673,
 'recall_micro': 0.9989876543209877,
 'recall_weighted': 0.9989876543209877}
Evaluating resnet18_pretrained_moco_34
Evaluating resnet18_pretrained_moco_224
Evaluating resnet18_pretrained_moco_rgb_34
{'accuracy': 0.9708888888888889,
 'f1_micro': 0.9708888888888889,
 'f1_weighted': 0.9708325797091503,
 'precision_micro': 0.9708888888888889,
 'precision_weighted': 0.9708306401189012,
 'r

In [8]:
# Dump metrics
with open("sat6-results.json", "w") as f:
    json.dump(results, f, indent=2)

In [9]:
# Clean metrics
with open("sat6-results.json") as f:
    results = json.load(f)

df = pd.DataFrame.from_dict(results).transpose()
df["rgb"] = ["RGB" if "rgb" in model_name else "MSI" for model_name in df.index]
df["size"] = [int(model_name.split("_")[-1]) for model_name in df.index]
df["encoder"] = [
    model_name.rsplit("_", 1)[0].replace("_rgb", "") for model_name in df.index
]
df = df.sort_values(["rgb", "encoder", "size"], ascending=True)
df.to_csv("sat6-results.csv")
df

Unnamed: 0,f1_weighted,f1_micro,precision_micro,precision_weighted,recall_micro,recall_weighted,accuracy,rgb,size,encoder
imagestats_rgb_34,0.996775,0.996778,0.996778,0.996778,0.996778,0.996778,0.996778,RGB,34,imagestats
imagestats_rgb_224,0.996775,0.996778,0.996778,0.996778,0.996778,0.996778,0.996778,RGB,224,imagestats
mosaiks_512_3_rgb_34,0.98621,0.986198,0.986198,0.986253,0.986198,0.986198,0.986198,RGB,34,mosaiks_512_3
mosaiks_512_3_rgb_224,0.984657,0.984642,0.984642,0.984709,0.984642,0.984642,0.984642,RGB,224,mosaiks_512_3
resnet18_pretrained_moco_rgb_34,0.970833,0.970889,0.970889,0.970831,0.970889,0.970889,0.970889,RGB,34,resnet18_pretrained_moco
resnet18_pretrained_moco_rgb_224,0.998827,0.998827,0.998827,0.998828,0.998827,0.998827,0.998827,RGB,224,resnet18_pretrained_moco
resnet50_pretrained_imagenet_rgb_34,0.931054,0.931123,0.931123,0.931122,0.931123,0.931123,0.931123,RGB,34,resnet50_pretrained_imagenet
resnet50_pretrained_imagenet_rgb_224,0.99773,0.997728,0.997728,0.997739,0.997728,0.997728,0.997728,RGB,224,resnet50_pretrained_imagenet
resnet50_pretrained_moco_rgb_34,0.981685,0.981704,0.981704,0.9817,0.981704,0.981704,0.981704,RGB,34,resnet50_pretrained_moco
resnet50_pretrained_moco_rgb_224,0.998988,0.998988,0.998988,0.998988,0.998988,0.998988,0.998988,RGB,224,resnet50_pretrained_moco
