In [1]:
import os
import numpy as np
from lib.utils import model_results_from_npz

CIFAR10_CLASSES = [
    "airplane","automobile","bird","cat","deer",
    "dog","frog","horse","ship","truck"
]
NUM_CLASSES = len(CIFAR10_CLASSES)
SEEDS = [42, 602, 311, 637, 800, 543, 969, 122, 336, 93]
EPOCHS = 50

DIR_RETRAIN = "../results/cifar10/"
DIR_REDIS   = "../results/redis_cifar10/"

best_match = {"seed": None, "removed_class": None, "match_ratio": -1.0}

for seed in SEEDS:
    for removed_class in range(NUM_CLASSES):
        retrain_path = os.path.join(DIR_RETRAIN, f"cifar_resnet_s{seed}_e{EPOCHS}_r{removed_class}.npz")
        redis_path   = os.path.join(DIR_REDIS, f"cifar_resnet_s{seed}_e{EPOCHS}_rd{removed_class}.npz")

        if not os.path.exists(retrain_path) or not os.path.exists(redis_path):
            continue

        model_retrain = model_results_from_npz(retrain_path, NUM_CLASSES)
        model_redis   = model_results_from_npz(redis_path, NUM_CLASSES)

        # Concatenate all class predictions for fair comparison
        retrain_preds = np.concatenate([model_retrain.preds[cls] for cls in range(NUM_CLASSES)])
        redis_preds   = np.concatenate([model_redis.preds[cls] for cls in range(NUM_CLASSES)])

        match_ratio = np.mean(retrain_preds == redis_preds)

        if match_ratio > best_match["match_ratio"]:
            best_match.update({"seed": seed, "removed_class": removed_class, "match_ratio": match_ratio})

print("Best matching seed and removed class:")
print(best_match)


Best matching seed and removed class:
{'seed': 311, 'removed_class': 3, 'match_ratio': np.float64(0.9233)}
