In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../../project")
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import wandb

api = wandb.Api()
base = "mxmn/concat_moons/"

In [None]:
def to_numpy(d: dict):
    return np.array(list(d.values()))

def mean_and_std(d: dict):
    x = to_numpy(d)
    return x.mean(axis=1), x.std(axis=1)

def bar(ax, y, title, ylabel, color):
    x = range(len(y))
    ax.set_title(title), ax.set_ylabel(ylabel)
    ax.bar(x, y, color=color)

In [None]:
id = 'ix4onq8c'
sweep =  api.sweep(base + id)

splitters = defaultdict(int)
degraders = defaultdict(int)
direct_degraders = defaultdict(int)
nonsplitters = defaultdict(int)

for run in sweep.runs:
    hist = run.history()
    degraded = (hist['untapped-potential'] < 0).any()
    split = (hist['untapped-potential'] == 0).any()
    level = run.config['pruning_levels']
    splitters[level] += 1 if (split and not degraded) else 0
    degraders[level] += 1 if (degraded and split) else 0
    direct_degraders[level] += 1 if (degraded and not split) else 0
    nonsplitters[level] += 1 if ((not degraded) and (not split)) else 0

In [None]:
fig, axs = plt.subplots(4, 1, figsize=(10, 8), sharex=True, sharey=True)

bar(
    title='Number of networks that split',
    ylabel='#split', 
    y=splitters.values(), 
    color='green',
    ax=axs[0], 
)

bar(
    title='Number of networks that degraded',
    ylabel='#degrade after split', 
    y=degraders.values(), 
    color='green',
    ax=axs[1], 
)

bar(
    title='Number of networks that degraded without splitting before',
    ylabel='#degrade before split', 
    y=direct_degraders.values(), 
    color='red',
    ax=axs[2], 
)

bar(
    title='Number of networks that didnt split',
    ylabel='num_networks', 
    y=nonsplitters.values(), 
    color='gray',
    ax=axs[3], 
)

axs[-1].set_xticks(range(len(nonsplitters)))
axs[-1].set_xticklabels(nonsplitters.keys())
fig.show()

In [None]:
id = 'p30nbq46'
sweep =  api.sweep(base + id)

splitters = defaultdict(int)
degraders = defaultdict(int)
direct_degraders = defaultdict(int)
nonsplitters = defaultdict(int)

for run in sweep.runs:
    hist = run.history()
    degraded = (hist['untapped-potential'] < 0).any()
    split = (hist['untapped-potential'] == 0).any()
    level = run.config['model_shape'][1]
    splitters[level] += 1 if (split and not degraded) else 0
    degraders[level] += 1 if (degraded and split) else 0
    direct_degraders[level] += 1 if (degraded and not split) else 0
    nonsplitters[level] += 1 if ((not degraded) and (not split)) else 0

In [None]:
fig, axs = plt.subplots(4, 1, figsize=(10, 8), sharex=True, sharey=True)

# Set labels and title if needed
x = range(len(nonsplitters.values()))

ax = axs[0]
ax.set_ylabel('#split')
ax.set_title('Number of networks that split')
ax.bar(x, splitters.values(), color='green')

ax = axs[1]
ax.set_ylabel('#degrade after split')
ax.set_title('Number of networks that degraded')
ax.bar(x, degraders.values(), color='lightgreen')

ax = axs[2]
ax.set_ylabel('#degrade before split')
ax.set_title('Number of networks that degraded without splitting before')
ax.bar(x, direct_degraders.values(), color='red')

ax = axs[3]
ax.set_xlabel('number of hidden neurons')
ax.set_ylabel('#no split')
ax.set_title('Number of networks that didnt split')
ax.bar(x, nonsplitters.values(), color='gray')
ax.set_xticks(x)
ax.set_xticklabels(nonsplitters.keys())

In [None]:
sweep_ids = ['wpcowdl5']

runs = [api.sweep(base + id).runs for id in sweep_ids]

first_split = defaultdict(list)
last_split = defaultdict(list)
first_acc = defaultdict(list)
last_acc = defaultdict(list)
first_active = defaultdict(list)
last_active = defaultdict(list)

for run in runs:

    total_pparams = run.config['param_trajectory'][0]

    hist = run.history()
    idc = np.where(hist['untapped-potential'].values == 0)[0]
    
    if len(idc) > 0:
        first, last = idc[0], idc[-1]
        first_split[total_pparams].append(hist['pparams'][first])
        last_split[total_pparams].append(hist['pparams'][last])
        first_acc[total_pparams].append(hist['val-loss'][first])
        last_acc[total_pparams].append(hist['val-loss'][last])
        first_active[total_pparams].append(hist['active-abs'][first])
        last_active[total_pparams].append(hist['active-abs'][last])


In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 9), sharex=True, sharey=True)

# Set labels and title if needed
x = range(len(first_split.values()))

ax = axs[0][0]
ax.set_ylabel('#pparams min')
ax.set_title('Number of networks that didnt split')

y, yerr = mean_and_std(first_split)
ax.errorbar(x, y, yerr=yerr, fmt='-o', capsize=2, color='darkgreen')

ax = axs[1][0]
ax.set_ylabel('#pparams min')
ax.set_title('Number of networks that didnt split')

y, yerr = mean_and_std(last_split)
ax.errorbar(x, y, yerr=yerr, fmt='-o', capsize=2, color='darkgreen')

ax = axs[0][1]
ax.set_ylabel('#active params')
ax.set_title('Number of active parameters on first split')
y, yerr = mean_and_std(first_active)
ax.errorbar(x, y, yerr=yerr, fmt='-o', capsize=2, color='darkgreen')

ax = axs[1][1]
ax.set_ylabel('#pparams min')
ax.set_title('Number of active parameters on last split')
y, yerr = mean_and_std(last_active)
ax.errorbar(x, y, yerr=yerr, fmt='-o', capsize=2, color='darkgreen')

ax.set_xlabel('number of pparams')