Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Plots and analyses for "Towards flexible perception with visual memory"

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import numpy as np
import pandas as pd
import json
import time
import math
import copy
from tqdm import tqdm
from collections import Counter

In [None]:
def file_opener_default(path, mode):
  return open(path, mode)

file_opener = file_opener_default

In [None]:
def print_viewing_path_default(save_fig_path):
  print(f'View at: {save_fig_path}')

print_viewing_path = print_viewing_path_default

In [None]:
FIGURE_DIR = '/path/to/figures/'
DATA_PARENT_DIR = '/path/to/data_parent_dir/'
DATA_DIR = f'{DATA_PARENT_DIR}/data/'
PRUNING_METRICS_DIR = f'{DATA_PARENT_DIR}/dataset-pruning-metrics'

## Defining color scheme & naming

In [None]:
# https://brand.google/brand-foundations/brand-identity/color/
featurizer_to_color = {'dinov2_vits14': '#669DF6',
                       'dinov2_vitb14': '#1A73E8',
                       'dinov2_vitl14': '#185ABC',
                       'clip-vit_b16': '#EA4335',
                       'clip-vit_l14': '#B31412',
                       }
featurizer_to_name = {'dinov2_vits14': 'DinoV2 ViT-S/14',
                      'dinov2_vitb14': 'DinoV2 ViT-B/14',
                      'dinov2_vitl14': 'DinoV2 ViT-L/14',
                      'clip-vit_b16': 'CLIP ViT-B/16',
                      'clip-vit_l14': 'CLIP ViT-L/14',
                       }
aggregator_to_color = {'PluralityVoting': '#e7298a',
                       'DistanceVoting': '#1b9e77',
                       'SoftmaxVoting': '#7570b3',
                       'RankVoting': '#d95f02',
                       }

## Loading data

In [None]:
def read_neighbors_info_JSON(paths: list[str]) -> pd.DataFrame:
  """Read information about nearest neighbors from JSON."""

  t1 = time.time()
  assert len(paths) == 1, 'Only one JSON file is supported'

  path = paths[0]
  assert path.endswith('.json')

  df = pd.read_json(path, orient='index')
  df.set_index('image_id', inplace=True)
  df = df.sort_index()
  df = df.reset_index(drop=False)

  # reorder columns
  df = df[['featurizer', 'image_id', 'image_class', 'neighbor_image_ids',	'neighbor_classes', 'neighbor_distances']]

  t2 = time.time()
  print(f'Loading time: {round(t2 - t1)} seconds')

  return df

In [None]:
def remove_neighbors_identical_to_query(df: pd.DataFrame) -> pd.DataFrame:
  """Remove neighbors identical to queries from the dataframe.

  If neighbors are derived from the training set, the first neighbor
  is usually identical to the query image and needs to be removed.
  """

  def _get_index(x):
    if x.image_id in x.neighbor_image_ids:
      return x.neighbor_image_ids.index(x.image_id)
    else:
      return 'QueryNotFound'

  def _remove_query_index(x, c):
    if x.query_index == 'QueryNotFound':
      # in this case, remove last element of neighbor list
      # to keep length consistent
      return x[c][:-1]
    else:
      del x[c][x.query_index]
      return x[c]

  df_tmp = copy.deepcopy(df)
  df_tmp['query_index'] = df.apply(lambda x: _get_index(x), axis=1)

  for c in ['neighbor_image_ids', 'neighbor_classes', 'neighbor_distances']:
    df_tmp[c] = df_tmp.apply(lambda x: _remove_query_index(x, c), axis=1)

  del df_tmp['query_index']

  return df_tmp

In [None]:
def calculate_accuracy(df: pd.DataFrame, k: int, prediction_column: str) -> float:
  if type(df['image_class'][0]) is list:
    accuracy = df.apply(lambda x: x[prediction_column][k] in x['image_class'], axis=1).mean()
  else:
    matches = df[prediction_column].apply(lambda x: x[k]) == df['image_class']
    accuracy = matches.mean()
  return accuracy

In [None]:
def print_df_stats(df: pd.DataFrame) -> None:
  """Print some basic statistics about the data."""

  assert len(df['featurizer'].unique()) == 1
  print(f'Dataframe stats for model {df["featurizer"].unique()[0]}:')

  num_neighbors = len(df.loc[0]['neighbor_image_ids'])
  print(f'Found {len(df)} samples, {num_neighbors} neighbors available.')

  min_dist = np.min(df['neighbor_distances'].apply(lambda y: np.min(y)))
  max_dist = np.max(df['neighbor_distances'].apply(lambda y: np.max(y)))
  print(f'min_dist: {min_dist}, max_dist: {max_dist}')

  acc_0 = calculate_accuracy(df, 0, 'neighbor_classes') # accuracy of k=0
  acc_1 = calculate_accuracy(df, 1, 'neighbor_classes') # accuracy of k=1

  print(f'acc of neighbor 0: {acc_0}')
  print(f'acc of neighbor 1: {acc_1}')

In [None]:
def distance_normalization(df: pd.DataFrame) -> pd.DataFrame:
  """Normalize distances to [0, 1]."""

  print('Normalizing distances to [0, 1]')

  df_tmp = copy.deepcopy(df)
  min_dist = np.min(df['neighbor_distances'].apply(lambda y: np.min(y)))
  max_dist = np.max(df['neighbor_distances'].apply(lambda y: np.max(y)))

  def normalize_distances(x):
    for i in range(len(x)):
      x[i] = (x[i] - min_dist) / (max_dist - min_dist)
    return x

  df_tmp['neighbor_distances'] = df_tmp.apply(lambda x: normalize_distances(x['neighbor_distances']), axis=1)

  new_min_dist = np.min(df_tmp['neighbor_distances'].apply(lambda y: np.min(y)))
  new_max_dist = np.max(df_tmp['neighbor_distances'].apply(lambda y: np.max(y)))
  assert np.isclose(new_min_dist, 0.0)
  assert np.isclose(new_max_dist, 1.0)

  return df_tmp

In [None]:
def read_scaling_df(model,
                    query_dataset='imagenet2012', query_split='validation',
                    memory_dataset='imagenet2012', memory_split='train',
                    size='full',
                    verbose=True,
                    remove_identical_neighbors=False,
                    normalize_distances=False):

  assert memory_split in ['train', 'validation', 'test', 'train-and-test']
  assert query_split in ['train', 'validation', 'test']
  memory_dataset_list = ['imagenet2012',
                         'imagenet-jft-extension-20-classes',
                         'imagenet2012-and-ninco',
                         'ninco',
                         'jft-with-vit22b-labels',
                         'inaturalist']
  assert memory_dataset in memory_dataset_list
  query_dataset_list = ['imagenet2012',
                        'imagenet-v2',
                        'imagenet-r',
                        'imagenet-a',
                        'imagenet-sketch',
                        'ninco',
                        'imagenet-real',
                        'inaturalist']
  assert query_dataset in query_dataset_list

  if not size == 'full':
    size = f"subsampled_{size}"

  path = f'{DATA_DIR}/memory-{memory_dataset}_msplit-{memory_split}_query-{query_dataset}_qsplit-{query_split}_{model}_{size}_neighbor_info.json'

  df = read_neighbors_info_JSON(paths=[path])

  if (memory_split == "train" and query_split == 'train') or remove_identical_neighbors:
    print(f'Removing neighbors identical to query')
    df = remove_neighbors_identical_to_query(df=df)

  if normalize_distances:
    df = distance_normalization(df=df)

  if verbose:
    print_df_stats(df=df)
  return df

In [None]:
def read_multiple_scaling_dfs(models, query_split, query_dataset, memory_dataset, memory_split, verbose=True):
  dfs = []
  for model in models:
    df = read_scaling_df(model=model, query_split=query_split, query_dataset=query_dataset, memory_dataset=memory_dataset, memory_split=memory_split, verbose=verbose)
    dfs.append(df)
  return pd.concat(dfs)

In [None]:
def read_pruning_data(metric_name, directory=PRUNING_METRICS_DIR):
  """Read pruning metrics and return as a dataframe."""

  path = f'{directory}/ImageNet-1K_{metric_name}.csv'
  with file_opener(path, 'r') as f:
    df = pd.read_csv(f)
  return df

## Accuracy based on class

In [None]:
synset_to_index = {'n01440764': 0, 'n01443537': 1, 'n01484850': 2, 'n01491361': 3, 'n01494475': 4, 'n01496331': 5, 'n01498041': 6, 'n01514668': 7, 'n01514859': 8, 'n01518878': 9, 'n01530575': 10, 'n01531178': 11, 'n01532829': 12, 'n01534433': 13, 'n01537544': 14, 'n01558993': 15, 'n01560419': 16, 'n01580077': 17, 'n01582220': 18, 'n01592084': 19, 'n01601694': 20, 'n01608432': 21, 'n01614925': 22, 'n01616318': 23, 'n01622779': 24, 'n01629819': 25, 'n01630670': 26, 'n01631663': 27, 'n01632458': 28, 'n01632777': 29, 'n01641577': 30, 'n01644373': 31, 'n01644900': 32, 'n01664065': 33, 'n01665541': 34, 'n01667114': 35, 'n01667778': 36, 'n01669191': 37, 'n01675722': 38, 'n01677366': 39, 'n01682714': 40, 'n01685808': 41, 'n01687978': 42, 'n01688243': 43, 'n01689811': 44, 'n01692333': 45, 'n01693334': 46, 'n01694178': 47, 'n01695060': 48, 'n01697457': 49, 'n01698640': 50, 'n01704323': 51, 'n01728572': 52, 'n01728920': 53, 'n01729322': 54, 'n01729977': 55, 'n01734418': 56, 'n01735189': 57, 'n01737021': 58, 'n01739381': 59, 'n01740131': 60, 'n01742172': 61, 'n01744401': 62, 'n01748264': 63, 'n01749939': 64, 'n01751748': 65, 'n01753488': 66, 'n01755581': 67, 'n01756291': 68, 'n01768244': 69, 'n01770081': 70, 'n01770393': 71, 'n01773157': 72, 'n01773549': 73, 'n01773797': 74, 'n01774384': 75, 'n01774750': 76, 'n01775062': 77, 'n01776313': 78, 'n01784675': 79, 'n01795545': 80, 'n01796340': 81, 'n01797886': 82, 'n01798484': 83, 'n01806143': 84, 'n01806567': 85, 'n01807496': 86, 'n01817953': 87, 'n01818515': 88, 'n01819313': 89, 'n01820546': 90, 'n01824575': 91, 'n01828970': 92, 'n01829413': 93, 'n01833805': 94, 'n01843065': 95, 'n01843383': 96, 'n01847000': 97, 'n01855032': 98, 'n01855672': 99, 'n01860187': 100, 'n01871265': 101, 'n01872401': 102, 'n01873310': 103, 'n01877812': 104, 'n01882714': 105, 'n01883070': 106, 'n01910747': 107, 'n01914609': 108, 'n01917289': 109, 'n01924916': 110, 'n01930112': 111, 'n01943899': 112, 'n01944390': 113, 'n01945685': 114, 'n01950731': 115, 'n01955084': 116, 'n01968897': 117, 'n01978287': 118, 'n01978455': 119, 'n01980166': 120, 'n01981276': 121, 'n01983481': 122, 'n01984695': 123, 'n01985128': 124, 'n01986214': 125, 'n01990800': 126, 'n02002556': 127, 'n02002724': 128, 'n02006656': 129, 'n02007558': 130, 'n02009229': 131, 'n02009912': 132, 'n02011460': 133, 'n02012849': 134, 'n02013706': 135, 'n02017213': 136, 'n02018207': 137, 'n02018795': 138, 'n02025239': 139, 'n02027492': 140, 'n02028035': 141, 'n02033041': 142, 'n02037110': 143, 'n02051845': 144, 'n02056570': 145, 'n02058221': 146, 'n02066245': 147, 'n02071294': 148, 'n02074367': 149, 'n02077923': 150, 'n02085620': 151, 'n02085782': 152, 'n02085936': 153, 'n02086079': 154, 'n02086240': 155, 'n02086646': 156, 'n02086910': 157, 'n02087046': 158, 'n02087394': 159, 'n02088094': 160, 'n02088238': 161, 'n02088364': 162, 'n02088466': 163, 'n02088632': 164, 'n02089078': 165, 'n02089867': 166, 'n02089973': 167, 'n02090379': 168, 'n02090622': 169, 'n02090721': 170, 'n02091032': 171, 'n02091134': 172, 'n02091244': 173, 'n02091467': 174, 'n02091635': 175, 'n02091831': 176, 'n02092002': 177, 'n02092339': 178, 'n02093256': 179, 'n02093428': 180, 'n02093647': 181, 'n02093754': 182, 'n02093859': 183, 'n02093991': 184, 'n02094114': 185, 'n02094258': 186, 'n02094433': 187, 'n02095314': 188, 'n02095570': 189, 'n02095889': 190, 'n02096051': 191, 'n02096177': 192, 'n02096294': 193, 'n02096437': 194, 'n02096585': 195, 'n02097047': 196, 'n02097130': 197, 'n02097209': 198, 'n02097298': 199, 'n02097474': 200, 'n02097658': 201, 'n02098105': 202, 'n02098286': 203, 'n02098413': 204, 'n02099267': 205, 'n02099429': 206, 'n02099601': 207, 'n02099712': 208, 'n02099849': 209, 'n02100236': 210, 'n02100583': 211, 'n02100735': 212, 'n02100877': 213, 'n02101006': 214, 'n02101388': 215, 'n02101556': 216, 'n02102040': 217, 'n02102177': 218, 'n02102318': 219, 'n02102480': 220, 'n02102973': 221, 'n02104029': 222, 'n02104365': 223, 'n02105056': 224, 'n02105162': 225, 'n02105251': 226, 'n02105412': 227, 'n02105505': 228, 'n02105641': 229, 'n02105855': 230, 'n02106030': 231, 'n02106166': 232, 'n02106382': 233, 'n02106550': 234, 'n02106662': 235, 'n02107142': 236, 'n02107312': 237, 'n02107574': 238, 'n02107683': 239, 'n02107908': 240, 'n02108000': 241, 'n02108089': 242, 'n02108422': 243, 'n02108551': 244, 'n02108915': 245, 'n02109047': 246, 'n02109525': 247, 'n02109961': 248, 'n02110063': 249, 'n02110185': 250, 'n02110341': 251, 'n02110627': 252, 'n02110806': 253, 'n02110958': 254, 'n02111129': 255, 'n02111277': 256, 'n02111500': 257, 'n02111889': 258, 'n02112018': 259, 'n02112137': 260, 'n02112350': 261, 'n02112706': 262, 'n02113023': 263, 'n02113186': 264, 'n02113624': 265, 'n02113712': 266, 'n02113799': 267, 'n02113978': 268, 'n02114367': 269, 'n02114548': 270, 'n02114712': 271, 'n02114855': 272, 'n02115641': 273, 'n02115913': 274, 'n02116738': 275, 'n02117135': 276, 'n02119022': 277, 'n02119789': 278, 'n02120079': 279, 'n02120505': 280, 'n02123045': 281, 'n02123159': 282, 'n02123394': 283, 'n02123597': 284, 'n02124075': 285, 'n02125311': 286, 'n02127052': 287, 'n02128385': 288, 'n02128757': 289, 'n02128925': 290, 'n02129165': 291, 'n02129604': 292, 'n02130308': 293, 'n02132136': 294, 'n02133161': 295, 'n02134084': 296, 'n02134418': 297, 'n02137549': 298, 'n02138441': 299, 'n02165105': 300, 'n02165456': 301, 'n02167151': 302, 'n02168699': 303, 'n02169497': 304, 'n02172182': 305, 'n02174001': 306, 'n02177972': 307, 'n02190166': 308, 'n02206856': 309, 'n02219486': 310, 'n02226429': 311, 'n02229544': 312, 'n02231487': 313, 'n02233338': 314, 'n02236044': 315, 'n02256656': 316, 'n02259212': 317, 'n02264363': 318, 'n02268443': 319, 'n02268853': 320, 'n02276258': 321, 'n02277742': 322, 'n02279972': 323, 'n02280649': 324, 'n02281406': 325, 'n02281787': 326, 'n02317335': 327, 'n02319095': 328, 'n02321529': 329, 'n02325366': 330, 'n02326432': 331, 'n02328150': 332, 'n02342885': 333, 'n02346627': 334, 'n02356798': 335, 'n02361337': 336, 'n02363005': 337, 'n02364673': 338, 'n02389026': 339, 'n02391049': 340, 'n02395406': 341, 'n02396427': 342, 'n02397096': 343, 'n02398521': 344, 'n02403003': 345, 'n02408429': 346, 'n02410509': 347, 'n02412080': 348, 'n02415577': 349, 'n02417914': 350, 'n02422106': 351, 'n02422699': 352, 'n02423022': 353, 'n02437312': 354, 'n02437616': 355, 'n02441942': 356, 'n02442845': 357, 'n02443114': 358, 'n02443484': 359, 'n02444819': 360, 'n02445715': 361, 'n02447366': 362, 'n02454379': 363, 'n02457408': 364, 'n02480495': 365, 'n02480855': 366, 'n02481823': 367, 'n02483362': 368, 'n02483708': 369, 'n02484975': 370, 'n02486261': 371, 'n02486410': 372, 'n02487347': 373, 'n02488291': 374, 'n02488702': 375, 'n02489166': 376, 'n02490219': 377, 'n02492035': 378, 'n02492660': 379, 'n02493509': 380, 'n02493793': 381, 'n02494079': 382, 'n02497673': 383, 'n02500267': 384, 'n02504013': 385, 'n02504458': 386, 'n02509815': 387, 'n02510455': 388, 'n02514041': 389, 'n02526121': 390, 'n02536864': 391, 'n02606052': 392, 'n02607072': 393, 'n02640242': 394, 'n02641379': 395, 'n02643566': 396, 'n02655020': 397, 'n02666196': 398, 'n02667093': 399, 'n02669723': 400, 'n02672831': 401, 'n02676566': 402, 'n02687172': 403, 'n02690373': 404, 'n02692877': 405, 'n02699494': 406, 'n02701002': 407, 'n02704792': 408, 'n02708093': 409, 'n02727426': 410, 'n02730930': 411, 'n02747177': 412, 'n02749479': 413, 'n02769748': 414, 'n02776631': 415, 'n02777292': 416, 'n02782093': 417, 'n02783161': 418, 'n02786058': 419, 'n02787622': 420, 'n02788148': 421, 'n02790996': 422, 'n02791124': 423, 'n02791270': 424, 'n02793495': 425, 'n02794156': 426, 'n02795169': 427, 'n02797295': 428, 'n02799071': 429, 'n02802426': 430, 'n02804414': 431, 'n02804610': 432, 'n02807133': 433, 'n02808304': 434, 'n02808440': 435, 'n02814533': 436, 'n02814860': 437, 'n02815834': 438, 'n02817516': 439, 'n02823428': 440, 'n02823750': 441, 'n02825657': 442, 'n02834397': 443, 'n02835271': 444, 'n02837789': 445, 'n02840245': 446, 'n02841315': 447, 'n02843684': 448, 'n02859443': 449, 'n02860847': 450, 'n02865351': 451, 'n02869837': 452, 'n02870880': 453, 'n02871525': 454, 'n02877765': 455, 'n02879718': 456, 'n02883205': 457, 'n02892201': 458, 'n02892767': 459, 'n02894605': 460, 'n02895154': 461, 'n02906734': 462, 'n02909870': 463, 'n02910353': 464, 'n02916936': 465, 'n02917067': 466, 'n02927161': 467, 'n02930766': 468, 'n02939185': 469, 'n02948072': 470, 'n02950826': 471, 'n02951358': 472, 'n02951585': 473, 'n02963159': 474, 'n02965783': 475, 'n02966193': 476, 'n02966687': 477, 'n02971356': 478, 'n02974003': 479, 'n02977058': 480, 'n02978881': 481, 'n02979186': 482, 'n02980441': 483, 'n02981792': 484, 'n02988304': 485, 'n02992211': 486, 'n02992529': 487, 'n02999410': 488, 'n03000134': 489, 'n03000247': 490, 'n03000684': 491, 'n03014705': 492, 'n03016953': 493, 'n03017168': 494, 'n03018349': 495, 'n03026506': 496, 'n03028079': 497, 'n03032252': 498, 'n03041632': 499, 'n03042490': 500, 'n03045698': 501, 'n03047690': 502, 'n03062245': 503, 'n03063599': 504, 'n03063689': 505, 'n03065424': 506, 'n03075370': 507, 'n03085013': 508, 'n03089624': 509, 'n03095699': 510, 'n03100240': 511, 'n03109150': 512, 'n03110669': 513, 'n03124043': 514, 'n03124170': 515, 'n03125729': 516, 'n03126707': 517, 'n03127747': 518, 'n03127925': 519, 'n03131574': 520, 'n03133878': 521, 'n03134739': 522, 'n03141823': 523, 'n03146219': 524, 'n03160309': 525, 'n03179701': 526, 'n03180011': 527, 'n03187595': 528, 'n03188531': 529, 'n03196217': 530, 'n03197337': 531, 'n03201208': 532, 'n03207743': 533, 'n03207941': 534, 'n03208938': 535, 'n03216828': 536, 'n03218198': 537, 'n03220513': 538, 'n03223299': 539, 'n03240683': 540, 'n03249569': 541, 'n03250847': 542, 'n03255030': 543, 'n03259280': 544, 'n03271574': 545, 'n03272010': 546, 'n03272562': 547, 'n03290653': 548, 'n03291819': 549, 'n03297495': 550, 'n03314780': 551, 'n03325584': 552, 'n03337140': 553, 'n03344393': 554, 'n03345487': 555, 'n03347037': 556, 'n03355925': 557, 'n03372029': 558, 'n03376595': 559, 'n03379051': 560, 'n03384352': 561, 'n03388043': 562, 'n03388183': 563, 'n03388549': 564, 'n03393912': 565, 'n03394916': 566, 'n03400231': 567, 'n03404251': 568, 'n03417042': 569, 'n03424325': 570, 'n03425413': 571, 'n03443371': 572, 'n03444034': 573, 'n03445777': 574, 'n03445924': 575, 'n03447447': 576, 'n03447721': 577, 'n03450230': 578, 'n03452741': 579, 'n03457902': 580, 'n03459775': 581, 'n03461385': 582, 'n03467068': 583, 'n03476684': 584, 'n03476991': 585, 'n03478589': 586, 'n03481172': 587, 'n03482405': 588, 'n03483316': 589, 'n03485407': 590, 'n03485794': 591, 'n03492542': 592, 'n03494278': 593, 'n03495258': 594, 'n03496892': 595, 'n03498962': 596, 'n03527444': 597, 'n03529860': 598, 'n03530642': 599, 'n03532672': 600, 'n03534580': 601, 'n03535780': 602, 'n03538406': 603, 'n03544143': 604, 'n03584254': 605, 'n03584829': 606, 'n03590841': 607, 'n03594734': 608, 'n03594945': 609, 'n03595614': 610, 'n03598930': 611, 'n03599486': 612, 'n03602883': 613, 'n03617480': 614, 'n03623198': 615, 'n03627232': 616, 'n03630383': 617, 'n03633091': 618, 'n03637318': 619, 'n03642806': 620, 'n03649909': 621, 'n03657121': 622, 'n03658185': 623, 'n03661043': 624, 'n03662601': 625, 'n03666591': 626, 'n03670208': 627, 'n03673027': 628, 'n03676483': 629, 'n03680355': 630, 'n03690938': 631, 'n03691459': 632, 'n03692522': 633, 'n03697007': 634, 'n03706229': 635, 'n03709823': 636, 'n03710193': 637, 'n03710637': 638, 'n03710721': 639, 'n03717622': 640, 'n03720891': 641, 'n03721384': 642, 'n03724870': 643, 'n03729826': 644, 'n03733131': 645, 'n03733281': 646, 'n03733805': 647, 'n03742115': 648, 'n03743016': 649, 'n03759954': 650, 'n03761084': 651, 'n03763968': 652, 'n03764736': 653, 'n03769881': 654, 'n03770439': 655, 'n03770679': 656, 'n03773504': 657, 'n03775071': 658, 'n03775546': 659, 'n03776460': 660, 'n03777568': 661, 'n03777754': 662, 'n03781244': 663, 'n03782006': 664, 'n03785016': 665, 'n03786901': 666, 'n03787032': 667, 'n03788195': 668, 'n03788365': 669, 'n03791053': 670, 'n03792782': 671, 'n03792972': 672, 'n03793489': 673, 'n03794056': 674, 'n03796401': 675, 'n03803284': 676, 'n03804744': 677, 'n03814639': 678, 'n03814906': 679, 'n03825788': 680, 'n03832673': 681, 'n03837869': 682, 'n03838899': 683, 'n03840681': 684, 'n03841143': 685, 'n03843555': 686, 'n03854065': 687, 'n03857828': 688, 'n03866082': 689, 'n03868242': 690, 'n03868863': 691, 'n03871628': 692, 'n03873416': 693, 'n03874293': 694, 'n03874599': 695, 'n03876231': 696, 'n03877472': 697, 'n03877845': 698, 'n03884397': 699, 'n03887697': 700, 'n03888257': 701, 'n03888605': 702, 'n03891251': 703, 'n03891332': 704, 'n03895866': 705, 'n03899768': 706, 'n03902125': 707, 'n03903868': 708, 'n03908618': 709, 'n03908714': 710, 'n03916031': 711, 'n03920288': 712, 'n03924679': 713, 'n03929660': 714, 'n03929855': 715, 'n03930313': 716, 'n03930630': 717, 'n03933933': 718, 'n03935335': 719, 'n03937543': 720, 'n03938244': 721, 'n03942813': 722, 'n03944341': 723, 'n03947888': 724, 'n03950228': 725, 'n03954731': 726, 'n03956157': 727, 'n03958227': 728, 'n03961711': 729, 'n03967562': 730, 'n03970156': 731, 'n03976467': 732, 'n03976657': 733, 'n03977966': 734, 'n03980874': 735, 'n03982430': 736, 'n03983396': 737, 'n03991062': 738, 'n03992509': 739, 'n03995372': 740, 'n03998194': 741, 'n04004767': 742, 'n04005630': 743, 'n04008634': 744, 'n04009552': 745, 'n04019541': 746, 'n04023962': 747, 'n04026417': 748, 'n04033901': 749, 'n04033995': 750, 'n04037443': 751, 'n04039381': 752, 'n04040759': 753, 'n04041544': 754, 'n04044716': 755, 'n04049303': 756, 'n04065272': 757, 'n04067472': 758, 'n04069434': 759, 'n04070727': 760, 'n04074963': 761, 'n04081281': 762, 'n04086273': 763, 'n04090263': 764, 'n04099969': 765, 'n04111531': 766, 'n04116512': 767, 'n04118538': 768, 'n04118776': 769, 'n04120489': 770, 'n04125021': 771, 'n04127249': 772, 'n04131690': 773, 'n04133789': 774, 'n04136333': 775, 'n04141076': 776, 'n04141327': 777, 'n04141975': 778, 'n04146614': 779, 'n04147183': 780, 'n04149813': 781, 'n04152593': 782, 'n04153751': 783, 'n04154565': 784, 'n04162706': 785, 'n04179913': 786, 'n04192698': 787, 'n04200800': 788, 'n04201297': 789, 'n04204238': 790, 'n04204347': 791, 'n04208210': 792, 'n04209133': 793, 'n04209239': 794, 'n04228054': 795, 'n04229816': 796, 'n04235860': 797, 'n04238763': 798, 'n04239074': 799, 'n04243546': 800, 'n04251144': 801, 'n04252077': 802, 'n04252225': 803, 'n04254120': 804, 'n04254680': 805, 'n04254777': 806, 'n04258138': 807, 'n04259630': 808, 'n04263257': 809, 'n04264628': 810, 'n04265275': 811, 'n04266014': 812, 'n04270147': 813, 'n04273569': 814, 'n04275548': 815, 'n04277352': 816, 'n04285008': 817, 'n04286575': 818, 'n04296562': 819, 'n04310018': 820, 'n04311004': 821, 'n04311174': 822, 'n04317175': 823, 'n04325704': 824, 'n04326547': 825, 'n04328186': 826, 'n04330267': 827, 'n04332243': 828, 'n04335435': 829, 'n04336792': 830, 'n04344873': 831, 'n04346328': 832, 'n04347754': 833, 'n04350905': 834, 'n04355338': 835, 'n04355933': 836, 'n04356056': 837, 'n04357314': 838, 'n04366367': 839, 'n04367480': 840, 'n04370456': 841, 'n04371430': 842, 'n04371774': 843, 'n04372370': 844, 'n04376876': 845, 'n04380533': 846, 'n04389033': 847, 'n04392985': 848, 'n04398044': 849, 'n04399382': 850, 'n04404412': 851, 'n04409515': 852, 'n04417672': 853, 'n04418357': 854, 'n04423845': 855, 'n04428191': 856, 'n04429376': 857, 'n04435653': 858, 'n04442312': 859, 'n04443257': 860, 'n04447861': 861, 'n04456115': 862, 'n04458633': 863, 'n04461696': 864, 'n04462240': 865, 'n04465501': 866, 'n04467665': 867, 'n04476259': 868, 'n04479046': 869, 'n04482393': 870, 'n04483307': 871, 'n04485082': 872, 'n04486054': 873, 'n04487081': 874, 'n04487394': 875, 'n04493381': 876, 'n04501370': 877, 'n04505470': 878, 'n04507155': 879, 'n04509417': 880, 'n04515003': 881, 'n04517823': 882, 'n04522168': 883, 'n04523525': 884, 'n04525038': 885, 'n04525305': 886, 'n04532106': 887, 'n04532670': 888, 'n04536866': 889, 'n04540053': 890, 'n04542943': 891, 'n04548280': 892, 'n04548362': 893, 'n04550184': 894, 'n04552348': 895, 'n04553703': 896, 'n04554684': 897, 'n04557648': 898, 'n04560804': 899, 'n04562935': 900, 'n04579145': 901, 'n04579432': 902, 'n04584207': 903, 'n04589890': 904, 'n04590129': 905, 'n04591157': 906, 'n04591713': 907, 'n04592741': 908, 'n04596742': 909, 'n04597913': 910, 'n04599235': 911, 'n04604644': 912, 'n04606251': 913, 'n04612504': 914, 'n04613696': 915, 'n06359193': 916, 'n06596364': 917, 'n06785654': 918, 'n06794110': 919, 'n06874185': 920, 'n07248320': 921, 'n07565083': 922, 'n07579787': 923, 'n07583066': 924, 'n07584110': 925, 'n07590611': 926, 'n07613480': 927, 'n07614500': 928, 'n07615774': 929, 'n07684084': 930, 'n07693725': 931, 'n07695742': 932, 'n07697313': 933, 'n07697537': 934, 'n07711569': 935, 'n07714571': 936, 'n07714990': 937, 'n07715103': 938, 'n07716358': 939, 'n07716906': 940, 'n07717410': 941, 'n07717556': 942, 'n07718472': 943, 'n07718747': 944, 'n07720875': 945, 'n07730033': 946, 'n07734744': 947, 'n07742313': 948, 'n07745940': 949, 'n07747607': 950, 'n07749582': 951, 'n07753113': 952, 'n07753275': 953, 'n07753592': 954, 'n07754684': 955, 'n07760859': 956, 'n07768694': 957, 'n07802026': 958, 'n07831146': 959, 'n07836838': 960, 'n07860988': 961, 'n07871810': 962, 'n07873807': 963, 'n07875152': 964, 'n07880968': 965, 'n07892512': 966, 'n07920052': 967, 'n07930864': 968, 'n07932039': 969, 'n09193705': 970, 'n09229709': 971, 'n09246464': 972, 'n09256479': 973, 'n09288635': 974, 'n09332890': 975, 'n09399592': 976, 'n09421951': 977, 'n09428293': 978, 'n09468604': 979, 'n09472597': 980, 'n09835506': 981, 'n10148035': 982, 'n10565667': 983, 'n11879895': 984, 'n11939491': 985, 'n12057211': 986, 'n12144580': 987, 'n12267677': 988, 'n12620546': 989, 'n12768682': 990, 'n12985857': 991, 'n12998815': 992, 'n13037406': 993, 'n13040303': 994, 'n13044778': 995, 'n13052670': 996, 'n13054560': 997, 'n13133613': 998, 'n15075141': 999, 'test.py': 1000}

In [None]:
index_to_synset = {value: key for key, value in synset_to_index.items()}

In [None]:
def plurality_element(lst):
    counts = Counter(lst)
    plurality_element, _ = counts.most_common(1)[0]
    return plurality_element

In [None]:
def get_accuracy_based_on_class(df, max_num_neighbors = 1):
  class_to_acc_list = {}

  def update_dict(image_class, neighbor_class):

    correctness_score = 0
    if image_class == neighbor_class:
      correctness_score = 1

    if not image_class in class_to_acc_list:
      class_to_acc_list[image_class] = [correctness_score]
    else:
      class_to_acc_list[image_class].append(correctness_score)

  df.apply(lambda x: update_dict(image_class=x['image_class'], neighbor_class=plurality_element(x['neighbor_classes'][:max_num_neighbors])), axis=1)

  class_to_acc = {}
  for k, v in class_to_acc_list.items():
    class_to_acc[k] = np.mean(v)

  list_of_tuples = sorted(class_to_acc.items(), key=lambda x: x[1])
  return [(x, y) for x, y in list_of_tuples]

## Reliability

In [None]:
def get_reliability_at_k(df: pd.DataFrame, max_num_neighbors: int = 100) -> None:
  """Calculate the reliability at k (neighbor index) for each model."""

  result_df = pd.DataFrame(columns=['featurizer', 'k_index', 'accuracy_at_k'])
  counter = 0

  for featurizer in df['featurizer'].unique():
    featurizer_df = copy.deepcopy(df.loc[df['featurizer'] == featurizer])
    for k in range(max_num_neighbors):
      acc = (featurizer_df.apply(lambda x: x.image_class == x.neighbor_classes[k], axis=1).sum()) / len(featurizer_df)
      # Note: this is using k+1 to avoid division by zero
      # when fitting a log function to the data in plotting
      result_df.loc[counter] = [featurizer, k+1, acc]
      counter += 1
  return result_df


def get_reliability_data_for_pruning(models,
                                     query_split='validation', query_dataset='imagenet2012',
                                     memory_dataset='imagenet2012', memory_split='train',
                                     max_num_neighbors=100):
  assert len(models) == 1, 'only 1 model supported at this time'
  df_combined = read_multiple_scaling_dfs(models=models, query_split=query_split, query_dataset=query_dataset, memory_dataset=memory_dataset, memory_split=memory_split)
  reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=max_num_neighbors)
  index_to_reliability = reliability_data.set_index('k_index')['accuracy_at_k'].to_dict()
  return index_to_reliability


def plot_reliability_at_k(df: pd.DataFrame, ylim: float = 0.0, fit: bool = True, save_fig_path: str = None) -> None:
  """Plot reliability at k (neighbor index) for each model."""

  plt.figure(figsize=(8, 5))

  for _, featurizer in enumerate(df['featurizer'].unique()):
    featurizer_df = df[df['featurizer'] == featurizer]

    plt.plot(featurizer_df['k_index'], featurizer_df['accuracy_at_k']*100.0,
             marker='o', linestyle='-',
             linewidth=2, markersize=8,
             color=featurizer_to_color[featurizer],
             label=featurizer_to_name[featurizer])

  if fit:
    for _, featurizer in enumerate(df['featurizer'].unique()):
      featurizer_df = df[df['featurizer'] == featurizer]
      a, b = np.polyfit(np.log(featurizer_df['k_index']), featurizer_df['accuracy_at_k'], 1)
      fit = []
      for k in featurizer_df['k_index']:
        value = a * np.log(k) + b
        fit.append(value * 100.0)
      plt.plot(featurizer_df['k_index'], fit, linestyle='-', linewidth=1.5, markersize=8, color='black')

  plt.ylim(ylim)
  plt.gca().spines['top'].set_visible(False)
  plt.gca().spines['right'].set_visible(False)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.legend(fontsize=12)

  plt.xlabel('Neighbor index', fontsize=14)
  plt.ylabel('Neighbor accuracy (%)', fontsize=14)

  if save_fig_path:
    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0)
    print(f'Saved figure to {save_fig_path}')
    print_viewing_path(save_fig_path)

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']
df_combined = read_multiple_scaling_dfs(models=models, query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012', memory_split='train')

In [None]:
reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=100)

In [None]:
reliability_data[:10]

In [None]:
plot_reliability_at_k(df=reliability_data,
                      ylim=(43, 83),
                      save_fig_path=f'{FIGURE_DIR}/imagenet_reliability_at_k.pdf')

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']
df_combined = read_multiple_scaling_dfs(models=models, query_dataset='inaturalist', query_split='validation', memory_dataset='inaturalist', memory_split='train')

In [None]:
reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=100)

In [None]:
reliability_data[:10]

In [None]:
plot_reliability_at_k(df=reliability_data,
                      ylim=(0, 62),
                      save_fig_path=f'{FIGURE_DIR}/inaturalist_reliability_at_k.pdf')

In [None]:
models = ['dinov2_vitl14']
df_combined = read_multiple_scaling_dfs(models=models,
                                        query_split='validation', query_dataset='imagenet2012',
                                        memory_dataset='jft-with-vit22b-labels', memory_split='train')
reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=100)
plot_reliability_at_k(df=reliability_data,
                      ylim=(70.0, 84.5))

## Calibration

In [None]:
def get_majority_element(l):
  c = Counter(l)
  elem, count = c.most_common()[0]
  return elem

def get_majority_count(l):
  c = Counter(l)
  elem, count = c.most_common()[0]
  return count

from scipy.stats import entropy

def get_entropy(labels, base=2):
  value, counts = np.unique(labels, return_counts=True)
  return entropy(counts, base=base)

def get_entropy_wrong(l):
    # Count the occurrences of each value
    counter = Counter(l)

    # Calculate the total number of values
    total_count = len(values)

    # Calculate the probability of each value
    probabilities = [count / total_count for count in counter.values()]

    # Calculate entropy using the formula: -sum(p * log(p))
    entropy_value = -sum(p * math.log2(p) for p in probabilities)

    return entropy_value

def get_fraction_of_majority_class(df: pd.DataFrame, max_num_neighbors: int = 100) -> None:
  """Get fraction of majority class."""

  result_df = pd.DataFrame(columns=['featurizer', 'count', 'accuracy'])

  for featurizer in tqdm(df['featurizer'].unique()):
    featurizer_df = copy.deepcopy(df.loc[df['featurizer'] == featurizer])
    featurizer_df['majority-element'] = featurizer_df.apply(lambda x: get_majority_element(first_k(x.neighbor_classes, k=max_num_neighbors)), axis=1)
    featurizer_df['majority-count'] = featurizer_df.apply(lambda x: get_majority_count(first_k(x.neighbor_classes, k=max_num_neighbors)), axis=1)
    featurizer_df['entropy'] = featurizer_df.apply(lambda x: get_entropy(first_k(x.neighbor_classes, k=max_num_neighbors)), axis=1)

    count_to_acc = {}

    for i, row in featurizer_df.iterrows():
      elem = row['majority-element']
      count = row['majority-count']
      ground_truth = row['image_class']
      is_correct = ground_truth == elem

      if not count_to_acc.get(count):
        count_to_acc[count] = []

      if is_correct:
        count_to_acc[count].append(1)
      else:
        count_to_acc[count].append(0)

    for c in range(max_num_neighbors):
      if c in count_to_acc:
        row = {'featurizer': featurizer,'count': c, 'accuracy': np.mean(count_to_acc[c])}
        result_df = pd.concat([result_df, pd.DataFrame([row])], ignore_index=True)

  return result_df

def plot_accuracy_from_count(df, ylim=(0, 100), save_fig_path = None):
  """Plot accuracy ."""

  plt.figure(figsize=(8, 5))

  for _, featurizer in enumerate(df['featurizer'].unique()):
    featurizer_df = df[df['featurizer'] == featurizer]

    plt.plot(featurizer_df['count'], featurizer_df['accuracy']*100.0,
             marker='o', linestyle='-',
             linewidth=2, markersize=8,
             color=featurizer_to_color[featurizer], label=featurizer)

  plt.ylim(ylim)
  plt.gca().spines['top'].set_visible(False)
  plt.gca().spines['right'].set_visible(False)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.legend(fontsize=12)

  # Plot diagonal line
  plt.plot([0, 100], [0, 100], color='black')

  plt.xlabel('Count of plurality class', fontsize=14)
  plt.ylabel('Plurality voting accuracy (%)', fontsize=14)

  if save_fig_path:
    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight')
    print(f'Saved figure to {save_fig_path}')
    print_viewing_path(save_fig_path)

In [None]:
# read data
df = read_scaling_df(model='dinov2_vitl14',
                     query_split='validation', query_dataset='imagenet2012',
                     memory_dataset='imagenet2012', memory_split='train')

In [None]:
# Example usage
values = [795, 795, 795, 970, 795, 795, 795, 795, 795]
print("Entropy:", get_entropy(values))

In [None]:
df['entropy'] = df.apply(lambda x: get_entropy(x.neighbor_classes), axis=1)

In [None]:
max(df['entropy'])

In [None]:
min(df['entropy'])

In [None]:
def plot_entropy_vs_accuracy(df, featurizer, save_fig_path=None):
    # Bin the data into 0.05 intervals
    bins = np.arange(0, 1.05, 0.05)  # 0, 0.05, 0.1, ..., 1.0

    max_entropy = max(df['entropy'])
    df['entropy-normalized'] = df.apply(lambda x: 1 - (x.entropy / max_entropy), axis=1)

    # Group the data by confidence bins and calculate the mean accuracy for each bin
    df['is_correct'] = df.apply(lambda x: get_majority_element(x.neighbor_classes) == x.image_class, axis=1)
    binned_data = df.groupby(pd.cut(df['entropy-normalized'], bins, right=False))['is_correct'].mean()

    # Convert interval index to midpoint for plotting
    midpoints = [interval.left + 0.025 for interval in binned_data.index]

    # Plot accuracy vs. confidence
    plt.plot(midpoints, binned_data.values * 100, marker='o',
             linewidth=2, markersize=8,
             color=featurizer_to_color[featurizer], label=featurizer)

    # Plot diagonal line from (0, 0) to (100, 100)
    plt.plot([0, 1], [0, 100], color='black')

    # Set labels and title
    plt.xlabel('Reverse normalized entropy', fontsize=14)
    plt.ylabel('Accuracy (%)', fontsize=14)

    # Remove top and right frame
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)

    # Remove gridlines
    plt.grid(False)

    # Set axis limits
    plt.xlim(0, 1)
    plt.ylim(0, 100)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12)

    #plt.savefig(save_fig_path, format='pdf', bbox_inches='tight')
    #plt.close()
    plt.show()

In [None]:
df_combined['entropy'] = df_combined.apply(lambda x: get_entropy(x.neighbor_classes), axis=1)

In [None]:
plot_entropy_vs_accuracy(df_combined.loc[df_combined['featurizer'] == 'dinov2_vitl14'],
                         featurizer='dinov2_vitl14')
                         save_fig_path=f'{FIGURE_DIR}/calibration_entropy_vs_accuracy.pdf')

In [None]:
plot_entropy_vs_accuracy(df_combined.loc[df_combined['featurizer'] == 'dinov2_vitb14'],
                         featurizer='dinov2_vitl14')
                         save_fig_path=f'{FIGURE_DIR}/calibration_entropy_vs_accuracy.pdf')

In [None]:
plot_entropy_vs_accuracy(df_combined.loc[df_combined['featurizer'] == 'dinov2_vits14'],
                         featurizer='dinov2_vitl14')
                         save_fig_path=f'{FIGURE_DIR}/calibration_entropy_vs_accuracy.pdf')

In [None]:
result_df = get_fraction_of_majority_class(df=df_combined, max_num_neighbors=100)

In [None]:
result_df

In [None]:
for featurizer in result_df['featurizer'].unique():
  plot_accuracy_from_count(df=result_df.loc[result_df['featurizer'] == featurizer],
                          ylim=(0, 100),
                          save_fig_path=f'{FIGURE_DIR}/calibration_accuracy_from_count_{featurizer}.pdf')

## Defining different aggregation methods

In [None]:
def first_k(x: pd.Series, k: int) -> pd.Series:
    """Return first k elements of x; raise error if k is too large.

    This makes sure that we notice if x is too short (for whatever reason)
    since [1,2,3][:42] would simply return [1,2,3] instead of raising an error.
    """
    assert k > 0, print(k)
    assert len(x) >= k, print(len(x), k)
    return x[:k]

In [None]:
def plot_class_accuracy(df, max_num_neighbors=100):
  """For a given dataframe and max_num_neighbors, plot accuracy for each class.

  Here, max_num_neighbors determines the maximum number of neighbors that were
  considered when aggregating predictions.
  """

  assert max_num_neighbors >= 1

  accuracies_per_class = {}
  for image_class, class_df in df.groupby('image_class'):

    correct_predictions = class_df['prediction_at_k'].apply(lambda x: x[max_num_neighbors-1] == image_class)
    accuracy = correct_predictions.mean() * 100
    accuracies_per_class[image_class] = accuracy

  plt.figure(figsize=(10, 6))
  plt.plot([i for i in range(1000)], sorted(accuracies_per_class.values()))

  plt.xlabel('Sorted class index')
  plt.ylabel('Class-conditional accuracy (%)')
  plt.legend()
  plt.show()

In [None]:
def get_acc_at_k_table(df, aggregators, k_list):
  """Print LaTeX table of accuracy @k nearest neighbors for aggregators."""

  # Define columns
  cols = ['Aggregation']
  for k in k_list:
    cols.append(f'@{k}')

  # Create empty DataFrame
  df_at_k = pd.DataFrame(columns=cols)

  for aggregator in tqdm(aggregators):

    accuracies = get_acc(df=df, aggregator=aggregator, max_num_neighbors=max(k_list))
    row = {'Aggregation': aggregator.get_name()}
    for k in k_list:
      row[f'@{k}'] = accuracies[k-1]
    df_at_k = pd.concat([df_at_k, pd.DataFrame([row])], ignore_index=True)

  # Convert DataFrame to LaTeX table
  formatters = dict()
  cols_bold_mapping = {}
  for c in cols[1:]:
    cols_bold_mapping[c] = max

  def format_numbers(y, num_digits=1):
    return ("{:." + str(num_digits) + "f}").format(y)

  for c, func in cols_bold_mapping.items():
    m = func(df_at_k[c])
    formatters[c] = lambda y, m=m: "\\textbf{" + format_numbers(y) + "}" if y == m else format_numbers(y)

  latex_table = df_at_k.to_latex(escape=False, formatters=formatters,
                                 float_format="%.1f", index=False)
  print(latex_table)

In [None]:
def plot_neighbor_scaling(df: pd.DataFrame,
                          max_num_neighbors: int,
                          aggregators,
                          ylim = (71.0, 77.0),
                          label_aggregator=True,
                          save_fig_path: str = None) -> None:
  """Plot accuracy as a function of number of nearest neighbors."""

  plt.figure(figsize=(8, 5))
  colors  = mpl.colormaps['tab10'].colors
  counter = 0

  for aggregator in aggregators:

    # aggregate results for the first k neighbors and store in list
    df['prediction_at_k'] = df.apply(
          lambda x: aggregator.predict(predictions=first_k(x.neighbor_classes, max_num_neighbors),
                                       distances=first_k(x.neighbor_distances, max_num_neighbors),
                                       neighbor_image_ids=first_k(x.neighbor_image_ids, max_num_neighbors),
                                       featurizer=x.featurizer),
          axis=1)

    for featurizer in df['featurizer'].unique():
      featurizer_df = copy.deepcopy(df.loc[df['featurizer'] == featurizer])

      accuracy_at_k = []
      for k in range(max_num_neighbors):
        accuracy = calculate_accuracy(featurizer_df, k, 'prediction_at_k')
        accuracy_at_k.append(100.0 * accuracy)
        print(f'\rCalculated {aggregator.get_name()} accuracy for {featurizer} at k={k}: {100.0 * accuracy}', end='', flush=True)

      k_list = [x for x in range(1, max_num_neighbors+1)]
      print('\r', end='', flush=True)
      print(f'Max {aggregator.get_name()} accuracy for {featurizer}: {np.round(np.max(accuracy_at_k), 3)} for k={k_list[np.argmax(accuracy_at_k)]}')

      if label_aggregator:
        if aggregator.get_name() in aggregator_to_color.keys():
          color = aggregator_to_color[aggregator.get_name()]
        else:
          color = colors[counter]
        label = aggregator.get_name()
      else:
        color = featurizer_to_color[featurizer]
        label = featurizer_to_name[featurizer]

      plt.plot(k_list, accuracy_at_k, marker='o', linestyle='-',
              linewidth=2, markersize=8, color=color, label=label)
      counter += 1

  plt.gca().set_ylim(ylim)

  plt.gca().spines['top'].set_visible(False)
  plt.gca().spines['right'].set_visible(False)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.legend(fontsize=12)

  plt.xlabel('Number of neighbors (k)', fontsize=14)
  plt.ylabel('Top-1 accuracy (%)', fontsize=14)

  if save_fig_path:
    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0)
    print(f'Saved figure to {save_fig_path}')
    print_viewing_path(save_fig_path)

  return df

In [None]:
class PredictionAggregation():
  """Abstract base class for aggregating predictions from nearest neighbors."""

  def __init__(self):
    self.exclude_sets = dict()
    self.hyperparam = None

  def index_to_weight(self, index: int, *args, **kwargs) -> float:
    raise NotImplementedError()

  def add_exclude_sets(self, exclude_sets):
    self.exclude_sets = exclude_sets

  def predict(self,
              predictions: list,
              neighbor_image_ids: list,
              featurizer: str,
              *args, **kwargs) -> list[str]:
    raise NotImplementedError()

  def get_name(self, state_hyperparam=False) -> str:

    if state_hyperparam and self.hyperparam:
      return f'{self.__class__.__name__} ({self.hyperparam})'
    else:
      return self.__class__.__name__

  def predict(self,
              predictions: list,
              neighbor_image_ids: list,
              featurizer: str,
              *args, **kwargs) -> list[str]:

    predictions_at_k = []
    counts = {}
    max_count = -np.inf
    highest_weight_class = None
    i_count = 0

    # In case add_exclude_sets() was called, exclude 'bad' neighbors
    exclude_neighbors = featurizer in self.exclude_sets.keys()
    exclude_set = set()
    if exclude_neighbors:
      exclude_set = self.exclude_sets[featurizer]
      num_intersecting_bad_neighbors = len(set(neighbor_image_ids).intersection(exclude_set))

      # If there's no good neighbor, proceed as usual without excluding any.
      # Note that this determination is done based on all neighbors,
      # not just on the information for k=1 etc.
      if len(neighbor_image_ids) - num_intersecting_bad_neighbors <= 0:
        exclude_neighbors = False

    for i, p in enumerate(predictions):

      if not (exclude_neighbors and neighbor_image_ids[i] in exclude_set):

        kwargs['neighbor_image_ids'] = neighbor_image_ids
        weight = self.index_to_weight(index=i_count, *args, **kwargs)

        if p in counts:
          counts[p] += weight
        else:
          counts[p] = weight

        if counts[p] > max_count:
          max_count = counts[p]
          highest_weight_class = p

        # Note: index_counter won't be increased in the event of
        # a 'bad neighbor' that needs to be excluded
        i_count += 1

      if not highest_weight_class:
        predictions_at_k.append(predictions[0])
      else:
        predictions_at_k.append(highest_weight_class)

    if len(predictions_at_k) != len(predictions):
      raise ValueError(len(predictions_at_k), len(predictions))

    return predictions_at_k

In [None]:
class PluralityVoting(PredictionAggregation):
  """Simple plurality voting, ignoring distances.

  Note that if two classes are tied for first place,
  the one with the lowest index of first occurrence
  is returned.
  Examples:
  [0, 1, 1] -> return 1 # returning majority class(1)
  [0, 1] -> return 0 # tie; returning first tied class (0)
  [2, 3, 3, 1, 1] -> return 3 # tie; returning first tied class (3)
  """

  def __init__(self):
    super().__init__()

  def index_to_weight(self, index: int, *args, **kwargs) -> float:
    return 1

  def get_name(self) -> str:
    return 'PluralityVoting'

In [None]:
class DistanceWeightedVoting(PredictionAggregation):
  """Weight predictions by math.exp(-distance).

  Note that if exponent=0.0, this includes MajorityVoting as a special case.
  If exponent=1.0, this is identical to the voting done by
  Khandelwal et al. (2020), Generalization through memorization: nearest
  neighbor language models.
  If exponent > 1.0, this 'sharpens the distribution' by giving more weight
  to low-distance neighbors, and less weight to high-distance neighbors.
  """

  def __init__(self, exponent: float = 1.0):
    super().__init__()
    self.exponent = exponent
    self.hyperparam = self.exponent

  def index_to_weight(self, index: int, distances: list, *args, **kwargs) -> float:
    return math.exp(-distances[index])**self.exponent

  def get_name(self) -> str:
    return 'DistanceVoting'

In [None]:
class RankWeightedVoting(PredictionAggregation):
  """Weight predictions by the rank of their neighbors."""

  def __init__(self, offset: float = 2.0):
    super().__init__()
    assert offset >= 0.1 # avoid division by zero
    self.offset = offset
    self.hyperparam = self.offset

  def index_to_weight(self, index: int, *args, **kwargs) -> float:
    return 1/(index + self.offset)

  def get_name(self) -> str:
    return 'RankVoting'

In [None]:
class SoftmaxWeightedVoting(PredictionAggregation):
  """Weight predictions by the softmax of distances.

  This is the default for kNN evaluation of SSL models such as DinoV2.
  High temperature will make weights more similar,
  low temperature will make weights closer to a one-hot distribution.
  """

  def __init__(self, temperature: float = 0.07):
    # the default of temperature = 0.07 comes from:
    # https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/knn.py#L91
    # Exemplary implementations:
    # https://github.com/facebookresearch/dino/blob/main/eval_knn.py#L143
    # https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/knn.py#L99
    super().__init__()
    self.temperature = temperature
    self.hyperparam = self.temperature

  def index_to_weight(self, index: int, distances: list, *args, **kwargs) -> float:
    """Get softmax weight.

    Note that this implementation assumes distances to be normalized to [0, 1],
    and it doesn't perform a full softmax since we don't need probabilities
    we're only interested in relative differences thus dividing by the sum
    as in the softmax isn't necessary.
    """

    if len(distances) > 1:
      assert distances[0] <= distances[1], 'Distances must be sorted'

    # this code assumes distances to be normalized to [0, 1]
    assert max(distances) <= 1.01
    assert min(distances) >= -0.01

    # new implementation
    distance = distances[index]
    similarity = 1 - distance
    return np.exp(similarity / self.temperature)

  def get_name(self) -> str:
    return 'SoftmaxVoting'

In [None]:
def get_acc(df, aggregator, max_num_neighbors):

  assert len(df['featurizer'].unique()) == 1

  df['prediction_at_k'] = df.apply(
      lambda x: aggregator.predict(predictions=first_k(x.neighbor_classes, max_num_neighbors),
                                   distances=first_k(x.neighbor_distances, max_num_neighbors),
                                   neighbor_image_ids=first_k(x.neighbor_image_ids, max_num_neighbors),
                                   featurizer=x.featurizer),
      axis=1)

  accuracy_at_k = {}
  for k in range(max_num_neighbors):
    accuracy = calculate_accuracy(df, k, 'prediction_at_k')
    accuracy_at_k[k] = 100.0 * accuracy

  return accuracy_at_k

def get_max_acc(df, aggregator, max_num_neighbors):
  return max(get_acc(df=df, aggregator=aggregator, max_num_neighbors=max_num_neighbors).values())

## Aggregation method results

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']

model_to_ylim = {'dinov2_vitl14': (80.5, 84.0),
                 'dinov2_vitb14': (78.5, 82.5),
                 'dinov2_vits14': (75, 79.5),
                 'clip-vit_l14': (75, 80.5),
                 'clip-vit_b16': (68, 74.5)}

for FEATURIZER in models:
  df = read_scaling_df(model=FEATURIZER,
                       query_dataset='imagenet2012', query_split='validation',
                       memory_dataset='imagenet2012', memory_split='train')
  aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    RankWeightedVoting(offset=2.0),
    ]
  plot_neighbor_scaling(df=df,
                        max_num_neighbors=100,
                        aggregators=aggregators,
                        ylim = model_to_ylim[FEATURIZER],
                        save_fig_path=f'{FIGURE_DIR}/imagenet2012_aggregators_{FEATURIZER}_no_pruning.pdf');

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']
df_combined = read_multiple_scaling_dfs(models=models, query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012', memory_split='train')

In [None]:
aggregators = [
    RankWeightedVoting(offset=2.0),
    ]
plot_neighbor_scaling(df=df_combined,
                      max_num_neighbors=100,
                      aggregators=aggregators,
                      ylim = (67, 84.0),
                      label_aggregator=False,
                      save_fig_path=f'{FIGURE_DIR}/imagenet2012_all_models_no_pruning.pdf');

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']

model_to_ylim = {'dinov2_vitl14': (48, 65),
                 'dinov2_vitb14': (48, 63),
                 'dinov2_vits14': (42, 58),
                 'clip-vit_l14': (30, 42),
                 'clip-vit_b16': (21, 24)}

for FEATURIZER in models:
  df = read_scaling_df(model=FEATURIZER,
                       query_dataset='inaturalist', query_split='validation',
                       memory_dataset='inaturalist', memory_split='train')
  aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    RankWeightedVoting(offset=2.0),
    ]
  plot_neighbor_scaling(df=df,
                        max_num_neighbors=100,
                        aggregators=aggregators,
                        ylim = model_to_ylim[FEATURIZER],
                        save_fig_path=f'{FIGURE_DIR}/inaturalist_aggregators_{FEATURIZER}_no_pruning.pdf');

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']
df_combined = read_multiple_scaling_dfs(models=models, query_dataset='inaturalist', query_split='validation', memory_dataset='inaturalist', memory_split='train')

In [None]:
aggregators = [
    RankWeightedVoting(offset=2.0),
    ]
plot_neighbor_scaling(df=df_combined,
                      max_num_neighbors=100,
                      aggregators=aggregators,
                      ylim = (20, 65),
                      label_aggregator=False,
                      save_fig_path=f'{FIGURE_DIR}/inaturalist_all_models_no_pruning.pdf');

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']

for model in models:
  df_val = read_scaling_df(model=model,
                           query_dataset='imagenet2012', query_split='validation',
                           memory_dataset='imagenet2012', memory_split='train')
  print()
  print(model)
  get_acc_at_k_table(df=df_val,
                     aggregators=[PluralityVoting(),
                                  DistanceWeightedVoting(exponent=1.0),
                                  SoftmaxWeightedVoting(),
                                  RankWeightedVoting(offset=2.0)],
                     k_list=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])

## Hyperparameter sensitivity

In [None]:
def plot_hyperparameter_accuracy(df, aggregator_fn, hyperparams, max_num_neighbors, verbose=False, save_fig_path=None):
  """Plot accuracy as a function of aggregation hyperparameter."""

  plt.figure(figsize=(8, 5))

  for _, featurizer in enumerate(df['featurizer'].unique()):
    featurizer_df = copy.deepcopy(df[df['featurizer'] == featurizer])

    acc_list = []
    for hyperparam in hyperparams:
      acc = get_max_acc(df=featurizer_df, aggregator=aggregator_fn(hyperparam), max_num_neighbors=max_num_neighbors)
      acc_list.append(acc)
      if verbose:
        print(featurizer, hyperparam, acc)

    plt.plot(hyperparams, acc_list, marker='o', linestyle='-', linewidth=1.5, markersize=8,
             color=featurizer_to_color[featurizer],
             label=featurizer_to_name[featurizer])

  ax = plt.gca()
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.legend(fontsize=12)

  plt.xlabel('Hyperparameter', fontsize=14)
  plt.ylabel('Top-1 accuracy (%)', fontsize=14)

  if save_fig_path:
    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0.0)
    print(f'Saved figure to {save_fig_path}')
    print_viewing_path(save_fig_path)

In [None]:
models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']
df_combined = read_multiple_scaling_dfs(models=models,
                                        query_dataset='imagenet2012', query_split='validation',
                                        memory_dataset='imagenet2012', memory_split='train')

In [None]:
plot_hyperparameter_accuracy(df=df_combined,
                             aggregator_fn=RankWeightedVoting,
                             hyperparams = [1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
                             max_num_neighbors=100,
                             save_fig_path=f'{FIGURE_DIR}/hyperparameters_RankVoting.pdf');

In [None]:
plot_hyperparameter_accuracy(df=df_combined,
                             aggregator_fn=DistanceWeightedVoting,
                             hyperparams = [0.0, 2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5, 25.0, 27.5, 30, 32.5, 35, 37.5, 40],
                             max_num_neighbors=10,
                             save_fig_path=f'{FIGURE_DIR}/hyperparameters_DistanceVoting.pdf');

In [None]:
plot_hyperparameter_accuracy(df=df_combined,
                             aggregator_fn=SoftmaxWeightedVoting,
                             hyperparams = [0.005, 0.01, 0.025, 0.05, 0.07, 0.09, 0.12, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
                             max_num_neighbors=10,
                             save_fig_path=f'{FIGURE_DIR}/hyperparameters_SoftmaxVoting.pdf');

## Memory pruning plots

#### General pruning functionality

In [None]:
def get_pruning_exclude_set(featurizer_list, metric_name, frac_data_excluded):
  df = read_pruning_data(metric_name)

  # Scores are sorted such that high scores mean keeping example is beneficial
  # See https://github.com/rgeirhos/dataset-pruning-metrics/
  df_sorted = df.sort_values(by=metric_name, ascending=True)

  # Calculate the number of rows to keep
  num_rows_to_keep = int(len(df) * frac_data_excluded)

  # Select the top rows based on the calculated number
  top_rows = df_sorted.head(num_rows_to_keep)

  # Extract the img_name column and put it in a set
  img_names_set = set(top_rows['img_name'])

  exclude_set = {}
  for featurizer in featurizer_list:
    exclude_set[featurizer] = img_names_set

  return exclude_set

In [None]:
def get_class_from_imagename(image_id, dataset='imagenet'):
  if dataset == 'imagenet':
    return image_id.split('_')[0]
  else:
    raise ValueError(f'Unknown dataset: {dataset}')

In [None]:
def get_wrong_and_correct_neighbors(df: pd.DataFrame, max_num_neighbors: int = 100) -> None:
  """Return dicts of {img_id: count} for wrong and correct neighbors.

  If max_num_neighbors is set, only use the first max_num_neighbors neighbors.
  """

  assert len(df) > 0, print(len(df))

  # convert to np.arrays for speed
  neighbor_classes = np.array(df['neighbor_classes'].tolist())
  neighbor_classes = neighbor_classes[:, :max_num_neighbors]

  neighbor_image_ids = np.array(df['neighbor_image_ids'].tolist())
  neighbor_image_ids = neighbor_image_ids[:, :max_num_neighbors]

  image_class = np.array(df['image_class'].tolist())

  # broadcast from 1D to 2D
  image_class_matrix = np.repeat(image_class[:, np.newaxis],
                                 repeats=neighbor_image_ids.shape[1],
                                 axis=1)

  def _get_neighbor_count_dict(comparison_operator):

    # get neighbor image ids
    match_indices = np.where(comparison_operator(image_class_matrix, neighbor_classes))
    neighbors = neighbor_image_ids[match_indices]

    unique_elements, counts = np.unique(neighbors, return_counts=True)
    return dict(zip(unique_elements, counts))

  def _get_count(d):
    count = 0
    for _, v in d.items():
      count += v
    return count

  wrong_neighbors = _get_neighbor_count_dict(lambda x, y: x != y)
  num_wrong_neighbors = _get_count(wrong_neighbors)
  print(f'Found {len(wrong_neighbors.keys())} wrong neighbors, occurring a total of {num_wrong_neighbors} times.')

  correct_neighbors = _get_neighbor_count_dict(lambda x, y: x == y)
  num_correct_neighbors = _get_count(correct_neighbors)
  print(f'Found {len(correct_neighbors.keys())} correct neighbors, occurring a total of {num_correct_neighbors} times.')
  print(f'Total number of unique neighbors found: {len(set(correct_neighbors.keys()).union(set(wrong_neighbors.keys())))}, occuring a total of {num_wrong_neighbors + num_correct_neighbors} times.')

  assert int(len(df) * max_num_neighbors) == num_correct_neighbors + num_wrong_neighbors

  return wrong_neighbors, correct_neighbors

In [None]:
def get_exclude_set(neighbor_dict, threshold):
  """Return neighbors with a count of at least threshold."""

  exclude_set = set()
  for k, v in neighbor_dict.items():
    if v >= threshold:
      exclude_set.add(k)
  return exclude_set

In [None]:
def get_exclude_set_wrong_minus_correct_difference(wrong_neighbor_dict, correct_neighbor_dict, threshold):
  """Return neighbors with a count difference of at least threshold between wrong and correct neighbors."""

  exclude_set = set()

  candidates = set(wrong_neighbor_dict.keys()).union(set(correct_neighbor_dict.keys()))
  for k in candidates:

    wrong_count = wrong_neighbor_dict.get(k)
    if not wrong_count:
      wrong_count = 0

    correct_count = correct_neighbor_dict.get(k)
    if not correct_count:
      correct_count = 0

    if wrong_count - correct_count >= threshold:
      exclude_set.add(k)

  return exclude_set

#### Hard memory pruning

In [None]:
df_val = read_scaling_df(model='dinov2_vitl14',
                         query_dataset='imagenet2012', query_split='validation',
                         memory_dataset='imagenet2012', memory_split='train')

In [None]:
df_train = read_scaling_df(model='dinov2_vitl14',
                           query_dataset='imagenet2012', query_split='train',
                           memory_dataset='imagenet2012', memory_split='train',
                           remove_identical_neighbors=True)

In [None]:
wrong_neighbors_train, correct_neighbors_train = get_wrong_and_correct_neighbors(df=df_train, max_num_neighbors=100)

In [None]:
# PluralityVoting
# no pruning: 83.22
# hard pruning: 83.33 for k=10 and excluding 26257 neighbors based on threshold 128
aggregators = [PluralityVoting()]
for i in [32, 64, 128, 192, 256, 320, 384, 448, 512]:
  aggregator = PluralityVoting()
  exclude_set = get_exclude_set(neighbor_dict=wrong_neighbors_train, threshold=i)
  print(i, len(exclude_set))
  aggregator.add_exclude_sets({'dinov2_vitl14': exclude_set})
  aggregators.append(aggregator)

plot_neighbor_scaling(df=df_val,
                      max_num_neighbors=100,
                      aggregators=aggregators,
                      ylim = (80.0, 84.5));

In [None]:
# RankWeightedVoting(2.0)
# no pruning: 83.62
# hard pruning: 83.73 for k=20 and excluding 26257 neighbors based on threshold 128

aggregators = [RankWeightedVoting(offset=2.0)]
for i in [32, 64, 128, 192, 256, 320, 384, 448, 512]:
  aggregator = RankWeightedVoting(offset=2.0)
  exclude_set = get_exclude_set(neighbor_dict=wrong_neighbors_train, threshold=i)
  print(i, len(exclude_set))
  aggregator.add_exclude_sets({'dinov2_vitl14': exclude_set})
  aggregators.append(aggregator)

plot_neighbor_scaling(df=df_val,
                      max_num_neighbors=100,
                      aggregators=aggregators,
                      ylim = (80.0, 84.5));

#### Soft memory pruning

In [None]:
def swarm_plot(data_as_list, sample_size=100):

  assert type(data_as_list) is list

  if len(data_as_list) > sample_size:
    # plotting is slow for many data points, therefore subsample if too large
    sampled_x = np.random.choice(data_as_list, size=sample_size, replace=False)
  else:
    sampled_x = data_as_list

  data = pd.DataFrame({'Value': sampled_x})

  plt.figure(figsize=(10, 6))
  sns.swarmplot(x='Value', data=data)
  plt.show()

In [None]:
class SoftPluralityVoting(PredictionAggregation):
  """Soft-weight predictions by the reliability of their neighbors."""

  def __init__(self, image_id_to_weight: dict[str, float]):
    super().__init__()
    self.image_id_to_weight = image_id_to_weight

  def index_to_weight(self, index: int, neighbor_image_ids, *args, **kwargs) -> float:

    neighbor_img_id = neighbor_image_ids[index]

    if neighbor_img_id in self.image_id_to_weight:
      soft_weight = self.image_id_to_weight[neighbor_img_id]
    else:
      soft_weight = 1.0

    return soft_weight

In [None]:
class SoftRankVoting(PredictionAggregation):
  """Soft-weight predictions by the reliability & rank of their neighbors."""

  def __init__(self, image_id_to_weight: dict[str, float], offset: float = 2.0):
    super().__init__()
    assert offset >= 0.1 # avoid division by zero
    self.offset = offset
    self.image_id_to_weight = image_id_to_weight
    self.hyperparam = self.offset

  def index_to_weight(self, index: int, neighbor_image_ids, *args, **kwargs) -> float:

    neighbor_img_id = neighbor_image_ids[index]

    if neighbor_img_id in self.image_id_to_weight:
      soft_weight = self.image_id_to_weight[neighbor_img_id]
    else:
      soft_weight = 1.0

    return (1/(index + self.offset)) * soft_weight

  def get_name(self):
    return 'SoftRankVoting (ours)'

In [None]:
def get_soft_weights(wrong_neighbors, constant = 1.0, dividend=1.75):
  soft_weights = {}
  for k, v in wrong_neighbors.items():
    soft_weights[k] = dividend / (constant + v)
  return soft_weights

In [None]:
df_train = read_scaling_df(model='dinov2_vitl14',
                           query_dataset='imagenet2012', query_split='train',
                           memory_dataset='imagenet2012', memory_split='train',
                           remove_identical_neighbors=True)

In [None]:
df_val = read_scaling_df(model='dinov2_vitl14',
                         query_dataset='imagenet2012', query_split='validation',
                         memory_dataset='imagenet2012', memory_split='train')

In [None]:
wrong_neighbors_train, correct_neighbors_train = get_wrong_and_correct_neighbors(df=df_train, max_num_neighbors=10)

In [None]:
soft_weights = get_soft_weights(wrong_neighbors_train)

In [None]:
x = list(wrong_neighbors_train.values())
swarm_plot(x, sample_size=100)

In [None]:
x = list(soft_weights.values())
swarm_plot(x, sample_size=150)

In [None]:
aggregators = [SoftRankVoting(image_id_to_weight=soft_weights), RankWeightedVoting(offset=2.0)]
_ = plot_neighbor_scaling(df=df_val,
                          max_num_neighbors=100,
                          aggregators=aggregators,
                          ylim = (81.0, 84.5));

In [None]:
aggregators = [SoftPluralityVoting(image_id_to_weight=soft_weights), PluralityVoting()]
_ = plot_neighbor_scaling(df=df_val,
                          max_num_neighbors=100,
                          aggregators=aggregators,
                          ylim = (81.0, 84.5));

In [None]:
aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    SoftRankVoting(image_id_to_weight=soft_weights, offset=2.0)
]

_ = plot_neighbor_scaling(df=df_val,
                          max_num_neighbors=100,
                          aggregators=aggregators,
                          ylim = (80.0, 84.5),
                          save_fig_path=f'{FIGURE_DIR}/soft_pruning_vs_baselines_dinov2_vitl14.pdf');

In [None]:
plot_hyperparameter_accuracy(df=df_val,
                             aggregator_fn=lambda y: SoftRankVoting(image_id_to_weight=get_soft_weights(wrong_neighbors=wrong_neighbors_train, constant=1.0, dividend=1.75)),
                             hyperparams = np.linspace(1, 1.2, 3),
                             max_num_neighbors=100,
                             verbose=True);

## OOD robustness analysis

In [None]:
def get_accuracy_df(featurizer_list, datasets, aggregators, normalize_distances=False, memory_dataset='imagenet2012', max_num_neighbors=100):

  rows = []

  for featurizer in featurizer_list:
    for qdataset, qsplit in datasets.items():
      data_df = read_scaling_df(model=featurizer,
                                query_split=qsplit, query_dataset=qdataset,
                                memory_dataset=memory_dataset,
                                normalize_distances=normalize_distances)
      for aggregator in aggregators:
        acc = get_max_acc(df=data_df,
                          aggregator=aggregator,
                          max_num_neighbors=max_num_neighbors)
        row = {'featurizer': featurizer,
               'qdataset': qdataset,
               'qsplit': qsplit,
               'aggregator': aggregator.get_name(),
               'accuracy': acc}
        rows.append(row)

  return pd.DataFrame(rows)

In [None]:
featurizer_list = ['dinov2_vitl14']
datasets = {'imagenet2012': 'validation',
            'imagenet-v2': 'test',
            'imagenet-r': 'test',
            'imagenet-a': 'test',
            'imagenet-sketch': 'test'}
            #'imagenet-real': 'test'}
aggregators = [PluralityVoting(), DistanceWeightedVoting(), SoftmaxWeightedVoting(), RankWeightedVoting()]

In [None]:
OOD_df_JFT = get_accuracy_df(featurizer_list, datasets, aggregators, memory_dataset='jft-with-vit22b-labels', max_num_neighbors=10, normalize_distances=True)

In [None]:
OOD_df_JFT

In [None]:
OOD_df_IN = get_accuracy_df(featurizer_list, datasets, aggregators, memory_dataset='imagenet2012', max_num_neighbors=100)

In [None]:
OOD_df_IN = get_accuracy_df(featurizer_list=featurizer_list, datasets={'imagenet-real': 'test'}, aggregators=aggregators, memory_dataset='imagenet2012', max_num_neighbors=10, normalize_distances=True)

In [None]:
OOD_df_IN

In [None]:
OOD_df_IN

## NINCO out-of-distribution classes

In [None]:
df_IN_baseline = read_scaling_df(model='dinov2_vitl14', query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012', memory_split='train')

In [None]:
df_IN_val = read_scaling_df(model='dinov2_vitl14', query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012-and-ninco', memory_split='train-and-test')

In [None]:
df_NINCO_test = read_scaling_df(model='dinov2_vitl14', query_split='test', query_dataset='ninco', memory_dataset='imagenet2012-and-ninco', memory_split='train-and-test', remove_identical_neighbors=True)

In [None]:
df_NINCO_OOD = read_scaling_df(model='dinov2_vitl14', query_split='test', query_dataset='ninco', memory_dataset='imagenet2012', memory_split='train')

In [None]:
NINCO_classes = ['Caracal caracal caracal', 'amphiuma_means', 'aphanizomenon_flosaquae', 'araneus_gemma',
                 'arctocephalus_galapagoensis', 'bagpipe', 'batrachoseps_attenuatus', 'cable', 'chicken_quesadilla',
                 'cirsium_pitcheri', 'creme_brulee', 'ctenolepisma_longicaudata', 'cup_cakes',
                 'darlingtonia_californica', 'dendrolagus_lumholtzi', 'donuts', 'door', 'empty_water_dispencer',
                 'epithelantha_micromeris', 'erysimum_franciscanum', 'f_field_road', 'f_forest_path',
                 'ferocactus_pilosus', 'fire_extinguisher', 'fireworks', 'french_fries', 'glass_of_milk',
                 'gramophone', 'haemulon_sciurus', 'high heels', 'hindu_temple', 'hippopus_hippopus',
                 'lasionycteris_noctivagans', 'lathyrus_odoratus', 'lepomis_auritus', 'leptoglossus_phyllopus',
                 'mbira', 'microcystis_wesenbergii', 'octopus_bimaculoides', 'octopus_rubescens',
                 'ozotoceros_bezoarticus', 'platycephalus_fuscus', 'polistes_dominula', 'pseudorca_crassidens',
                 'pyramid', 's_sky', 'sarpa_salpa', 'sarracenia_alata', 'scissors', 'sepia_apama',
                 'sepia_officinalis', 'sepioteuthis_australis', 'shuttlecock', 'skipper_caterpillar',
                 'spaghetti_bolognese', 'stapler', 'streptopus_lanceolatus', 'tapirus_bairdii', 'triturus_marmoratus',
                 'tursiops_aduncus', 'vaccinium_reticulatum', 'waffles', 'walker', 'windsor_chair']
print(len(NINCO_classes))

In [None]:
# Baseline: memory=IN-train, query=IN-val
aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    RankWeightedVoting(),
    ]
plot_neighbor_scaling(df=df_IN_baseline,
                      max_num_neighbors=100,
                      aggregators=aggregators,
                      ylim = (80.0, 84.0));

In [None]:
# memory=IN-train-and-NINCO, query=IN-val
aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    RankWeightedVoting(),
    ]
plot_neighbor_scaling(df=df_IN_val,
                      max_num_neighbors=100,
                      aggregators=aggregators,
                      ylim = (80.0, 84.0));

In [None]:
# memory=IN-train-and-NINCO, query=NINCO
aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    RankWeightedVoting(),
    ]
plot_neighbor_scaling(df=df_NINCO_test,
                      max_num_neighbors=99,
                      aggregators=aggregators,
                      ylim = (70.0, 88.0));

In [None]:
# memory=IN-train-and-NINCO, query=IN-val-and-NINCO
aggregators = [
    PluralityVoting(),
    DistanceWeightedVoting(exponent=1.0),
    SoftmaxWeightedVoting(),
    RankWeightedVoting(),
    ]
plot_neighbor_scaling(df=pd.concat([df_IN_val, df_NINCO_test], ignore_index=True),
                      max_num_neighbors=99,
                      aggregators=aggregators,
                      ylim = (70.0, 88.0));

#### OOD detection analysis

In [None]:
def plot_distance_histogram(df):
  fig, axs = plt.subplots(1, len(df['featurizer'].unique())+1, figsize=(15, 4))
  for i, featurizer in enumerate(df['featurizer'].unique()):
    featurizer_df = df[df['featurizer'] == featurizer]
    min_distance = np.min(featurizer_df['neighbor_distances'].apply(lambda y: np.min(y)))
    max_distance = np.max(featurizer_df['neighbor_distances'].apply(lambda y: np.max(y)))
    mean_distance = np.mean(featurizer_df['neighbor_distances'].apply(lambda y: np.mean(y)))
    median_distance = np.median(featurizer_df['neighbor_distances'].apply(lambda y: np.median(y)))
    std_distance = np.std(featurizer_df['neighbor_distances'].apply(lambda y: np.std(y)))
    print(f'{featurizer}, min_distance: {min_distance}, max_distance: {max_distance}, mean_distance: {mean_distance}, median_distance: {median_distance}, std_distance: {std_distance}')

    fig.suptitle('Histogram of distances of first neighbors')
    #axs[i].hist(featurizer_df['neighbor_distances'].apply(lambda x: x[0]), bins=20);
    axs[i].hist(featurizer_df['neighbor_distances'].apply(lambda x: np.mean(x)), bins=20);
    axs[i].set_title(featurizer)

In [None]:
plot_distance_histogram(df_NINCO_OOD)

In [None]:
plot_distance_histogram(df_IN_baseline)

#### OOD distance boxplot

In [None]:
def plot_distance_boxplot(df_baseline, df_OOD, statistic='median', save_fig_path=None):

  assert statistic in ['median', 'mean']
  if statistic == 'median':
    ylabel = 'Median distance to first 100 neighbors'
  else:
    ylabel = 'Mean distance to first 100 neighbors'

  df_baseline = copy.deepcopy(df_baseline)
  df_OOD = copy.deepcopy(df_OOD)

  df_baseline['group'] = 'in-distribution (ImageNet)'
  df_OOD['group'] = 'OOD (NINCO)'

  df_baseline['median'] = df_baseline['neighbor_distances'].apply(lambda x: np.median(x[:100]))
  df_OOD['median'] = df_OOD['neighbor_distances'].apply(lambda x: np.median(x[:100]))
  df_baseline['mean'] = df_baseline['neighbor_distances'].apply(lambda x: np.mean(x[:100]))
  df_OOD['mean'] = df_OOD['neighbor_distances'].apply(lambda x: np.mean(x[:100]))

  combined_df = pd.concat([df_baseline, df_OOD])

  # Create the boxplot
  plt.figure(figsize=(10, 6))

  sns.boxplot(data=combined_df, x='group', y=statistic)
  plt.gca().patch.set_facecolor('None')

  ax = plt.gca()
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)

  sns.set(font_scale=1.6)  # Adjust the scaling factor as needed
  sns.set_style("white")

  plt.xlabel('', fontsize=20)
  plt.ylabel(ylabel, fontsize=17)

  if save_fig_path:
    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0)
    print(f'Saved figure to {save_fig_path}')
    print_viewing_path(save_fig_path)

In [None]:
plot_distance_boxplot(df_baseline=df_IN_baseline,
                      df_OOD=df_NINCO_OOD,
                      statistic='median',
                      save_fig_path=f'{FIGURE_DIR}/boxplot_OOD_detection_NINCO_median.pdf')

In [None]:
plot_distance_boxplot(df_baseline=df_IN_baseline,
                      df_OOD=df_NINCO_OOD,
                      statistic='mean',
                      save_fig_path=f'{FIGURE_DIR}/boxplot_OOD_detection_NINCO_mean.pdf')

## Scaling dataset size

In [None]:
def read_imagenet_scaling_results_for_dataset_scaling_experiment(featurizers, sample_sizes, aggregator, max_num_neighbors=100):

  df = pd.DataFrame(columns=['featurizer', 'sample_size', 'accuracy'])

  counter = 0
  for featurizer in featurizers:
    for sample_size in sample_sizes:
      scaling_df = read_scaling_df(model=featurizer,
                                   query_dataset='imagenet2012',
                                   query_split='validation',
                                   memory_dataset='imagenet2012',
                                   memory_split='train',
                                   size=sample_size,
                                   verbose=False)

      acc = get_max_acc(df=scaling_df, aggregator=aggregator, max_num_neighbors=max_num_neighbors)
      if sample_size == 'full':
        sample_size = 1_281_167 # full IN train set
      df.loc[counter] = [featurizer, sample_size, acc]
      counter += 1

  return df

def read_jft_scaling_results_for_dataset_scaling_experiment(data_dir=DATA_DIR):
  df1 = pd.read_csv(f'{data_dir}/dinov2_vitl14_JFT_scaling.csv')
  df2 = pd.read_csv(f'{data_dir}/dinov2_vits14_JFT_scaling.csv')
  combined_df = pd.concat([df1, df2], ignore_index=True)
  return combined_df

In [None]:
def plot_data_scaling(df: pd.DataFrame,
                      x_range = [0, 1, 2, 3, 4, 5, 6, 7],
                      x_ticks = ['1', '10', '100', '1K', '10K', '100K', '1M'],
                      linestyle='-',
                      multiply_accuracy_by_100 = False,
                      save_fig_path = None,
                      plot_error_rate = True,
                      fit = False,
                      plot_yaxis_log_scale = True) -> None:
  """Plot accuracy as a function of memory dataset size."""

  plt.figure(figsize=(8, 5))

  for _, featurizer in enumerate(df['featurizer'].unique()):
    featurizer_df = df[df['featurizer'] == featurizer]

    x = np.log10(featurizer_df['sample_size'])
    y = featurizer_df['accuracy'].to_numpy()
    if multiply_accuracy_by_100:
      y = y * 100

    if plot_error_rate:
      y = 100 - y # convert accuracy to error rate

    if fit:
      assert plot_yaxis_log_scale and plot_error_rate and (not multiply_accuracy_by_100)
      a, b = np.polyfit(np.log10(x), np.log10(y), deg=1)
      print(f'{featurizer}: a = {a}, b = {b}')
      x_range_for_fit = np.linspace(x_range[0], x_range[-1], 100).tolist()
      fit = [10**(a * np.log10(x_val) + b) for x_val in x_range_for_fit]
      plt.plot(x_range_for_fit, fit, linestyle='-', linewidth=1.75, color=featurizer_to_color[featurizer])

    plt.plot(x, y, marker='o', linestyle=linestyle, linewidth=2, markersize=10,
             color=featurizer_to_color[featurizer],
             label=featurizer_to_name[featurizer])

  # Use log y scale with human-readable accuracies
  if plot_yaxis_log_scale:
    plt.yscale('log', base=10)
  ax = plt.gca()
  ax.yaxis.set_minor_formatter('{x:.0f}')
  ax.tick_params(axis='y', which='minor', labelsize=12)


  plt.xticks(x_range, x_ticks)

  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.legend(fontsize=12)

  plt.xlabel('Number of images in memory', fontsize=14)
  plt.ylabel('Top-1 error rate (%)', fontsize=14)

  if save_fig_path:
    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight')
    print(f'Saved figure to {save_fig_path}')
    print_viewing_path(save_fig_path)

In [None]:
featurizer_list = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']

data_IN = read_imagenet_scaling_results_for_dataset_scaling_experiment(
    featurizers=featurizer_list,
    sample_sizes=[1_000, 10_000, 100_000, 'full'],
    aggregator=PluralityVoting())

In [None]:
data_IN

In [None]:
plot_data_scaling(df=data_IN,
                  x_range=[np.log10(i) for i in [1_000, 10_000, 100_000, 1_281_167]],
                  x_ticks=['1K', '10K', '100K', '1.28M'],
                  save_fig_path=f'{FIGURE_DIR}/memory_size_scaling_imagenet.pdf')

In [None]:
data_JFT = read_jft_scaling_results_for_dataset_scaling_experiment()

In [None]:
data_JFT

In [None]:
# log fit in log log space
plot_data_scaling(df=data_JFT,
                  x_range=[np.log10(i) for i in [1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]],
                  x_ticks=['1K', '10K', '100K', '1M', '10M', '100M', '1B'],
                  linestyle='',
                  fit = True,
                  save_fig_path=f'{FIGURE_DIR}/memory_size_scaling_jft.pdf')

#### Raw JFT data

In [None]:
np.log10(data_JFT.loc[data_JFT['featurizer'] == 'dinov2_vitl14']['sample_size']).tolist()

In [None]:
[100 - x for x in data_JFT.loc[data_JFT['featurizer'] == 'dinov2_vitl14']['accuracy'].tolist()]