In [1]:
from datasets.eurosat_datamodule import EurosatDataModule
from models.moco2_module import MocoV2
import torchvision.transforms as T
from torchvision.models import resnet18
import torch
import pandas as pd
from copy import deepcopy
import torch.nn as nn
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from pprint import pprint

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])

backbone_imagenet = resnet18(pretrained=True)
backbone_imagenet = nn.Sequential(*list(backbone_imagenet.children())[:-1], nn.Flatten())
backbone_imagenet = backbone_imagenet.to("cuda")
backbone_imagenet = backbone_imagenet.eval()

model = MocoV2.load_from_checkpoint("checkpoints/seco_resnet18_1m.ckpt")
backbone_seco = deepcopy(model.encoder_q)
backbone_seco = backbone_seco.to("cuda")
backbone_seco = backbone_seco.eval()

resize = nn.Identity() #T.Resize((224, 224))
transforms_imagenet = T.Compose([T.ToTensor(), resize, T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
transforms_seco = T.Compose([T.ToTensor(), resize])


  from .autonotebook import tqdm as notebook_tqdm
  stdout_func(
  stdout_func(
  stdout_func(


In [2]:
def extract_features(backbone, transforms):
    dm = EurosatDataModule("data/eurosat-rgb", transforms)
    dm.setup()
    labels, features = [], []
    dataloader = dm.train_dataloader()
    for batch in tqdm(dataloader, total=len(dataloader)):
        labels.append(batch[1])
        with torch.no_grad():
            features.append(backbone(batch[0].to("cuda")).detach().cpu())

    train_y = torch.cat(labels, dim=0).numpy()
    train_x = torch.cat(features, dim=0).numpy()

    labels, features = [], []
    dataloader = dm.val_dataloader()
    for batch in tqdm(dataloader, total=len(dataloader)):
        labels.append(batch[1])
        with torch.no_grad():
            features.append(backbone(batch[0].to("cuda")).detach().cpu())

    test_y = torch.cat(labels, dim=0).numpy()
    test_x = torch.cat(features, dim=0).numpy()
    return train_x, train_y, test_x, test_y

In [3]:
x_train_seco, y_train_seco, x_test_seco, y_test_seco = extract_features(backbone_seco, transforms_seco)
x_train_imagenet, y_train_imagenet, x_test_imagenet, y_test_imagenet = extract_features(backbone_imagenet, transforms_imagenet)

100%|██████████| 506/506 [00:07<00:00, 64.70it/s]
100%|██████████| 169/169 [00:02<00:00, 60.30it/s]
100%|██████████| 506/506 [00:07<00:00, 66.32it/s]
100%|██████████| 169/169 [00:02<00:00, 60.88it/s]


In [4]:
results = {}
for k in tqdm([3, 5, 10, 20]):
    model = KNeighborsClassifier(n_neighbors=k, n_jobs=4)
    model.fit(x_train_seco, y_train_seco)
    y_pred_seco = model.predict(x_test_seco)
    metrics = {
        "f1_weighted": f1_score(y_test_seco, y_pred_seco, average="weighted"),
        "f1_macro": f1_score(y_test_seco, y_pred_seco, average="macro"),
        "f1_micro": f1_score(y_test_seco, y_pred_seco, average="micro"),
        "precision_micro": precision_score(y_test_seco, y_pred_seco, average="micro"),
        "precision_macro": precision_score(y_test_seco, y_pred_seco, average="macro"),
        "precision_weighted": precision_score(y_test_seco, y_pred_seco, average="weighted"),
        "recall_micro": recall_score(y_test_seco, y_pred_seco, average="micro"),
        "recall_macro": recall_score(y_test_seco, y_pred_seco, average="macro"),
        "recall_weighted": recall_score(y_test_seco, y_pred_seco, average="weighted"),
        "accuracy": accuracy_score(y_test_seco, y_pred_seco),
    }
    results[f"seco_{k}"] = metrics.copy()

    model = KNeighborsClassifier(n_neighbors=k, n_jobs=4)
    model.fit(x_train_imagenet, y_train_imagenet)
    y_pred_imagenet = model.predict(x_test_imagenet)
    metrics = {
        "f1_weighted": f1_score(y_test_imagenet, y_pred_imagenet, average="weighted"),
        "f1_macro": f1_score(y_test_imagenet, y_pred_imagenet, average="macro"),
        "f1_micro": f1_score(y_test_imagenet, y_pred_imagenet, average="micro"),
        "precision_micro": precision_score(y_test_imagenet, y_pred_imagenet, average="micro"),
        "precision_macro": precision_score(y_test_imagenet, y_pred_imagenet, average="macro"),
        "precision_weighted": precision_score(y_test_imagenet, y_pred_imagenet, average="weighted"),
        "recall_micro": recall_score(y_test_imagenet, y_pred_imagenet, average="micro"),
        "recall_macro": recall_score(y_test_imagenet, y_pred_imagenet, average="macro"),
        "recall_weighted": recall_score(y_test_imagenet, y_pred_imagenet, average="weighted"),
        "accuracy": accuracy_score(y_test_imagenet, y_pred_imagenet),
    }
    results[f"imagenet_{k}"] = metrics.copy()

100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


In [6]:
df = pd.DataFrame.from_dict(results).transpose()
df["k"] = [int(model_name.split("_")[-1]) for model_name in df.index]
df["weights"] = [model_name.split("_")[0] for model_name in df.index]
df.to_csv("knn_64_results.csv")