In [1]:
import matplotlib.pyplot as plt
import torch
import os
import numpy as np

device='cuda:0'

In [None]:
root = os.path.expanduser('~/sparsity/runs/')
output_dir = os.path.expanduser('~/sparsity/dumps/')

In [None]:
def merge_splits(dir_to_merge: str):
    res = 0
    for split_name in os.listdir(dir_to_merge):
        split = torch.load(os.path.join(dir_to_merge, split_name), device)
        res = res + split
    return res

In [None]:
def get_histc(dir: str, checkpoint: int):
    n_layer = len(os.listdir(os.path.join(dir, 'g')))
    gs, gs_activated = [], []
    for i in range(n_layer):
        g_directory = os.path.join(dir, 'g', str(i), str(checkpoint))
        g_activated_directory = os.path.join(dir, 'g_activated', str(i), str(checkpoint))

        gs.append(merge_splits(g_directory))
        gs_activated.append(merge_splits(g_activated_directory))
    return gs, gs_activated

In [None]:
def get_checkpoints(dir: str):
    g_directory = os.path.join(dir, 'g', '0')
    return sorted([int(s) for s in os.listdir(g_directory)])

In [None]:
n_bin = 1000
color='blue'

def histogram(gs, gs_activated, color, n_bin=1000, save_path=None, bin_range=[-10, 5], display_range=[-6, 3]):
    fig, axs = plt.subplots(len(gs_activated), 1, sharex=True)
    for i, ax in enumerate(axs):
        ax.set_yticks([])
        if i == len(axs) - 1:
            ax.set_xlabel("Log10 absolute value of entries in g")
    hist_c_bound = bin_range
    bin_width = (hist_c_bound[1] - hist_c_bound[0]) / (n_bin + 1)
    bin_edges = np.linspace(hist_c_bound[0], hist_c_bound[1], n_bin + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    for ax, g, g_activated in zip(axs, gs, gs_activated):
        g = g.reshape(n_bin, -1).sum(dim=-1)
        g_activated = g_activated.reshape(n_bin, -1).sum(dim=-1) * (g.sum() / g_activated.sum() / 3)
        ax.bar(bin_centers, g.cpu(), width=bin_width, align='center', alpha=0.3, color=color)
        ax.bar(bin_centers, g_activated.cpu(), width=bin_width, align='center', color=color)
    plt.xlim(*display_range)
    if save_path:
        fig.savefig(save_path)
    # plt.close()

histogram(gs, gs_activated, color='blue')

In [None]:
for task in ["imagenet1k", "T5"]:
    for model_type in ["sparsified", "vanilla"]:
        dir = os.path.join(root, task, 'gradient_density', model_type, "pth")
        checkpoints = get_checkpoints(dir)
        output_directory = os.path.join(output_dir, task, 'gradient_density', model_type, str(checkpoint)+'.jpg')
        for checkpoint in checkpoints:
            gs, gs_activated = get_histc(dir, checkpoint)
            save_path = os.path.join(output_directory, str(checkpoint)+'.jpg')
            os.makedirs(output_directory, exist_ok=True)
            histogram(gs, gs_activated, 'red' if model_type == 'sparsified' else 'blue', save_path=save_path)