In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from utils import COCOVis
import torch
def plot_influential_images(query_path, indices, title_prefix, results, temperature, coco_vis):
    """
    Helper function to plot influential images.
    
    Args:
        query_path (str): Path to the query image.
        indices (list): Indices of images to plot.
        title_prefix (str): Title prefix for the plot figure.
        results (dict): Influence data containing influence values and ranks.
        temperature (float): Temperature for softmax normalization.
        coco_vis (object): Visualization object to access dataset images.
    """
    plt.figure(figsize=(20, 2))
    plt.subplot(1, len(indices) + 1, 1)
    plt.imshow(plt.imread(query_path))
    plt.title('Query')
    plt.axis('off')
    plt.subplots_adjust(wspace=0.01)

    for i, idx in enumerate(indices):
        plt.subplot(1, len(indices) + 1, i + 2)
        plt.imshow(np.asarray(coco_vis[idx][0]))
        infl = results["influence"][idx]
        infl_score = torch.from_numpy(results["influence"] / temperature).softmax(0)[idx]
        plt.title(f'infl: {infl:.2e}\n{infl_score:.2e}')
        plt.axis('off')
    print(f"=={title_prefix}==")
    # plt.suptitle(title_prefix, y=1.2)  # Adding a custom title for clarity
    plt.show()


In [2]:
pretrain_loss = np.load('results/pretrain_loss.npy')
# load pickled data
attribution_list = []
for i in range(1,6):
    filename = f'/tss/kinwai/AttributeByUnlearning/mscoco/fisher_reproduce/fisher_{i}_influence_0_ulearnsteps_1_accum625_cross-attn-kv.pkl'
    with open(f'{filename}', 'rb') as f:
        attribution_list.append(pickle.load(f))
# get sample idx from the file name
# sample_idx = int(filename.split('_')[1])
sample_idx = 0
dataroot = 'data/mscoco'
temperature = 1e-4

In [None]:
# Visualize top, middle, and bottom influential images
coco_vis = COCOVis(path=dataroot, split='train')
for epoch, results in enumerate(attribution_list):
    print(f'****************Epoch: {epoch}****************')
    query_img_path = f'{dataroot}/sample/{sample_idx}.png'
    top_10_indices = results['rank'][:10]
    middle_10_indices = results['rank'][results['rank'].shape[0] // 2:results['rank'].shape[0] // 2 + 10]
    bottom_10_indices = results['rank'][-10:]    
    plot_influential_images(query_img_path, top_10_indices, "Top 10 Influential Images", results, temperature, coco_vis)
    plot_influential_images(query_img_path, middle_10_indices, "Middle 10 Influential Images", results, temperature, coco_vis)
    plot_influential_images(query_img_path, bottom_10_indices, "Bottom 10 Influential Images", results, temperature, coco_vis)


In [None]:
# Plot the histogram
# make plot longer and shorter
for epoch, results in enumerate(attribution_list):
    plt.figure(figsize=(16, 3))
    plt.hist(
        results["influence"],
        bins=1000,
        density=True,
        alpha=0.7,
        color='blue'
        )
    plt.title("Distribution of attribution")
    # make x y label larger
    plt.xlabel("Value", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    # make x and y axis font larger
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.show()

In [None]:
filename = f'/tss/kinwai/AttributeByUnlearning/mscoco/fisher_reproduce/fisher_5_influence_0_ulearnsteps_1_accum625_cross-attn-kv.pkl'
with open(f'{filename}', 'rb') as f:
    reproduce = pickle.load(f)

filename = f'/tss/kinwai/AttributeByUnlearning/mscoco/results_ablation/influence_0_ulearnsteps_1_cross-attn-kv.pkl'
with open(f'{filename}', 'rb') as f:
    original = pickle.load(f)


In [None]:
# Plot the histogram
# make plot longer and shorter
fig = plt.figure(figsize=(16, 3))
ax = fig.add_subplot(111)
ax.hist(
    [original["influence"]],
    bins=1000,
    density=True,
    alpha=0.7,
    color=['blue'],
    )

ax.set_title("Distribution of attribution")
# make x y label larger
ax.set_xlabel("Value", fontsize=14)
ax.set_ylabel("Density", fontsize=14)
# make x and y axis tick labels larger
ax.tick_params(axis='both', which='major', labelsize=14)
ax.legend(loc='upper left')
ax.set_xlim(-0.0001, 4e-3)
ax.set_ylim(0, 850)
ax.grid(True)

In [None]:
# Plot the histogram
# make plot longer and shorter
fig = plt.figure(figsize=(16, 3))
ax = fig.add_subplot(111)
ax.hist(
    [reproduce["influence"]],
    bins=1000,
    density=True,
    alpha=0.7,
    color=['red'],
    )

ax.set_title("Distribution of attribution")
# make x y label larger
ax.set_xlabel("Value", fontsize=14)
ax.set_ylabel("Density", fontsize=14)
# make x and y axis tick labels larger
ax.tick_params(axis='both', which='major', labelsize=14)
ax.legend(loc='upper left')
ax.set_xlim(-0.0001, 4e-3)
ax.set_ylim(0, 850)
ax.grid(True)

In [None]:
# Plot the histogram
# make plot longer and shorter
fig = plt.figure(figsize=(16, 3))
ax = fig.add_subplot(111)
ax.hist(
    [original["influence"], reproduce["influence"]],
    bins=1000,
    density=True,
    alpha=0.7,
    color=['blue', 'red'],
    label=['original', 'reproduced']
    )

ax.set_title("Distribution of attribution")
# make x y label larger
ax.set_xlabel("Value", fontsize=14)
ax.set_ylabel("Density", fontsize=14)
# make x and y axis tick labels larger
ax.tick_params(axis='both', which='major', labelsize=14)
ax.legend(loc='upper left')
ax.set_xlim(-0.0001, 4e-3)


In [6]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

def animate_fisher_histograms(fisher_infos, bins=1000, interval=500, title="Fisher Info Distribution"):
    """
    Animate histograms for multiple sets of Fisher Information.
    
    Args:
        fisher_infos (list of tensors): A list of Fisher information tensors.
        bins (int): Number of bins for the histogram.
        interval (int): Interval between frames in milliseconds.
        title (str): Title for the animation.
    """
    fig, ax = plt.subplots(figsize=(12, 4))
    
    def update(frame):
        ax.clear()
        data = fisher_infos[frame]
        ax.hist(data, bins=bins, density=True, alpha=0.7, color='blue')
        ax.set_title(f"{title} (Epoch {frame + 1}/{len(fisher_infos)})")
        ax.set_xlabel("Value", fontsize=14)
        ax.set_ylabel("Density", fontsize=14)
        ax.set_xlim(-0.0001, 4e-3)
        ax.set_ylim(0, 850)
        ax.tick_params(axis='both', which='major', labelsize=14)
        ax.grid(True)

    ani = animation.FuncAnimation(fig, update, frames=len(fisher_infos), interval=interval)
    # ani.save('fisher_info_histogram_animation.mp4', writer='ffmpeg')
    ani.save("fisher_info_histogram_animation.gif", writer="pillow")


In [None]:
fisher_infos = [fisher_info["influence"] for fisher_info in attribution_list]
animate_fisher_histograms(fisher_infos)
