In [88]:
import os
import sys
import numpy as np
import pandas as pd
from pandas.core.common import flatten

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from dejavuoob.utils.image_common import ImageFolderIndexWithPath, SSL_Transform
from dejavuoob.utils.common import most_conf_frac
from scipy import stats

from torch.utils.data import DataLoader

In [67]:
def get_confidence_and_topk(neighb_labels, k_neighbs = 100, topk = [1]):
    from scipy.stats import entropy
    import torch

    #get class counts
    class_cts = np.apply_along_axis(np.bincount, axis=1,
                            arr=neighb_labels[:,:k_neighbs], minlength=1000)

    #get confidence
    attk_uncert = entropy(class_cts, axis = 1)
    preds = {}
    preds_values = {}
    for k in topk:
        topk_cts, topk_preds = torch.topk(torch.Tensor(class_cts), k, dim = 1)
        topk_preds = np.array(topk_preds).astype(int)
        preds[f'top_{k}'] = topk_preds
        preds_values[f'top[k]'] = topk_cts

    return -attk_uncert, preds

In [68]:
VICREG_KNN_ROOT_PATH_MEM_OOB = '<PATH-VICREG-OUTPUT>/dejavu/vicreg/attack_sweeps/NN_attk_vicregoob_model_bbox_A_bbox_B_blurred_05.11.2024'

vicreg_bbox_attk_B_idxs_mem_oob = np.load(VICREG_KNN_ROOT_PATH_MEM_OOB + '/valid_attk_A_attk_idxs.npy')
vicreg_bbox_attk_B_labels_mem_oob = np.load(VICREG_KNN_ROOT_PATH_MEM_OOB + '/valid_attk_A_labels.npy')
vicreg_bbox_attk_B_neighb_labels_mem_oob = np.load(VICREG_KNN_ROOT_PATH_MEM_OOB + '/valid_attk_A_neighb_labels.npy')
vicreg_bbox_attk_B_neighb_idx_mem_oob = np.load(VICREG_KNN_ROOT_PATH_MEM_OOB + '/valid_attk_A_neighb_idxs.npy')

len(vicreg_bbox_attk_B_labels_mem_oob)

139183

In [69]:
# load dataset mem oob
img_dir = '<PATH-TO-IMAGENET-TRAIN>/train_blurred'
vicreg_dataset = ImageFolderIndexWithPath(img_dir, SSL_Transform(), vicreg_bbox_attk_B_idxs_mem_oob)
vicreg_dataloader_mem_oob = DataLoader(vicreg_dataset, batch_size = 1, shuffle = False, num_workers=8)


In [70]:
vicreg_confidences_mem_oob, vicreg_pred_idxes_mem_oob = get_confidence_and_topk(vicreg_bbox_attk_B_neighb_labels_mem_oob)

vicreg_confidences_mem_oob, vicreg_pred_idxes_mem_oob = vicreg_confidences_mem_oob, vicreg_pred_idxes_mem_oob['top_1']


predictions_to_save = []
for (x, y, idx, path), label, confidence, pred in zip(vicreg_dataloader_mem_oob, vicreg_bbox_attk_B_labels_mem_oob, vicreg_confidences_mem_oob, vicreg_pred_idxes_mem_oob):
    predictions_to_save.append([idx.item(), label[0], pred[0], confidence, path[0]])
columns = ['idx_vmo', 'label_vmo', 'pred_vmo', 'conf_vmo', 'path_vmo']

vicreg_predictions_knn_mem_oob = pd.DataFrame.from_records(np.array(predictions_to_save), columns=columns)


In [None]:
vicreg_predictions_knn_mem_oob_srt = vicreg_predictions_knn_mem_oob.sort_values(by = ['path_vmo'], ascending=True)


In [71]:
predictions_resnet = np.load('<PATH-RESNET-OUTPUT>/dejavu/resnet50_crops/eval/supervised_train_A_test_B_set_15_05_2024_wd_0.1_lars_mom_0.9_w_aug_eval.npy')
columns = ['label', 'pred', 'conf', 'path']

predictions_resnet = pd.DataFrame.from_records(np.array(predictions_resnet), columns=columns)
predictions_resnet['path'] = predictions_resnet['path'].str.split('/').str[-1].str.split('.').str[0]
predictions_resnet_srt = predictions_resnet.sort_values(by = ['path'], ascending=True)


In [72]:
predictions_resnet2 = np.load('<PATH-RESNET-OUTPUT2>/dejavu/resnet50_crops/eval/supervised_train_A_test_B_set_07_11_2024_wd_0.1_lars_mom_0.9_w_aug_eval.npy')
columns = ['label2', 'pred2', 'conf2', 'path2']
predictions_resnet2 = pd.DataFrame.from_records(np.array(predictions_resnet2), columns=columns)
predictions_resnet2['path2'] = predictions_resnet2['path2'].str.split('/').str[-1].str.split('.').str[0]
predictions_resnet2_srt = predictions_resnet2.sort_values(by = ['path2'], ascending=True)


In [73]:
import csv
import json

imagenet_label_path = '<PATH-IMAGENET-LABELS>/labels.txt'
nb_output_path = '<HOME-GROUNDED-SEGANY>/Grounded-Segment-Anything/output_May_03_2023_trainA_testB'
class_id2index = {}
classes = []

with open(imagenet_label_path) as folder2label:
    folder2label_reader = csv.reader(folder2label, delimiter=',')
    folder2label_lst = [line for line in folder2label_reader]
    file_names_all = []
    prediction_scores_all = []
    prediction_label_all = []
    targets_all = []
    for i, folder2label in enumerate(folder2label_lst):
        class_id, class_name = folder2label
        class_id2index[class_id] = i
        labels_input_path = os.path.join(nb_output_path, 'predictions_top_5_before_Aug_2nd', f'{class_id}_{class_name}_predictions.json')
        labels_input = json.loads(open(labels_input_path).read())
        file_names = labels_input.keys()
        predictions = labels_input.values()
        prediction_scores = [prediction[0] for prediction in predictions]
        prediction_label = [prediction[1] for prediction in predictions]
        targets = [f'{class_id}_{class_name}'] * len(predictions)
        classes.append(class_name)

        file_names_all.append(file_names)
        prediction_scores_all.append(prediction_scores)
        prediction_label_all.append(prediction_label)
        targets_all.append(targets)


In [74]:
file_names_all = list(flatten(file_names_all))
prediction_scores_all = list(flatten(prediction_scores_all))
prediction_label_all = list(flatten(prediction_label_all))
targets_all = list(flatten(targets_all))
len(targets_all)

139183

In [75]:
len(targets_all), len(file_names_all)

(139183, 139183)

In [87]:
columns = ['label_n', 'pred_n', 'conf_n', 'path_n']

predictions_nb = pd.DataFrame({'label_n': targets_all, 'pred_n': prediction_label_all, 'conf_n': prediction_scores_all, 'path_n': file_names_all})

predictions_nb['label_n'] = predictions_nb['label_n'].str.split('_').str[0].apply(lambda x: class_id2index[x])
predictions_nb['pred_n'] = predictions_nb['pred_n'].str.split('_').str[0].apply(lambda x: class_id2index[x])
predictions_nb['path_n'] = predictions_nb['path_n'].str.split('.').str[0]

predictions_nb_srt = predictions_nb.sort_values(by = ['path_n'], ascending=True)


In [77]:

vicreg_predictions_knn_mem_oob_srt['path_vmo'] = vicreg_predictions_knn_mem_oob_srt['path_vmo'].str.split('/').str[-1].str.split('.').str[0]
result = pd.merge(vicreg_predictions_knn_mem_oob_srt, predictions_resnet_srt, how="inner", left_on = 'path_vmo', right_on='path')

result = pd.merge(result, predictions_resnet2_srt, how="inner", left_on = 'path', right_on='path2')
result = pd.merge(result, predictions_nb_srt, how="inner", left_on = 'path', right_on='path_n')

cond6 = result['pred'] == result['label']
result = result[cond6]

cond7 = result['pred2'] == result['label2']
result = result[cond7]

cond8 = result['pred_n'] == result['label_n']
result = result[cond8]


In [78]:
corr_idx = np.array(list(result['idx_vmo']), dtype=int)  # convert string to int and load this the images based on this ids. Visualize some of them.

img_dir = '<PATH-TO-IMAGENET-TRAIN>/train_blurred'
corr_dataset = ImageFolderIndexWithPath(img_dir, SSL_Transform(), corr_idx)
corr_dataset_mem_oob = DataLoader(corr_dataset, batch_size = 1, shuffle = False, num_workers=8)


In [79]:
#np.save('DejaVuOOB/data/dataset_level_correlations_indices.npy', corr_idx)

In [81]:
import torchvision

import matplotlib.pyplot as plt

class InverseTransform:
    """inverses normalization of SSL transform """
    def __init__(self): 
        self.invTrans = torchvision.transforms.Compose([
        torchvision.transforms.Normalize(mean = [ 0., 0., 0. ],
        std = [ 1/0.229, 1/0.224, 1/0.225 ]),
        torchvision.transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
        std = [ 1., 1., 1. ]),
        ])

    def __call__(self, x): 
        return self.invTrans(x)

In [82]:
iTrans = InverseTransform()

In [83]:
def showImg(image, label):
    print('label:', label)
    print('image shape: ', image.shape)
    patch = iTrans(image.squeeze())
    patch = patch.permute(1,2,0)        
    plt.imshow(patch)
    plt.axis('off')
    plt.show()

In [86]:
'''for i, (x, y, idx, path) in enumerate(corr_dataset_mem_oob):
    print(path)
    showImg(x, y)
    if i > 30:
        break
'''

'for i, (x, y, idx, path) in enumerate(corr_dataset_mem_oob):\n    print(path)\n    showImg(x, y)\n    if i > 30:\n        break\n'