In [1]:
from dinov2.models import DINOv2, DINOv1

In [2]:
import torch

dinov2 = DINOv2.from_pretrained(
    'KevinCha/dinov2-vit-small-remote-sensing',
    device_map=0,
    torch_dtype=torch.float32,
)
dinov2 = dinov2.eval()

In [3]:
def read_filename(filename):

    lines = []
    f = open(filename, 'r')
    while True:
        line = f.readline()
        if not line: break
        line = line.replace('\n', '')
        lines.append(line)
    f.close()
    return lines

In [4]:
import torchvision
from PIL import Image
from glob import glob

In [5]:
class KNNDataset(torch.utils.data.Dataset):
    def __init__(self, features, gts):
        self.features = features
        self.gts = gts

    def __getitem__(self, idx):
        return self.features[idx], self.gts[idx]

    def __len__(self):
        return len(self.features)

In [6]:
class BasicDataset(torch.utils.data.Dataset):
    def __init__(self, image_lists, class_idx, transforms):
        self.image_lists = image_lists
        self.class_idx = class_idx
        self.transforms = transforms

    def __getitem__(self, idx):
        path = self.image_lists[idx]
        return self.transforms(Image.open(path).convert('RGB')), self.class_idx[idx]

    def __len__(self):
        return len(self.image_lists)

In [7]:
import numpy as np
import os
from tqdm import tqdm

def extract_features(model, dataset_dir):
    image_paths, labels = [], []

    # whu-rs19
    for class_id, class_path in enumerate(glob('/nas/Dataset/WHU-RS19/WHU-RS19/*')):
    
        print(class_id, class_path)
        images_in_class = glob(os.path.join(class_path, '*'))
        image_paths.extend(images_in_class)
        labels.extend([class_id] * len(images_in_class))

    transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ]
    )

    dataset = BasicDataset(image_paths, labels, transforms)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=128,
        num_workers=4,
        shuffle=False,
        drop_last=False,
    )

    features, ground_truths = [], []
    for images, labels in tqdm(dataloader):

        with torch.no_grad():
            feature_output = model.backbone(images.cuda()).x_norm_clstoken
            # feature_output = model.dino_head(feature_output)

        features.append(feature_output.cpu().numpy())
        ground_truths.append(labels.numpy())

    features = np.vstack(features)
    ground_truths = np.concatenate(ground_truths)

    return features, ground_truths

In [8]:
features, ground_truths = extract_features(dinov2.student, None)

0 /nas/Dataset/WHU-RS19/WHU-RS19/Meadow
1 /nas/Dataset/WHU-RS19/WHU-RS19/Commercial
2 /nas/Dataset/WHU-RS19/WHU-RS19/Industrial
3 /nas/Dataset/WHU-RS19/WHU-RS19/Park
4 /nas/Dataset/WHU-RS19/WHU-RS19/Farmland
5 /nas/Dataset/WHU-RS19/WHU-RS19/Viaduct
6 /nas/Dataset/WHU-RS19/WHU-RS19/Airport
7 /nas/Dataset/WHU-RS19/WHU-RS19/Mountain
8 /nas/Dataset/WHU-RS19/WHU-RS19/Residential
9 /nas/Dataset/WHU-RS19/WHU-RS19/Desert
10 /nas/Dataset/WHU-RS19/WHU-RS19/Forest
11 /nas/Dataset/WHU-RS19/WHU-RS19/River
12 /nas/Dataset/WHU-RS19/WHU-RS19/footballField
13 /nas/Dataset/WHU-RS19/WHU-RS19/Beach
14 /nas/Dataset/WHU-RS19/WHU-RS19/Bridge
15 /nas/Dataset/WHU-RS19/WHU-RS19/railwayStation
16 /nas/Dataset/WHU-RS19/WHU-RS19/Pond
17 /nas/Dataset/WHU-RS19/WHU-RS19/Parking
18 /nas/Dataset/WHU-RS19/WHU-RS19/Port



00%|██████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.10s/it]

In [9]:
import faiss
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit

def knn(features, gts, ratios=[0.1, 0.3, 0.5, 0.7, 0.9], num_trials=5):
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    for ratio in ratios:
        total_accuracies = []
        for idx_fold, (train_idx, test_idx) in enumerate(skf.split(features, gts)):
            train_features, test_features = features[train_idx], features[test_idx]
            train_gts, test_gts = gts[train_idx], gts[test_idx]

            for trial in range(num_trials):
                random_state = trial

                # Stratified Sampling
                if ratio < 1.0:
                    sss = StratifiedShuffleSplit(
                        n_splits=1, train_size=ratio, random_state=random_state
                    )
                    subset_idx, _ = next(sss.split(train_features, train_gts))
                    train_features_subset = train_features[subset_idx]
                    train_gts_subset = train_gts[subset_idx]
                else:
                    train_features_subset, train_gts_subset = train_features, train_gts

                index = faiss.IndexFlatL2(train_features_subset.shape[1])
                index.add(train_features_subset)

                knn_dist, knn_idx = index.search(test_features, k=20)
                preds = train_gts_subset[knn_idx]

                preds = np.array([np.bincount(pred).argmax() for pred in preds])

                accuracy = (preds == test_gts).mean() * 100.0
                total_accuracies.append(accuracy)

                print(f'Ratio: {ratio}, Fold: {idx_fold}, Trial: {trial}, Accuracy: {accuracy:.3f}')

        overall_avg_acc, overall_std_acc = np.mean(total_accuracies), np.std(total_accuracies)
        print(f'[Total] Ratio: {ratio}, Overall Mean Accuracy: {overall_avg_acc:.3f}, STD: {overall_std_acc:.3f}')

In [10]:
knn(features, ground_truths)

Ratio: 0.1, Fold: 0, Trial: 0, Accuracy: 63.682
Ratio: 0.1, Fold: 0, Trial: 1, Accuracy: 59.701
Ratio: 0.1, Fold: 0, Trial: 2, Accuracy: 54.229
Ratio: 0.1, Fold: 0, Trial: 3, Accuracy: 55.224
Ratio: 0.1, Fold: 0, Trial: 4, Accuracy: 63.184
Ratio: 0.1, Fold: 1, Trial: 0, Accuracy: 63.682
Ratio: 0.1, Fold: 1, Trial: 1, Accuracy: 54.229
Ratio: 0.1, Fold: 1, Trial: 2, Accuracy: 61.194
Ratio: 0.1, Fold: 1, Trial: 3, Accuracy: 53.731
Ratio: 0.1, Fold: 1, Trial: 4, Accuracy: 57.711
Ratio: 0.1, Fold: 2, Trial: 0, Accuracy: 68.159
Ratio: 0.1, Fold: 2, Trial: 1, Accuracy: 57.711
Ratio: 0.1, Fold: 2, Trial: 2, Accuracy: 53.731
Ratio: 0.1, Fold: 2, Trial: 3, Accuracy: 58.706
Ratio: 0.1, Fold: 2, Trial: 4, Accuracy: 54.229
Ratio: 0.1, Fold: 3, Trial: 0, Accuracy: 57.711
Ratio: 0.1, Fold: 3, Trial: 1, Accuracy: 59.204
Ratio: 0.1, Fold: 3, Trial: 2, Accuracy: 53.234
Ratio: 0.1, Fold: 3, Trial: 3, Accuracy: 57.711
Ratio: 0.1, Fold: 3, Trial: 4, Accuracy: 56.716
Ratio: 0.1, Fold: 4, Trial: 0, Accuracy: