In [None]:
import pickle
from torchvision import transforms
import sys
import os
sys.path.append(os.path.abspath(".."))
from models import model_dict
from utils import NormalizeByChannelMeanStd
import numpy as np
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from dataset import prepare_train_test_dataset
from torch.utils.data import DataLoader, Dataset, Subset
import torch
import pickle
from itertools import cycle
from utils.evaluation import Hook_handle, analysis, get_micro_eval, get_acc, get_micro_eval_seperate_correct
import pandas as pd
import argparse
import random
import copy
from types import SimpleNamespace
import seaborn as sns
import matplotlib.pyplot as plt

dataset = ['cifar10', 'cifar100', 'TinyImagenet', 'TinyImagenet']
architecture = ['resnet18', 'resnet50', 'resnet18', 'vgg16_bn']




In [None]:

backdata_ret = {}
for arch, data in zip(architecture, dataset):
    for us in range(3):
        with open(f"assets/figures/conf_change/backdata/{data}_{arch}_unlearn_seed_{us}_confidence_change.pickle", "rb") as f:
            seed_dict = pickle.load(f)
            backdata_ret[(data, arch, us)] = seed_dict

In [None]:
from statsmodels.nonparametric.smoothers_lowess import lowess
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec

colors = ['tab:blue', 'tab:orange', 'tab:green'] 
colors_line = ['darkblue', 'darkorange', 'darkgreen']

for arch, data in zip(architecture, dataset):
    for us in range(3):
        seed_dict = backdata_ret[(data, arch, us)]



        # Figure: Confidence Difference Plot
        fig1 = plt.figure(figsize=(4, 2))
        fig1.subplots_adjust(left=0.15, right=0.95, bottom=0.15, top=0.9)
        
        gs1 = GridSpec(1, 1, figure=fig1)
        ax1 = fig1.add_subplot(gs1[0])
        # fig1, ax1 = plt.subplots(figsize=(4, 3), constrained_layout=True)
        legend_elements = []

        for i in range(3):
            color = colors[i]
            color_line = colors_line[i]
            orig = seed_dict[i]['orig_confidence']["forget"].numpy()
            retrain = seed_dict[i]['retrain_confidence']["forget"].numpy()

            sort_idx = np.argsort(orig)
            diff = np.abs(retrain - orig)[sort_idx]
            ax1.plot(diff, 'o', label=f'Seed {i}', alpha=0.15, markersize=4, color=color)

            smoothed = lowess(diff, np.arange(len(diff)), frac=0.2)
            ax1.plot(smoothed[:, 0], smoothed[:, 1], linewidth=2, color=color_line, linestyle='--')

            
            legend_elements.append(
                Line2D([0], [0], marker='o', color='none', label=f'Seed {i}',
                       markerfacecolor=color, markersize=6, markeredgewidth=0)
            )
            legend_elements.append(
                Line2D([0], [0], linestyle='--', color=color_line, label=f'Seed {i} (trend)', linewidth=2)
            )

        ax1.set_ylabel("|Retrain - Original|")
        # ax1.set_title("Confidence Difference[Forget]")
        ax1.set_xlabel("Sample Index (Original Confidence Ascending Order)")
        ax1.grid(True)

        fig1.legend(handles=legend_elements, loc='upper center', ncol=3, bbox_to_anchor=(0.5, 1.23), columnspacing=0.8, handletextpad=0.4)
        # fig1.tight_layout()
        fig1.savefig(f"assets/figures/conf_change/{data}_{arch}_unlearn_seed_{us}_confidence_diff.pdf", dpi=300, bbox_inches='tight', format='pdf')
        plt.show()
        plt.close(fig1)


        # : Correctness Density Plot
        fig2 = plt.figure(figsize=(4, 2))
        fig2.subplots_adjust(left=0.15, right=0.95, bottom=0.15, top=0.9)
        gs2 = GridSpec(1, 1, figure=fig2)
        ax_density = fig2.add_subplot(gs2[0])
        ax_correct = ax_density.twinx()

        legend_correct = []
        legend_wrong = []
        legend_scatter = []

        for i in range(3):
            color = colors[i]
            orig = seed_dict[i]['orig_confidence']["forget"].numpy()
            retrain = seed_dict[i]['retrain_confidence']["forget"].numpy()
            corr = seed_dict[i]['retrain_correct']["forget"].numpy()

            sort_idx = np.argsort(orig)
            corr_sorted = corr[sort_idx]

            # correctness 
            p = ax_correct.plot(
                corr_sorted, 'o', label=f'Seed {i}', alpha=0.7, markersize=4, color=color
            )
            legend_scatter.append(
                Line2D([0], [0], marker='o', color='none', label=f'Seed {i}',
                    markerfacecolor=color, markersize=6, markeredgewidth=0)
            )

            indices = np.arange(len(corr_sorted))
            correct_idx = indices[corr_sorted == 1]
            wrong_idx = indices[corr_sorted == 0]

            # density plot 
            line_correct = sns.kdeplot(correct_idx, ax=ax_density, label=None, color=color,
                                        linestyle='-', bw_adjust=1.5, clip=(0, len(corr_sorted)-1))
            legend_correct.append(
                Line2D([0], [0], linestyle='-', color=color, label=f'Seed {i} (correct)', linewidth=2)
            )

            line_wrong = sns.kdeplot(wrong_idx, ax=ax_density, label=None, color=color,
                                    linestyle='--', bw_adjust=1.5, clip=(0, len(corr_sorted)-1))
            legend_wrong.append(
                Line2D([0], [0], linestyle='--', color=color, label=f'Seed {i} (wrong)', linewidth=2)
            )

        ax_density.set_ylabel("Density")
        # ax_correct.set_ylabel("Correctness")
        ax_density.set_xlabel("Sample Index (Original Confidence Ascending Order)")
        ax_correct.set_yticks([0, 1])
        ax_correct.set_yticklabels(['Wrong', 'Correct'])
        # ax_density.set_title("Retrain Correctness[Forget]")
        ax_density.grid(True)

      
        final_legend = legend_scatter + legend_correct + legend_wrong
        fig2.legend(handles=final_legend, loc='upper center', ncol=3, bbox_to_anchor=(0.55, 1.33), columnspacing=1.0, handletextpad=0.4)

        fig2.savefig(
            f"assets/figures/conf_change/{data}_{arch}_unlearn_seed_{us}_correctness_kde.pdf",
            dpi=300, bbox_inches='tight', format='pdf'
        )
        plt.show()
        plt.close(fig2)


        print("done")
