Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Nearest neighbor evaluation

This colab evaluates Dino / CLIP models on classification datasets, storing nearest neighbor information as JSON.

In [None]:
#@title Imports

import json
import os
import time
import tqdm
import random
import numpy as np
import functools

import jax
from jax.sharding import PartitionSpec as P
from jax.experimental import mesh_utils
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

In [None]:
#@title Jax sharding utils

NamedSharding = jax.sharding.NamedSharding

mesh_shape = (jax.device_count(),)
mesh = mesh_utils.create_device_mesh(
    mesh_shape, devices=jax.devices()
)
mesh = jax.sharding.Mesh(mesh, axis_names=('data',))
p = NamedSharding(mesh, P('data', None))

print(mesh)


def get_shard_array_fn(sharding):
  shard_array_fn = jax.jit(lambda x:x, out_shardings =sharding)
  return shard_array_fn


def shard_array(arr):
  if arr.shape[0] % jax.local_device_count() == 0:
    shard_array_fn = get_shard_array_fn(p)
    return shard_array_fn(arr)
  return arr

print('num devices = ', jax.local_device_count())

In [None]:
#@title Resizing

def resize_smaller_side(
    images: tf.Tensor, size: int, antialias: bool = False
) -> tf.Tensor:
  """Resizes the smaller side to size preserving the aspect ratio.

  Args:
    images: image batch of shape [B, H, W, 3].
    size: integer that represents a new size of the smaller side of an input
      image.
    antialias: whether to use an anti-aliasing filter when downsampling an
      image.

  Returns:
    resized images with aspect ratio preserved.
  """

  h, w = tf.shape(images)[-3], tf.shape(images)[-2]

  ratio = tf.cast(size, tf.float32) / tf.cast(tf.minimum(h, w), tf.float32)
  h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
  w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)
  images = tf.image.resize(
      images, [h, w], method=tf.image.ResizeMethod.BICUBIC, antialias=antialias
  )
  return images


def central_crop(images: tf.Tensor, size: int) -> tf.Tensor:
  """Central crop images to size.

  Args:
    images: images of shape [B, H, W, C] as float32 tensor.
    size: integer that represents the new height and width of the images.

  Returns:
    resized images.
  """

  assert len(tf.shape(images)) == 4

  h, w = size, size
  top = (tf.shape(images)[-3] - h) // 2
  left = (tf.shape(images)[-2] - w) // 2
  return tf.image.crop_to_bounding_box(images, top, left, h, w)

In [None]:
#@title Define DINO featurizer

class DinoFeatures():
  """A class to extract DINO features from images."""

  # DINO was trained on ImageNet data that is centered and scaled using the
  # following mean and stddev. We need to transform image values from the
  # [0.0, 1.0] range using the mean and std values below as follows:
  # val = (val - mean) / std  ... for each R,G,B channel.
  MEAN_RGB = np.array([0.485, 0.456, 0.406])
  STDDEV_RGB = np.array([0.229, 0.224, 0.225])

  def __init__(
      self,
      model_name: str = '',
      is_tpu_inference: bool = True,
  ):
    # load model weights, not included here
    pass

  def preprocess_tpu(
      self,
      images: tf.Tensor,
      aspect_ratio_size: int,
      central_crop_size: int,
      antialias: bool,
  ) -> tf.Tensor:
    images_to_preprocess = images['image']
    resized_images = resize_smaller_side(
        images_to_preprocess, size=aspect_ratio_size, antialias=antialias
    )
    cropped_images = central_crop(resized_images, size=central_crop_size)
    images['image'] = cropped_images
    return images

  def extract_batch(self, image: np.ndarray) -> np.ndarray:
    return self._extract_batch(image)

  def _extract_batch(self, images: np.ndarray) -> np.ndarray:
    """Computes DINO features on the given image.

    Args:
      images: Image tensor [B H W C] with values in [0, 255] range.
             Image channels must be ordered as RGB and not BGR.

    Returns:
      features: extracted pooled DINO features for the image.
    """

    # Preprocessing follows DINO v2 code for kNN eval (e.g.
    # https://github.com/facebookresearch/dino/blob/main/eval_knn.py)
    # exact preprocessing from here:
    # https://github.com/facebookresearch/dino/issues/149
    images = images.astype(np.float32)/255.0
    images = (images - DinoFeatures.MEAN_RGB) / DinoFeatures.STDDEV_RGB
    # Calling model's apply function to encode the images.
    _, features = self._apply_fn(images)
    features = jnp.apply_along_axis(lambda x: x / jnp.linalg.norm(x), arr=features, axis=1)

    return features

In [None]:
#@title Define CLIP featurizer

class ClipFeatures():
  """A class to extract Clip features from images."""

  def __init__(
      self,
      model_name: str = '',
      is_tpu_inference: bool = False,
  ):
    # load model weights, not included here
    pass

  def preprocess_tpu(
      self,
      images: tf.Tensor,
      aspect_ratio_size: int,
      central_crop_size: int,
      antialias: bool,
  ) -> tf.Tensor:

    images_preprocess = images['image']

    resized_images = resize_smaller_side(
        images_preprocess, size=aspect_ratio_size, antialias=antialias
    )
    cropped_images = central_crop(resized_images, size=central_crop_size)
    images['image'] = cropped_images
    return images

  def extract_batch(self, image: np.ndarray) -> np.ndarray:
    return self._extract_batch(image)

  def _extract_batch(self, images: np.ndarray) -> np.ndarray:
    """Computes CLIP features on the given preprocessed image batch.

    Args:
      images: Float image tensor [B H W C] with values in [0, 255] range. Image
        channels must be ordered as RGB and not BGR.

    Returns:
      ? pre-logit features [B, D] and D is the model variant dependent
      feature dimension.
    """
    if len(images.shape) != 4:
      raise ValueError(f'Image must be (B, H, W, 3) but got {images.shape}')
    if images.shape[-1] != 3:
      raise ValueError(f'Image must be 3 channels but got {images.shape}')

    images = images.astype(np.float32) / 255.0
    images = jnp.array(images)
    images = clip.normalize_image(images) # CLIP normalization, not included here

    embedding, _ = self._apply_fn(image=images)
    return embedding

  def extract_text_features(self, text_queries) -> np.ndarray:
    tokens = self._tokenizer(text_queries)
    _, text_features = self._model.apply(self._params, image=None,
                                         text=tokens)
    return text_features

In [None]:
#@title Define preprocessing

def dataset_preprocess(dataset, featurizer, batch_size, resizing_size, antialias):
  dataset_preprocessed = dataset.map(
      lambda x: featurizer.preprocess_tpu(
          x,
          resizing_size,
          224,
          antialias=antialias,
      ),
      num_parallel_calls=tf.data.AUTOTUNE,
  )

  return dataset_preprocessed.batch(batch_size, drop_remainder=False)

In [None]:
#@title Utils for subsampling

def subsample_train_set_embeddings(train_embeddings, num_per_class=100):
  subsampled_train_embeddings = {
      'file_name': [],
      'labels': [],
      'features': [],
  }
  examples_by_class = {}

  # Bucket by class.
  for idx, label in enumerate(train_embeddings['labels']):
    if label not in examples_by_class:
      examples_by_class[label] = {
          'file_name': [],
          'labels': [],
          'features': []
      }

    examples_by_class[label]['features'].append(train_embeddings['features'][idx])
    examples_by_class[label]['file_name'].append(train_embeddings['file_name'][idx])
    examples_by_class[label]['labels'].append(train_embeddings['labels'][idx])

  # Subsample per class.
  subsampled_examples_by_class = {}

  for key in examples_by_class:
    length = len(examples_by_class[key]['file_name'])

    indices = list(range(length))
    random.shuffle(indices)
    random_indices = indices[:num_per_class]

    subsampled_examples_by_class[key] = {}
    subsampled_examples_by_class[key]['file_name'] = np.array(examples_by_class[key]['file_name'])[random_indices]
    subsampled_examples_by_class[key]['labels'] = np.array(examples_by_class[key]['labels'])[random_indices]
    subsampled_examples_by_class[key]['features'] = np.array(examples_by_class[key]['features'])[random_indices]

  # Put back into original form
  for key in subsampled_examples_by_class:
    subsampled_train_embeddings['file_name'].extend(subsampled_examples_by_class[key]['file_name'])
    subsampled_train_embeddings['labels'].extend(subsampled_examples_by_class[key]['labels'])
    subsampled_train_embeddings['features'].extend(subsampled_examples_by_class[key]['features'])

  subsampled_train_embeddings['features'] = np.array(subsampled_train_embeddings['features'])
  subsampled_train_embeddings['file_name'] = np.array(subsampled_train_embeddings['file_name'])
  subsampled_train_embeddings['labels'] = np.array(subsampled_train_embeddings['labels'])

  return subsampled_train_embeddings

In [None]:
def scale_memory(featurizer_name,
                 train_embeddings,
                 validation_embeddings,
                 file_path_prefix,
                 num_imgs_per_class_list = [1,10,100,1000],
                 top_k=100,
                 batch_size=1024):

  for num_imgs_per_class in num_imgs_per_class_list:

    subsampled_train_embeddings = subsample_train_set_embeddings(train_embeddings, num_per_class = num_imgs_per_class)
    train_x_train_subsampled = get_nearest_neighbors(subsampled_train_embeddings, subsampled_train_embeddings, k=top_k+1, batch_size=min(batch_size, subsampled_train_embeddings['features'].shape[1]))
    val_x_train_subsampled = get_nearest_neighbors(validation_embeddings, subsampled_train_embeddings, k=top_k, batch_size=min(batch_size, validation_embeddings['features'].shape[1]))
    print_accuracy(label=f'{featurizer_name}_train_x_train_subsampled_{num_imgs_per_class}', neighbors=train_x_train_subsampled)
    print_accuracy(label=f'{featurizer_name}_val_x_train_subsampled_{num_imgs_per_class}', neighbors=val_x_train_subsampled)

    size = int(1000 * num_imgs_per_class)
    file_path = f'{file_path_prefix}_{size}_neighbor_info.json'
    write_neighbor_info_to_json(val_x_train_subsampled, file_path, FEATURIZER_NAME)

    file_path = f'{file_path_prefix}_{size}_neighbor_info.json'
    write_neighbor_info_to_json(train_x_train_subsampled, file_path, FEATURIZER_NAME)

In [None]:
#@title Loading dataset from TFDS
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import tensorflow_datasets as tfds

def rename_feature(example):
    example['file_name'] = example.pop('image/filename')
    return example

def load_dataset(dataset_name, split):

  if dataset_name in ['ninco']:
    # https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder
    builder = tfds.folder_dataset.ImageFolder(f'/path/to/{dataset_name}/')
    dataset = builder.as_dataset(split=split, shuffle_files=False)
    dataset = dataset.map(rename_feature)
  else:
    dataset = tfds.load(dataset_name, split=split, shuffle_files=False)
  return dataset

def load_combined_datasets(list_of_names_and_splits):

  first_name, first_split = list_of_names_and_splits[0]
  combined_dataset = load_dataset(dataset_name=first_name, split=first_split)

  for dataset_name, split in list_of_names_and_splits[1:]:
    dataset_2 = load_dataset(dataset_name=dataset_name, split=split)
    combined_dataset = combined_dataset.concatenate(dataset_2)

  return combined_dataset

def load_dataset_batches(list_of_names_and_splits, featurizer, batch_size, resizing_size, antialias):

  dataset = load_combined_datasets(list_of_names_and_splits)
  batches = dataset_preprocess(dataset, featurizer, batch_size, resizing_size, antialias)
  return batches

In [None]:
#@title Get model embeddings

def get_model_embeddings(featurizer, batched_dataset):
  """Generate embeddings for featurizer.

  Args:
    featurizer: model to be used to generate embeddings.
    batched_dataset: dataset preprocess and batched.

  Returns:
    features_dict: dictionary containing embeddings of batched_dataset

  """
  features, labels, file_name = [], [], []
  for batch in tqdm.tqdm(iter(batched_dataset)):
    batch_sharded = shard_array(batch['image'].numpy())
    features.append(featurizer.extract_batch(batch_sharded))
    labels.append(batch['label'].numpy())

    if 'file_name' in batch:
      file_name.append(batch['file_name'].numpy())
    elif 'image/filename' in batch:
      file_name.append(batch['image/filename'].numpy())
    else:
      raise ValueError('file_name or image/filename not found in batch')

  features = np.vstack(features)
  labels = np.concatenate(labels, axis=0)
  file_name = np.concatenate(file_name, axis=0)

  features_dict = {
      'features': features,
      'labels': labels,
      'file_name': file_name,
  }
  return features_dict

In [None]:
#@title Utils to compute top K neighbors

@jax.jit
def jitted_dot_fn(x, y):
  dot_product =  jnp.dot(x, y.T, precision=jax.lax.Precision.HIGHEST)
  dot_product = jax.lax.with_sharding_constraint(dot_product, NamedSharding(mesh, P('data', None)))

  return dot_product


# Extracts indices of Top-K neighbors
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def extract_min_indices(dot_products, top_k=100):
  # array of items of shape [1024, 1024], len of this array is 1252
  # after concatenate it becomes array of shape [1024, 1,211,867]
  dot_products = jnp.concatenate(dot_products, axis=1)
  dot_products, indices = jnp.apply_along_axis(lambda x: jax.lax.top_k(x, k=top_k), arr=dot_products, axis=1)

  return dot_products, indices

def get_nearest_neighbors(query_features, neighbor_features, k, batch_size):
  """Generate neighbors for the test images for kNN classification.

  Args:
    train_features: list feature embeddings for train set.
    test_features: feature embeddings for test set.
    k: number of nearest neighbor.
    batch_size: batch size for dot product calculation.

  Returns:
    _: dictionary containing information of nearest neighbors of test set.

  """

  neighbor_info = {}
  batched_query_features = []
  batched_neighbor_features = []

  # Form batches from entire dataset
  print('Sharding query batches ....')
  for idx in tqdm.tqdm(range(0, query_features['features'].shape[0], batch_size)):
    features = query_features['features'][
      idx : min(idx + batch_size, query_features['features'].shape[0])
    ]

    query_labels = query_features['labels'][idx : min(idx + batch_size, query_features['features'].shape[0])]
    query_filenames = query_features['file_name'][idx : min(idx + batch_size, query_features['features'].shape[0])]

    features = shard_array(features)

    batched_query_features.append(
      {
        'features': features,
        'labels': query_labels,
        'filenames': query_filenames,
      })

  print('Sharding neighbor batches ....')
  for idx in tqdm.tqdm(range(0, neighbor_features['features'].shape[0], batch_size)):
    features = neighbor_features['features'][
      idx : min(idx + batch_size, neighbor_features['features'].shape[0])
    ]

    query_labels = neighbor_features['labels'][idx : min(idx + batch_size, neighbor_features['features'].shape[0])]
    query_filenames = neighbor_features['file_name'][idx : min(idx + batch_size, neighbor_features['features'].shape[0])]

    features = shard_array(features)

    batched_neighbor_features.append(
      {
        'features': features,
        'labels': query_labels,
        'filenames': query_filenames,
      })

  # Compute dot products for 1 batch and extract top-K neighbors.
  print('Extracting top-K neighbors ....')
  for query_batch in tqdm.tqdm(batched_query_features):
    dot_products_list = []
    neighbor_file_names_list = []
    neighbor_labels_list = []

    for neighbor_batch in batched_neighbor_features:
      # calculate the dot product and compute top-k neighbors
      distances = jitted_dot_fn(query_batch['features'], neighbor_batch['features'])
      neighbor_file_names = neighbor_batch['filenames']
      neighbor_labels = neighbor_batch['labels']

      dot_products_list.append(distances)
      neighbor_file_names_list.append(neighbor_file_names)
      neighbor_labels_list.append(neighbor_labels)

    neighbor_labels_list = np.concatenate(neighbor_labels_list, axis=0)
    dot_products_list, min_distance_indices = extract_min_indices(dot_products_list, top_k=k)

    batch_size = dot_products_list.shape[0]
    neighbor_labels_list = np.broadcast_to(neighbor_labels_list, (batch_size, neighbor_labels_list.shape[0]))

    neighbor_file_names_list = np.concatenate(neighbor_file_names_list, axis=0)
    neighbor_file_names_list = np.broadcast_to(
        neighbor_file_names_list, (batch_size, neighbor_file_names_list.shape[0]))

    neighbor_file_names_list = np.take_along_axis(neighbor_file_names_list, min_distance_indices, axis=1)
    neighbor_labels_list = np.take_along_axis(neighbor_labels_list, min_distance_indices, axis=1)
    dot_products_list = jax.device_get(dot_products_list)

    min_distance_indices.delete()

    for index in range(batch_size):
      key = str(query_batch['filenames'][index].decode('utf-8'))

      neighbor_info[key] = {
          'image_id': key,
          'image_class': query_batch['labels'][index],
          'neighbor_image_ids': neighbor_file_names_list[index],
          'neighbor_classes': neighbor_labels_list[index],
          'neighbor_distances': 1.0 - dot_products_list[index],
      }


  return neighbor_info

In [None]:
#@title Write neighbor info to JSON
def write_neighbor_info_to_json(neighbors, file_path, featurizer_name):

  t1 = time.time()

  data = {}

  for key in neighbors:
    k = neighbors[key]['image_id']

    data[k] = {}
    data[k]['featurizer'] = featurizer_name
    data[k]['image_id'] = neighbors[key]['image_id']
    data[k]['image_class'] = int(neighbors[key]['image_class'])
    data[k]['neighbor_image_ids'] = [x.decode('utf-8') for x in neighbors[key]['neighbor_image_ids'].tolist()]
    data[k]['neighbor_classes'] = neighbors[key]['neighbor_classes'].tolist()
    data[k]['neighbor_distances'] = neighbors[key]['neighbor_distances'].tolist()

  print(f'Writing json to {file_path}')
  with open(file_path, mode='w') as f:
    json.dump(data, f)

  t2 = time.time()
  print(f'Time to write: {round(t2 - t1, 1)} seconds')

In [None]:
#@title Evaluate k=1 accuracy

def print_accuracy(label, neighbors):
  correct = 0
  total = 0

  for key in neighbors:
    neighbor_classes = neighbors[key]['neighbor_classes']

    if neighbor_classes[0] == neighbors[key]['image_class']:
      correct += 1
    total += 1

  print('label=%s, correct: %d, total: %d, accuracy: %f' % (label, correct, total, correct*100 / total))

In [None]:
#@title Assert valid dataset names

def assert_valid_memory_dataset(memory_dataset: str):
  assert memory_dataset in ['imagenet2012', 'ninco']

def assert_valid_query_dataset(query_dataset: str):
  assert query_dataset in ['imagenet2012',
                           'imagenet_v2',
                           'imagenet_r',
                           'imagenet_sketch',
                           'imagenet_a',
                           'imagenet2012_real',
                           'ninco']

def get_memory_name_and_split(memory_list):

  name_index = 0
  split_index = 1
  mname = memory_list[0][name_index]
  msplit = memory_list[0][split_index]

  for m, s in memory_list[1:]:
    mname = f'{mname}-and-{m}'
    msplit = f'{msplit}-and-{s}'
  return mname, msplit

In [None]:
#@title Load multiple models and datasets

# NOTE: multiple memory datasets will be added as a combined memory
# Here is an example for multiple datasets in memory:
# MEMORY = [('imagenet2012', 'train'), ('ninco', 'test')]
# Here is an example for a single dataset in memory:
MEMORY = [('imagenet2012', 'train')]

SCALE_MEMORY = False # set to True for storing down-scaled memory in addition to full memory results
TOP_K = 100 # Note that for train, TOP_K+1 is saved automatically below.
batch_size=1024

RESULT_DIR = '/path/to/result/directory'

FEATURIZER_DICT = {
    'dinov2_vitl14': lambda: DinoFeatures(model_name='dinov2_vitl14', is_tpu_inference=True),
    'dinov2_vitb14': lambda: DinoFeatures(model_name='dinov2_vitb14', is_tpu_inference=True),
    'dinov2_vits14': lambda: DinoFeatures(model_name='dinov2_vits14', is_tpu_inference=True),
    'clip-vit_l14':  lambda: ClipFeatures(model_name='vit_l14', is_tpu_inference=True),
    'clip-vit_b16':  lambda: ClipFeatures(model_name='vit_b16', is_tpu_inference=True),
                   }
QUERY_DATASET_LIST = [
    ('imagenet2012', 'validation'),
    ('imagenet_v2', 'test'),
    ('imagenet_r', 'test'),
    ('imagenet_sketch', 'test'),
    ('imagenet_a', 'test'),
    ('ninco', 'test'),
    ('imagenet2012_real', 'validation'),
    ]

# assert dataset names are valid
for m, _ in MEMORY:
  assert_valid_memory_dataset(m)
for (q, _) in QUERY_DATASET_LIST:
  assert_valid_query_dataset(q)

In [None]:
#@title Main run loop

for FEATURIZER_NAME, model_loader in FEATURIZER_DICT.items():

  # Load model
  print('Loading model: ', FEATURIZER_NAME)
  featurizer = model_loader()

  # Set preprocessing details
  # Dino preprocessing: resize(256) then crop(224)
  # https://github.com/facebookresearch/dinov2/blob/main/dinov2/data/transforms.py#L77
  # CLIP preprocessing: resize(224) then crop(224)
  # https://github.com/openai/CLIP/blob/main/clip/clip.py#L79

  antialias = 'dino' in FEATURIZER_NAME
  print('antialias: ', antialias)
  resizing_size = 256
  if 'clip' in FEATURIZER_NAME:
    resizing_size = 224
  print('resizing size: ', resizing_size)

  # Load train batches
  MEMORY_DATASET, MEMORY_SPLIT = get_memory_name_and_split(MEMORY)
  print(f'Loading memory dataset {MEMORY_DATASET} with split {MEMORY_SPLIT}')

  train_batches = load_dataset_batches(list_of_names_and_splits=MEMORY,
                                       featurizer=featurizer,
                                       batch_size=batch_size,
                                       resizing_size=resizing_size,
                                       antialias=antialias)
  train_embeddings = get_model_embeddings(featurizer, train_batches)
  print(train_embeddings['features'].shape)

  train_x_train_full_neighbors = get_nearest_neighbors(train_embeddings, train_embeddings, k=TOP_K+1, batch_size=batch_size)
  print_accuracy(label=f'{FEATURIZER_NAME}_train_x_train_full', neighbors=train_x_train_full_neighbors)

  file_path = f'{RESULT_DIR}/memory-{MEMORY_DATASET.replace("_", "-")}_msplit-{MEMORY_SPLIT}_query-{MEMORY_DATASET.replace("_", "-")}_qsplit-{MEMORY_SPLIT}_{FEATURIZER_NAME}_full_neighbor_info.json'
  write_neighbor_info_to_json(train_x_train_full_neighbors, file_path, FEATURIZER_NAME)

  for (QUERY_DATASET, QUERY_SPLIT) in QUERY_DATASET_LIST:

    print(f'Loading query dataset {QUERY_DATASET} with split {QUERY_SPLIT}')
    query_batches = load_dataset_batches(list_of_names_and_splits=[(QUERY_DATASET, QUERY_SPLIT)],
                                         featurizer=featurizer,
                                         batch_size=batch_size,
                                         resizing_size=resizing_size,
                                         antialias=antialias)
    query_embeddings = get_model_embeddings(featurizer, query_batches)
    val_x_train_full_neighbors = get_nearest_neighbors(query_embeddings, train_embeddings, k=TOP_K, batch_size=batch_size)
    print_accuracy(label=f'{FEATURIZER_NAME}_val_x_train_full', neighbors=val_x_train_full_neighbors)

    qdataset = QUERY_DATASET.replace("_", "-")
    if qdataset == 'imagenet2012-real':
      qdataset = 'imagenet-real'
    file_path = f'{RESULT_DIR}/memory-{MEMORY_DATASET.replace("_", "-")}_msplit-{MEMORY_SPLIT}_query-{qdataset}_qsplit-{QUERY_SPLIT}_{FEATURIZER_NAME}_full_neighbor_info.json'
    write_neighbor_info_to_json(val_x_train_full_neighbors, file_path, FEATURIZER_NAME)

    if SCALE_MEMORY:
      scale_memory(featurizer_name=FEATURIZER_NAME,
                   train_embeddings=train_embeddings,
                   validation_embeddings=query_embeddings,
                   file_path_prefix = file_path.replace('_full_neighbor_info.json', ''),
                   num_imgs_per_class_list = [1,10,100,1000],
                   top_k=TOP_K)