# KNN evaluation - Oxford 102 Flowers

- [Dataset homepage](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)

In this setup we evaluate the Barlow Twins training by using the embeddings from the backbone (ResNet50) of the model (projection layers are "dropped").

This evaluation test the representational power of the embeddings from the trained model. Here we expect that embeddings from same classes are closer together based on `L2 distance`.

Setup:
- Model training (not part of this notebook)
- Setup dataset and model
- Generate embeddings for both training and testing datasets
- Given an embedding from the test dataset, we find the closest `N` embeddings in the train set, and based on the labels we assign a new label to the test example with majority voting

In [1]:
import numpy as np
import tensorflow as tf
import scipy.io
from tqdm import tqdm
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import top_k_accuracy_score, f1_score, classification_report

import barlow_twins

In [2]:
tf.config.run_functions_eagerly(True)

# Constants

In [None]:
# Data
IMAGE_FOLDER = "/data"
TRAIN_TEST_SPLIT_IDS_FILE = "/data/setid.mat"
LABELS_FILE = "/data/imagelabels.mat"

# Image
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224

# Model & Eval
MODEL_WEIGHTS_PATH = "/code/logs_full_run/test_adam/checkpoint.h5"
BATCH_SIZE = 64
NB_NEIGHBORS = 5

# Dataset

In [3]:
image_paths = sorted(barlow_twins.data._get_image_paths(IMAGE_FOLDER))
image_ids = [int(x.stem.split("_")[-1]) for x in image_paths]

## Train test split

In [4]:
train_test_dict = scipy.io.loadmat(TRAIN_TEST_SPLIT_IDS_FILE)
train_ids = sorted(train_test_dict["trnid"][0])
val_ids = sorted(train_test_dict["valid"][0])
test_ids = sorted(train_test_dict["tstid"][0])

In [5]:
len(train_ids), len(val_ids), len(test_ids)

(1020, 1020, 6149)

## Labels

In [6]:
labels_dict = scipy.io.loadmat(LABELS_FILE)
labels = labels_dict["labels"][0]

## Dataframes

In [8]:
df = pd.DataFrame({"image_path":list(map(str, image_paths)),
                   "image_id":image_ids,
                   "label":labels})
df.set_index("image_id", inplace=True)

In [9]:
train_df = df.loc[train_ids]
val_df = df.loc[val_ids]
train_val_df = pd.concat((train_df, val_df))
test_df = df.loc[test_ids]

## tf.data.Dataset

In [10]:
# def make_dataset(df, augment:bool=False, batch_size:int=4):
#     dataset_images = tf.data.Dataset.from_tensor_slices(df["image_path"].values)
# 
#     dataset_images = dataset_images.map(barlow_twins.data._read_image_from_path,
#                                         num_parallel_calls=tf.data.AUTOTUNE)
#     if augment:
#         dataset_images = dataset_images.map(tf.image.random_flip_left_right,
#                                             num_parallel_calls=tf.data.AUTOTUNE)
# 
#     dataset_labels = tf.data.Dataset.from_tensor_slices(df["label"].values)
#     dataset = tf.data.Dataset.zip((dataset_images, dataset_labels))
#     
#     dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
#     return dataset

In [11]:
def image_dataset(df, augment:bool=False, batch_size:int=4):
    dataset_images = tf.data.Dataset.from_tensor_slices(df["image_path"].values)

    dataset_images = dataset_images.map(barlow_twins.data._read_image_from_path,
                                        num_parallel_calls=tf.data.AUTOTUNE)
    dataset_images = dataset_images.map(lambda x: tf.image.resize(x, (224, 224)))
    if augment:
        dataset_images = dataset_images.map(tf.image.random_flip_left_right,
                                            num_parallel_calls=tf.data.AUTOTUNE)
    
    dataset_images = dataset_images.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset_images

# Model

In [12]:
model = barlow_twins.BarlowTwinsModel(input_height=IMAGE_HEIGHT,
                                      input_width=IMAGE_WIDTH,
                                      projection_units=None,
                                      drop_projection_layer=True)

In [13]:
dummy_input = np.zeros((1, IMAGE_HEIGHT, IMAGE_WIDTH, 3), dtype=np.float32)
dummy_output = model(dummy_input)

In [14]:
model.load_weights(MODEL_WEIGHTS_PATH, by_name=True)

# KNN

## Generating the embeddings

In [15]:
def generate_embeddings(model, dataset, batch_size:int=1):
    embeddings = []
    
    for i, x in tqdm(enumerate(dataset)):
        batch_embeddings = model(x)
        embeddings.extend(batch_embeddings)
        
    return np.array(embeddings)

In [17]:
train_dataset = image_dataset(train_val_df, batch_size=BATCH_SIZE)
train_embeddings = generate_embeddings(model, train_dataset, BATCH_SIZE)

  "Even though the tf.config.experimental_run_functions_eagerly "
32it [00:02, 11.03it/s]


In [18]:
test_dataset = image_dataset(test_df, batch_size=BATCH_SIZE)
test_embeddings = generate_embeddings(model, test_dataset, BATCH_SIZE)

97it [00:07, 13.59it/s]


## Evaluation

In [19]:
knn = KNeighborsClassifier(n_neighbors=NB_NEIGHBORS)

In [20]:
knn.fit(train_embeddings, train_val_df["label"].values);

In [21]:
pred_label_scores = knn.predict_proba(test_embeddings)
pred_labels = knn.predict(test_embeddings)

In [22]:
test_labels = test_df["label"].values

top_1_acc = top_k_accuracy_score(test_labels, pred_label_scores, k=1)
top_5_acc = top_k_accuracy_score(test_labels, pred_label_scores, k=5)
report = classification_report(test_labels, pred_labels)
f1 = f1_score(test_labels, pred_labels, average="micro")

In [25]:
print(f"Top 1 accuracy: {top_1_acc:.4f}\nTop 5 accuracy: {top_5_acc:.4f}\nF1: {f1:.4f}")

Top 1 accuracy: 0.1699
Top 5 accuracy: 0.4062
F1: 0.1612


In [24]:
print(report)

              precision    recall  f1-score   support

           1       0.03      0.25      0.05        20
           2       0.16      0.42      0.23        40
           3       0.03      0.25      0.05        20
           4       0.02      0.11      0.04        36
           5       0.13      0.49      0.20        45
           6       0.04      0.24      0.07        25
           7       0.08      0.45      0.13        20
           8       0.08      0.22      0.12        65
           9       0.10      0.38      0.16        26
          10       0.23      0.72      0.35        25
          11       0.03      0.10      0.05        67
          12       0.27      0.30      0.28        67
          13       0.08      0.17      0.11        29
          14       0.11      0.21      0.15        28
          15       0.05      0.31      0.09        29
          16       0.03      0.14      0.05        21
          17       0.07      0.12      0.09        65
          18       0.10    