In [19]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
df = pd.DataFrame(columns=['class','grids','density', 'name', 'l_max', 'l_min', 'l_ratio', 'val_acc'])

In [20]:
dataset = 'cifar10' # cifar10 or cifar100
loss_types = ['CE']
train_rules = ['DRW']
sams = [False, True]
if dataset=='cifar10':
    classes = [0,1,2,3,4,5,6,7,8,9,-1]
elif dataset=='cifar100':
    classes = [0,49,99,-1] 
dataloader_hesss = ['val']

In [21]:
def gaussian(x, x0, sigma_squared):
    return np.exp(-(x0 - x)**2 /
                  (2.0 * sigma_squared)) / np.sqrt(2 * np.pi * sigma_squared)

def density_generate(eigenvalues,
                     weights,
                     num_bins=10000,
                     sigma_squared=1e-5,
                     overhead=0.01):

    eigenvalues = np.array(eigenvalues)
    weights = np.array(weights)

    lambda_max = np.mean(np.max(eigenvalues, axis=1), axis=0) + overhead
    lambda_min = np.mean(np.min(eigenvalues, axis=1), axis=0) - overhead

    grids = np.linspace(lambda_min, lambda_max, num=num_bins)
    sigma = sigma_squared * max(1, (lambda_max - lambda_min))

    num_runs = eigenvalues.shape[0]
    density_output = np.zeros((num_runs, num_bins))

    for i in range(num_runs):
        for j in range(num_bins):
            x = grids[j]
            tmp_result = gaussian(eigenvalues[i, :], x, sigma)
            density_output[i, j] = np.sum(tmp_result * weights[i, :])
    density = np.mean(density_output, axis=0)
    normalization = np.sum(density) * (grids[1] - grids[0])
    density = density / normalization
    return density, grids

In [22]:
for loss_type, train_rule in zip(loss_types, train_rules):
    for sam in sams:
        for dataloader_hess in dataloader_hesss:
            # eigenvalues_overall = np.empty((0,100))
            # weights_overall = np.empty((0,100))
            if sam == True:
                path = f'checkpoint/hessian_{dataset}_resnet32_{loss_type}_{train_rule}_exp_0.01_sam_0.8_sched_none_seed_None_0_{dataloader_hess}_{loss_type}_None_sample/'
            else:
                path = f'checkpoint/hessian_{dataset}_resnet32_{loss_type}_{train_rule}_exp_0.01_seed_None_0_{dataloader_hess}_{loss_type}_None_sample/'
            val_accs = pd.read_csv(path+'accuracies.csv')
            for cls in classes:
                eigenvalues  = np.load(path+f'{cls}_density_eigen.npy')
                weights = np.load(path+f'{cls}_density_weights.npy')
                print(cls,eigenvalues.shape)
                # eigenvalues_overall = np.append(eigenvalues_overall, eigenvalues, axis=0) 
                # weights_overall = np.append(weights_overall, weights, axis=0)   
                density, grids = density_generate(eigenvalues, weights)
                density = density + 1e-7
                lambda_min = np.min(eigenvalues)
                lambda_max = np.max(eigenvalues)
                lambda_ratio = lambda_max/lambda_min
                val_acc = val_accs[val_accs['Class']==cls]['Accuracy'].values[0]
                print(cls,val_acc)
                # val_acc = val_accs[cls]
                name = dataset+loss_type+train_rule+str(sam)+dataloader_hess
                name = f'dataset:{dataset}<br>loss_type:{loss_type}<br>train_rule:{train_rule}<br>sam:{sam}<br>dataloader_hess:{dataloader_hess}<br>'
                df_temp = pd.DataFrame({'class':cls,'grids':grids,'density':density, 'name':name, 'text':cls, 'l_max':lambda_max, 'l_min':lambda_min, 'l_ratio':lambda_ratio, 'val_acc':val_acc})
                df = pd.concat([df,df_temp],ignore_index=True)
            # density, grids = density_generate(eigenvalues_overall, weights_overall)
            # density = density + 1e-7
            # lambda_min = np.min(eigenvalues_overall)
            # lambda_max = np.max(eigenvalues_overall)
            # lambda_ratio = lambda_max/lambda_min
            # name = f'dataset:{dataset}<br>loss_type:{loss_type}<br>train_rule:{train_rule}<br>sam:{sam}<br>dataloader_hess:{dataloader_hess}<br>'
            # df_temp = pd.DataFrame({'class':'overall','grids':grids,'density':density, 'name':name, 'text':cls, 'l_max':lambda_max, 'l_min':lambda_min, 'l_ratio':lambda_ratio, val_acc:np.mean(np.array(val_accs))})
            # df = pd.concat([df,df_temp],ignore_index=True)

# sort by class and x values
df = df.sort_values(by=['class','grids'])

0 (1, 100)
0 94.9
1 (1, 100)
1 97.8
2 (1, 100)
2 81.3
3 (1, 100)
3 74.700005
4 (1, 100)
4 80.3
5 (1, 100)
5 68.9
6 (1, 100)
6 78.4
7 (1, 100)
7 65.600006
8 (1, 100)
8 56.9
9 (1, 100)
9 62.000004
-1 (1, 100)
-1 76.08
0 (1, 100)
0 87.4
1 (1, 100)
1 93.3
2 (1, 100)
2 70.4
3 (1, 100)
3 60.800003
4 (1, 100)
4 79.0
5 (1, 100)
5 68.5
6 (1, 100)
6 82.100006
7 (1, 100)
7 65.700005
8 (1, 100)
8 78.100006
9 (1, 100)
9 75.9
-1 (1, 100)
-1 76.119995


In [23]:
fig = px.line(df, x="grids", y="density", color='name', facet_col='class', hover_data=['l_max','l_min','l_ratio', 'val_acc'], facet_row='name', log_y=True)
# fig.update_xaxes(matches=None)
# fig.update_yaxes(matches=None)
fig.write_html(f'checkpoint/{dataset}.html')
# fig.show()