In [1]:
import torch
import pickle
import numpy as np
from os.path import exists
from keras.datasets import cifar10
from robustbench.utils import load_model
from manifold_angles import ManifoldAngles

In [4]:
def get_model_predictions(model, data):
  predictions = []
  BATCH_SIZE = 50 # 50 seems to work, maybe try to increase

  for i in range(0, data.shape[0], BATCH_SIZE):
    print(f'BATCH: {i} to {i + BATCH_SIZE}')
    input_slice = data[i : i + BATCH_SIZE].type(torch.FloatTensor).cuda()
    output_slice = model(input_slice)
    del input_slice

    predictions.extend(output_slice.cpu().detach().numpy())
    del output_slice

  return np.array(predictions)

In [2]:
models = [
  'Wang2023Better_WRN-70-16',
  'Wang2023Better_WRN-28-10',
  'Rebuffi2021Fixing_70_16_cutmix_extra',
  'Gowal2020Uncovering_extra',
  'Rebuffi2021Fixing_70_16_cutmix_ddpm',
  'Rebuffi2021Fixing_28_10_cutmix_ddpm',
  'Sehwag2021Proxy',
  'Rade2021Helper_R18_ddpm',
  'Rebuffi2021Fixing_R18_cutmix_ddpm',
  'Gowal2020Uncovering',
]

In [None]:
for model_name in models:
  print(f'Checking existing predictions for {model_name}')
  predictions_path_name = f'./predictions/cifar_predictions_{model_name}'
  if exists(predictions_path_name):
    print(f'  Predictions for {model_name} already exists.')
  else:
    print(f'  Predictions for {model_name} not found.')
    # load model into gpu
    print('  Loading model...')
    model = load_model(model_name=model_name, dataset='cifar10', threat_model='L2').cuda()
    print('  Loading dataset...')
    (cifar_X_train, _), (_, _) = cifar10.load_data()
    del _ # don't save stuff we're not using
    cifar_X_train = cifar_X_train / 255
    model_inputs = torch.from_numpy(np.reshape(cifar_X_train, (cifar_X_train.shape[0], 3, 32, 32)))

    print('  Generating predictions...')
    predictions = get_model_predictions(model, model_inputs)
    print('len predictions:', len(predictions))

    print(f'  Saving predictions to {predictions_path_name}')
    with open(predictions_path_name, 'wb') as file:
      pickle.dump(predictions, file)
      file.close()

    # clean up
    del model, cifar_X_train, model_inputs, predictions, file

  print(f'Checking existing curvature sets for {model_name}')
  curvatures_path_name = f'./output_curv/cifar_output_curv_{model_name}'
  if exists(curvatures_path_name):
    print(f'  Curvatures for {model_name} already exists.')
  else:
    print(f'  Curvatures for {model_name} not found.')
    with open(predictions_path_name, 'rb') as file:
      predictions = np.array(pickle.load(file))
      file.close()

    print('  Loading dataset...')
    (_, cifar_train_y), (_, _) = cifar10.load_data()
    del _, file # don't save stuff we're not using

    curvatures = []
    print('   Generating curvatures:')
    for y_class in range(10):
      print(f'     Generating curvature for class {y_class}...')
      _, manifold_neighbour_angle_sum = ManifoldAngles([predictions[cifar_train_y[:, 0] == y_class]], classsize=1, neighboursize1=10, dim_reduc_size=5)
      curvatures.append(np.array(manifold_neighbour_angle_sum))
      del _, manifold_neighbour_angle_sum

    print(f'  Saving curvatures to {curvatures_path_name}')
    with open(curvatures_path_name, 'wb') as file:
      pickle.dump(np.array(curvatures), file)
      file.close()

    # clean up
    del predictions, cifar_train_y, curvatures, file

### save average to file

In [None]:
curvs_list = []
for model in models:
  with open(f'./cached/output_curv/cifar_output_curv_{model}', 'rb') as file:
    curvs_list.append(pickle.load(file))
    file.close()

averages = np.average(curvs_list, axis=0)
with open(f'./cached/output_curv/cifar_output_curv_AVG', 'wb') as file:
  pickle.dump(averages, file)

In [147]:
with open(f'./cached/cifar_input_curv', 'rb') as file:
  input_curv = np.array(pickle.load(file)).squeeze()
  file.close()

print(input_curv.shape)

with open(f'./cached/cifar_input_curv_ndarray', 'wb') as file:
  pickle.dump(input_curv, file)
  file.close()

(10, 5000)


In [149]:
with open(f'./cached/cifar_input_curv_ndarray', 'rb') as file:
  input_curv_new = pickle.load(file)
  print(input_curv_new.shape)

(10, 5000)
