In [1]:
import pickle
from torchvision import transforms
import sys
import os
sys.path.append(os.path.abspath(".."))
from models import model_dict
from utils import NormalizeByChannelMeanStd
import numpy as np
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from dataset import prepare_train_test_dataset
from torch.utils.data import DataLoader, Dataset, Subset
import torch
import pickle
from itertools import cycle
from utils.evaluation import Hook_handle, analysis, get_micro_eval, get_acc, get_micro_eval_seperate_correct
import pandas as pd
import argparse
import random
import copy
from types import SimpleNamespace
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F
from statsmodels.nonparametric.smoothers_lowess import lowess
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec


dataset = ['cifar10', 'cifar100', 'TinyImagenet', 'TinyImagenet']
architecture = ['resnet18', 'resnet50', 'resnet18', 'vgg16_bn']
methods = ["randomlabel", "randomlabel_salun"] # "neggrad", "GAGD",



In [None]:
def compute_js_distance(softmax: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    softmax: Tensor of shape (N, num_classes), output probabilities
    labels: Tensor of shape (N,), integer class labels

    Returns:
        js_distance: Tensor of shape (N,), Jensenâ€“Shannon distances
    """
    num_classes = labels.max().item() + 1

    one_hot = F.one_hot(labels, num_classes=num_classes).float()  # shape: (N, num_classes)

    # Clamp to avoid log(0)
    eps = 1e-8
    P = softmax.clamp(min=eps)
    Q = one_hot.clamp(min=eps)

    M = 0.5 * (P + Q)

    # KL divergence: sum over classes
    kl_PM = torch.sum(P * torch.log(P / M), dim=1)
    kl_QM = torch.sum(Q * torch.log(Q / M), dim=1)

    # Jensen-Shannon Divergence and Distance
    js_divergence = 0.5 * (kl_PM + kl_QM)
    js_distance = torch.sqrt(js_divergence)

    return js_distance  # shape: (N,)

colors = ['tab:blue', 'tab:orange', 'tab:green'] 

for obs in ['retain', 'forget', 'test']:
    for method in methods:
        for arch, data in zip(architecture, dataset):
            for us in range(3):
                fig2 = plt.figure(figsize=(4, 2))
                fig2.subplots_adjust(left=0.15, right=0.95, bottom=0.15, top=0.9)
                gs2 = GridSpec(1, 1, figure=fig2)
                ax_density = fig2.add_subplot(gs2[0])
                legends = []
                for color, target, b_traget in zip(colors, ["retrain", "basic", "sailency"], ["Retrain", "Baseline", "Ours"]):
                    if target == "retrain":
                        with open(f"assets/softmaxspace/backdata/{target}/{data}_{arch}_unlearn_seed_{us}_confidence_change.pickle", "rb") as f:
                            seed_dict = pickle.load(f)
                    else:
                        with open(f"assets/softmaxspace/backdata/{target}/{method}_{data}_{arch}_unlearn_seed_{us}_confidence_change.pickle", "rb") as f:
                            seed_dict = pickle.load(f)
                    

                    if target == "retrain":
                        legends.append(Line2D([0], [0], linestyle='-', color="tab:red", label="Original", linewidth=2))
                        legends.append(Line2D([0], [0], linestyle=':', color=color, label=b_traget, linewidth=2))
                        for i in range(3):
                            js_distance_orig = compute_js_distance(seed_dict[i]['orig_softmax'][obs], seed_dict[i]['orig_label'][obs])
                            sns.kdeplot(js_distance_orig, ax=ax_density,
                                        linestyle='-', bw_adjust=1, clip=(0, None), label=None, color="tab:red", linewidth=2, zorder=1)

                            js_distance_retrain = compute_js_distance(seed_dict[i]['retrain_softmax'][obs], seed_dict[i]['retrain_label'][obs])
                            sns.kdeplot(js_distance_retrain, ax=ax_density,
                                    linestyle=':', bw_adjust=1, clip=(0, None), label=None, color=color, linewidth=2, zorder=3)
                        
                    else:
                        legends.append(Line2D([0], [0], linestyle='--', color=color, label=b_traget, linewidth=2))
                        for i in range(3):
                            js_distance_retrain = compute_js_distance(seed_dict[i]['retrain_softmax'][obs], seed_dict[i]['retrain_label'][obs])
                            sns.kdeplot(js_distance_retrain, ax=ax_density,
                                    linestyle='--', bw_adjust=1, clip=(0, None), label=None, color=color, linewidth=2, zorder=2)
                        
                print(f"{data.upper()} - {arch.upper()} - {method.upper()} - {obs.upper()}")


                ax_density.set_ylabel("Density")
                ax_density.set_xlabel("JSD")
                ax_density.grid(True)
                fig2.legend(handles=legends, loc='upper center', ncol=4, bbox_to_anchor=(0.55, 1.1), columnspacing=0.8, handletextpad=0.3)
                
                os.makedirs(f"assets/softmaxspace/figure", exist_ok=True)
                fig2.savefig(
                    f"assets/softmaxspace/figure/{method}_{data}_{arch}_unlearn_seed_{us}_softmaxspace_{obs}.pdf",
                    dpi=300, bbox_inches='tight', format='pdf'
                )
                plt.show()
                plt.close(fig2)
