In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import json
from minio_obj_storage import get_numpy_from_cloud
'''
Load FZ scores from https://pluskid.github.io/influence-memorization/#cifar100-dl for cifar100
Get the top-k indices of images with high memorization
'''
top_k = 5000
npz = np.load('./analysis_checkpoints/cifar100/cifar100_infl_matrix.npz', allow_pickle=True)
fz_scores = pd.DataFrame.from_dict({item: npz[item] for item in ['tr_labels', 'tr_mem']})
fz_scores.sort_values(by='tr_mem', inplace=True, ascending=False)
top_k_indices = fz_scores.index[:top_k]

epoch_vs_scores = None
fz_top_scores = None
epochs = range(0, 294, 1)
container_name = 'leraning-dynamics-scores'
container_dir = 'cifar100'
arch = 'resnet18'
metric_type = 'loss'


# Load precomputed cifar100 curvature scores
for epoch in  epochs:
    score_file_name = f'{metric_type}_resnet18_wd1_{epoch}.npy'
    scores_for_epoch = get_numpy_from_cloud(container_name, container_dir, score_file_name)
    if epoch_vs_scores is None:
        epoch_vs_scores = scores_for_epoch
        fz_top_scores = scores_for_epoch[top_k_indices]
    else:
        epoch_vs_scores = np.row_stack([epoch_vs_scores, scores_for_epoch])
        fz_top_scores = np.row_stack([fz_top_scores, scores_for_epoch[top_k_indices]])

plt.plot(epochs, epoch_vs_scores.mean(1))
plt.plot(epochs, fz_top_scores.mean(1))
plt.show()

In [None]:
import numpy as np
from utils.load_dataset import load_dataset
import logging
import json

epochs_to_avg = np.arange(150, 290, 1)
epoch_to_idx = lambda x: x // 1
idxs = list(map(epoch_to_idx, epochs_to_avg))
avg_score = torch.Tensor(epoch_vs_scores[idxs]).mean(0)
sorted_score, indices = torch.sort(avg_score, stable=True, descending=True)
dataset_order = torch.load("./index/data_index_cifar100.pt") 
with open('./config.json', 'r') as f:
    config = json.loads(f.read())

logger = logging.getLogger(f'Analyze Duplicates')
dataset = load_dataset(
    logger=logger,
    dataset="cifar100",
    train_batch_size=256,
    test_batch_size=256,
    val_split=0.0,
    root_path=config['data_dir'],
    augment=False,
    shuffle=False,
    random_seed=0)

images = []
labels = []

for idx in indices[:200]:
    image, label = dataset.train_loader.dataset.__getitem__(idx)
    images.append(image)
    labels.append((label, idx))

In [None]:
class_idx = json.load(open("./analysis_checkpoints/cifar100/cifar100_class_index.json"))
idx2label = class_idx['names']

r = 6
c = 8
fig, axes = plt.subplots(r, c)
fig.set_size_inches(8, 9)
plt.axis('off')

for idx, (image, label) in enumerate(zip(images, labels)):
    if idx >= r*c:
        break
    image_np = image.permute(1,2,0).numpy()
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
    axes[idx // c, idx % c].imshow(image_np)
    axes[idx // c, idx % c].set_title(f"{label[1]}\n{idx2label[label[0]].split('_')[0].capitalize()}")
    axes[idx // c, idx % c].set_axis_off()


plt.savefig(f"./output/{metric_type}_duplicates.svg")
fig.show()
