In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import time

from lib.utils import model_results_from_npz
from lib.projection import gram_schmidt, project_confidences
from lib.redistribute import redistribute_confidences_of_class

# ===== Dataset Config =====
CIFAR10_CLASSES = [
    "airplane","automobile","bird","cat","deer",
    "dog","frog","horse","ship","truck"
]
NUM_CLASSES = len(CIFAR10_CLASSES)

SEEDS = [42, 602, 311, 637, 800, 543, 969, 122, 336, 93]            # example seeds
EPOCHS = 50
REMOVED_CLASSES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# ===== Paths =====
DIR_ORI = "../results/cifar10/"
DIR_OUT = "../results/redis_cifar10/"
TIMER_DIR = "../analytics/CIFAR10/timer/"
os.makedirs(DIR_OUT, exist_ok=True)

# ===== Timer Store =====
timer_records = []

# ===== Redistribute Loop =====
for seed in SEEDS:
    file = f"cifar_resnet_s{seed}_e{EPOCHS}.npz"
    filepath = os.path.join(DIR_ORI, file)
    if not os.path.exists(filepath):
        print(f"Skip missing: {filepath}")
        continue

    # Load original once per seed
    model_ori = model_results_from_npz(filepath=filepath, num_classes=NUM_CLASSES)

    # Precompute averages 
    avg_confidences = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.float32)
    for i in range(NUM_CLASSES):
        avg_confidences[i] = np.mean(model_ori.confidences[i], axis=0)
    avg_confidences_ortho = gram_schmidt(avg_confidences)

    for removed_class in tqdm(REMOVED_CLASSES, desc=f"Seed {seed}", leave=False):
        start_time = time.perf_counter()

        proj_confidences = project_confidences(
            model_ori.confidences,
            avg_confidences_ortho,
            removed_class
        )

        avg_proj_confidences = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.float32)
        for i in range(NUM_CLASSES):
            avg_proj_confidences[i] = np.mean(proj_confidences[i], axis=0)

        # Redistribute per class
        redis = {}
        for cls in range(NUM_CLASSES):
            redis[cls] = redistribute_confidences_of_class(
                conf=model_ori.confidences[cls],
                proj_conf=proj_confidences[cls],
                avg_proj_conf=avg_proj_confidences[cls],
                removed_class=removed_class,
            )

        end_time = time.perf_counter()

        # Record timing
        duration_sec = end_time - start_time
        timer_records.append({
            "seed": seed,
            "removed_class": removed_class,
            "removed_name": CIFAR10_CLASSES[removed_class],
            "duration_sec": duration_sec
        })

        # Concatenate back
        all_confs = np.vstack([redis[cls] for cls in range(NUM_CLASSES)])
        all_targets = np.concatenate([
            np.full(redis[cls].shape[0], cls, dtype=int) for cls in range(NUM_CLASSES)
        ])
        all_preds = np.argmax(all_confs, axis=1)

        # Save (always with _rd suffix)
        out_path = os.path.join(DIR_OUT, f"cifar_resnet_s{seed}_e{EPOCHS}_rd{removed_class}.npz")
        np.savez_compressed(out_path, preds=all_preds, targets=all_targets, confs=all_confs)
        print(f"[Saved] {out_path}")

timer_df = pd.DataFrame(timer_records)
timer_csv_path = os.path.join(TIMER_DIR, "redistribution_timer.csv")
timer_df.to_csv(timer_csv_path, index=False)
print(f"\nAll timers saved to {timer_csv_path}")


Seed 42:  40%|████      | 4/10 [00:00<00:00, 30.95it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd2.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd6.npz


                                                       

[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s42_e50_rd9.npz


Seed 602:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd2.npz


Seed 602:  40%|████      | 4/10 [00:00<00:00, 32.77it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s602_e50_rd9.npz


Seed 311:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd2.npz


Seed 311:  40%|████      | 4/10 [00:00<00:00, 34.14it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s311_e50_rd9.npz


Seed 637:  40%|████      | 4/10 [00:00<00:00, 34.36it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd2.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s637_e50_rd9.npz


Seed 800:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd2.npz


Seed 800:  40%|████      | 4/10 [00:00<00:00, 33.78it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd6.npz


Seed 800:  80%|████████  | 8/10 [00:00<00:00, 34.20it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd7.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s800_e50_rd9.npz


Seed 543:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd2.npz


Seed 543:  40%|████      | 4/10 [00:00<00:00, 32.16it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s543_e50_rd9.npz


Seed 969:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd2.npz


Seed 969:  40%|████      | 4/10 [00:00<00:00, 34.24it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s969_e50_rd9.npz


Seed 122:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd2.npz


Seed 122:  40%|████      | 4/10 [00:00<00:00, 33.76it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s122_e50_rd9.npz


Seed 336:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd2.npz


Seed 336:  40%|████      | 4/10 [00:00<00:00, 34.09it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd6.npz


                                                        

[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s336_e50_rd9.npz


Seed 93:   0%|          | 0/10 [00:00<?, ?it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd0.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd1.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd2.npz


Seed 93:  40%|████      | 4/10 [00:00<00:00, 33.80it/s]

[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd3.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd4.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd5.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd6.npz


                                                       

[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd7.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd8.npz
[Saved] ../results/redis_cifar10/cifar_resnet_s93_e50_rd9.npz

All timers saved to ../analytics/CIFAR10/timer/redistribution_timer.csv


