In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../../project")
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import wandb

api = wandb.Api()
base = "mxmn/concat_moons/"

In [None]:
def runs_from_sweeps(sweep_ids):
    runs = []
    for id in sweep_ids:
        for run in api.sweep(base + id).runs:
            runs.append(run)
    return runs

def to_numpy(d: dict):
    return np.array(list(d.values()))

def mean_and_std(d: dict):
    x = to_numpy(d)
    return x.mean(axis=1), x.std(axis=1)

def to_list(d: dict):
    keys = sorted(d.keys(), reverse=False)
    result = []
    for key in keys:
        result.append(d[key])
    return list(keys), result

def bar(ax, x, y, title, color, ylabel=None):
    _,y = to_list(y)
    ax.grid('on', color='#CCCCCC', linestyle=':')
    ax.set_title(title, pad=0, fontweight="bold", x=1.1,y=0.5)
    ax.set_ylabel(ylabel) if ylabel is not None else None
    ax.bar(x, y, color=color, alpha=1, linewidth=1, edgecolor='black')
    ax.spines['right'].set_color((1.,1.,1.))
    ax.spines['top'].set_color((1.0,1.0,1.0))

In [None]:
data_1 = None  # reset data_1

In [None]:
data_2 = None  # reset data_2

In [None]:
data_3 = None  # reset data_3

In [None]:
description = '''
Histograms of models with same architecture (4, 410, 410, 2).
but different number of pruning levels to reach the same pruning target of (4, 8, 8, 2).
'''

runs = runs_from_sweeps(['ix4onq8c'])

def get_data_experiment_1(runs):
    x_axis = set()
    data = defaultdict(lambda: defaultdict(int))
    for run in runs:
        level = run.config['pruning_levels']
        x_axis.add(level)

        key='untapped-potential'
        hist = run.history(keys=[key])
        potential = hist[key]
        degraded = (potential < 0).any()
        split = (potential == 0).any()

        data['s'][level] += 1 if (split and not degraded) else 0
        data['sd'][level] += 1 if (degraded and split) else 0
        data['d'][level] += 1 if (degraded and not split) else 0
        data['no'][level] += 1 if ((not degraded) and (not split)) else 0
        data['pr'][level] = run.config['pruning_rate']

    return data, x_axis

def plot_experiment_1(data, x_axis, save=False):
    pruning_levels = sorted(x_axis, reverse=False)
    pruning_rates = [f'{int(data["pr"][i]*100)}%' for i in pruning_levels]

    xticks = list(range(len(pruning_levels)))
    xticklabels = [f'{hd}\n{pr}' for hd, pr in zip(pruning_levels, pruning_rates)]

    fig, axs = plt.subplots(4, 1, figsize=(10, 12), sharex=True, sharey=True)
    plt.setp(axs, xticks=xticks, xticklabels=xticklabels, yticks=[0,1,2,3,4])

    bar(
        title='Split',
        x=xticks, y=data['s'], 
        color='#84DCC6',
        ax=axs[0], 
    )
    bar(
        title='Split &\n Degrade',
        x=xticks, y=data['sd'], 
        color='#A5FFD6',
        ax=axs[1], 
    )
    bar(
        title='Degrade',
        x=xticks, y=data['d'], 
        color='#FF686B',
        ax=axs[2], 
    )
    bar(
        title='None',
        x=xticks, y=data['no'], 
        color='#FFDEC0',
        ax=axs[3], 
    )
    fig.text(0.04, 0.5, 'num networks', va='center', rotation='vertical', fontdict={'size':12})
    fig.text(0.5, 0.04, 'pruning levels \n pruning rate (round to int)', ha='center', fontdict={'size':12})
    plt.show()

    if save:
        fig.savefig("exp1.svg", format = 'svg', dpi=300)

data_1, x_axis_1 = get_data_experiment_1(runs) if False else (data_1, x_axis_1)
    
plot_experiment_1(data_1, x_axis_1, save=False)

In [None]:
runs = runs_from_sweeps(sweep_ids=['p30nbq46', 'r0b16unk'])

def get_data_experiment_2(runs):
    x_axis = set()
    data = defaultdict(lambda: defaultdict(int))

    for run in runs:
        hidden_dim = run.config['model_shape'][1]
        x_axis.add(hidden_dim)

        un_pot = 'untapped-potential'
        hist = run.history(keys=[un_pot])
        untapped_potential = hist[un_pot]
        
        degraded = (untapped_potential < 0).any()
        split = (untapped_potential == 0).any()
        
        data['s'][hidden_dim] += 1 if (split and not degraded) else 0
        data['sd'][hidden_dim] += 1 if (degraded and split) else 0
        data['d'][hidden_dim] += 1 if (degraded and not split) else 0
        data['no'][hidden_dim] += 1 if ((not degraded) and (not split)) else 0
        data['pr'][hidden_dim] = run.config['pruning_rate']

    return data, x_axis

def plot_experiment_2(data, x_axis, save=False):
    hidden_dims = sorted(x_axis, reverse=False)
    pruning_rates = [f'{int(data["pr"][i]*100)}%' for i in hidden_dims]

    xticks = list(range(len(hidden_dims)))
    xticklabels = [f'{hd}\n{pr}' for hd, pr in zip(hidden_dims, pruning_rates)]

    fig, axs = plt.subplots(4, 1, figsize=(10, 12), sharex=True, sharey=True)
    plt.setp(axs, xticks=xticks, xticklabels=xticklabels, yticks=[0,1,2,3,4])

    bar(
        title='Split',
        x=xticks, y=data['s'], 
        color='#84DCC6',
        ax=axs[0], 
    )
    bar(
        title='Split &\n Degrade',
        x=xticks, y=data['sd'], 
        color='#A5FFD6',
        ax=axs[1], 
    )
    bar(
        title='Degrade',
        x=xticks, y=data['d'], 
        color='#FF686B',
        ax=axs[2], 
    )
    bar(
        title='None',
        x=xticks, y=data['no'], 
        color='#FFDEC0',
        ax=axs[3], 
    )
    fig.text(0.04, 0.5, 'num networks', va='center', rotation='vertical', fontdict={'size':12})
    fig.text(0.5, 0.04, 'num hidden neurons \n pruning rate (round to int)', ha='center', fontdict={'size':12})
    plt.show()

    if save:
        fig.savefig("exp2.svg", format = 'svg', dpi=300)

data_2, x_axis_2 = get_data_experiment_2(runs) if False else (data_2, x_axis_2)

plot_experiment_2(data_2, x_axis_2, save=False)

In [None]:
runs = runs_from_sweeps(sweep_ids=['wpcowdl5', 'j9lvvxjg', 'a5kr3muw'])

def get_data_experiment_3(runs):
    x_axis = set()
    data = defaultdict(lambda: defaultdict(list))
    
    for run in runs:
        if run.name == 'bright-sweep-19': continue
        x = run.config['model_shape'][1]
        hist = run.history()
        idc = np.where(hist['untapped-potential'].values == 0)[0]
        x_axis.add(x)

        if len(idc) > 0:
            begin, end = idc[0], idc[-1]
            for key in hist.keys():
                data[f'{key}-begin'][x].append(hist[key][begin])
                data[f'{key}-end'][x].append(hist[key][end])

        data['extension_levels'][x] = run.config['extension_levels']

    return data, x_axis

def mean_std(xrange, x_values, y_values):

    y_means, y_stds = [], []
    for x in xrange:   

        if x in x_values:
            y = y_values.pop(0)
            y_mean, y_std = np.mean(y), np.std(y)
        else:
            y_mean, y_std = np.nan, np.nan

        y_means.append(y_mean), y_stds.append(y_std)

    return y_means, y_stds

def mean_min_max(xrange, x_values, y_values):

    y_means, ymins, ymaxs = [], [], []
    for x in xrange:   

        if x in x_values:
            y = y_values.pop(0)
            y_mean = np.mean(y)
            ymin, ymax = y_mean - np.min(y), np.max(y) - y_mean
        else:
            y_mean, ymin, ymax = np.nan, np.nan, np.nan

        y_means.append(y_mean), ymins.append(ymin), ymaxs.append(ymax)

    return y_means, np.stack([ymins, ymaxs])

def format(title, ax,  xlabel=None, ylabel=None):
    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)
    ax.set_title(title, pad=0, fontweight="bold")
    ax.grid('on', color='#CCCCCC', linestyle=':')
    ax.spines['right'].set_color((1.,1.,1.))
    ax.spines['top'].set_color((1.0,1.0,1.0))
    return ax

def errorbar(x, xticklabels, y, yerr, ax, color):
    ax.errorbar(x, y, yerr=yerr, fmt='-o', capsize=2, color=color, alpha=0.5, ecolor='black')
    ax.set_xticks(x)
    ax.set_xticklabels(xticklabels, rotation=45)

def plot_experiment_3(data, x_axis, save=False):

    x_axis = sorted(x_axis)
    x = list(range(len(x_axis)))

    def prepare_mean_std(key):
        x_values, y_values = to_list(data[key])
        y, yerr = mean_std(x_axis, x_values, y_values)
        return x, x_axis, y, yerr

    def prepare_min_max(key):
        x_values, y_values = to_list(data[key])
        y, yerr = mean_min_max(x_axis, x_values, y_values)
        return x, x_axis, y, yerr

    def fig_1():
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3), sharex=False, sharey=False)
        fig.suptitle('test')
        x, x_axis, y, yerr = prepare_mean_std('active-weights-rel-begin')
        errorbar(x, x_axis, y, yerr, ax1, color='blue')
        format(
            title='Prunable parameters when network splits',
            ylabel='prunable parameters',
            xlabel='hidden neurons per layer',
            ax=ax1
        )

        x, x_axis, y, yerr = prepare_mean_std('active-weights-rel-end')
        errorbar(x, x_axis, y, yerr, ax2, color='red')
        format(
            title='Prunable parameters when network splits',
            ylabel='prunable parameters',
            xlabel='hidden neurons per layer',
            ax=ax2
        )
        plt.show()
    
    fig_1()

    def fig_2():
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3), sharex=False, sharey=False)
        x, x_axis, y, yerr = prepare_min_max('accuracy-begin')
        errorbar(x, x_axis, y, yerr , ax1, color='blue')
        format(
            title='Prunable parameters when network splits',
            ylabel='prunable parameters',
            xlabel='hidden neurons per layer',
            ax=ax1
        ) 
        x, x_axis, y, yerr = prepare_min_max('accuracy-end')
        errorbar(x, x_axis, y, yerr , ax2, color='blue')
        format(
            title='Prunable parameters when network splits',
            ylabel='prunable parameters',
            xlabel='hidden neurons per layer',
            ax=ax2
        )
        plt.show()
        
    fig_2()

    def fig_3():
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3), sharex=False, sharey=False)

        x, x_axis, y, yerr = prepare_min_max('active-features-abs-begin')
        errorbar(x, x_axis, y, yerr , ax1, color='blue')
        format(
            title='active rel begin',
            ylabel='prunable parameters',
            xlabel='hidden neurons per layer',
            ax=ax1
        )
        x, x_axis, y, yerr = prepare_min_max('active-features-abs-end')
        errorbar(x, x_axis, y, yerr , ax2, color='blue')
        format(
            title='active rel end',
            ylabel='prunable parameters',
            xlabel='hidden neurons per layer',
            ax=ax2
        )
        plt.show()
    fig_3()

if False or data_3 is None:
    data_3, x_axis_3 = get_data_experiment_3(runs)

plot_experiment_3(data_3, x_axis_3)