In [None]:
import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

api = wandb.Api(timeout=60)

In [None]:
_active_weights_abs = "active-weights-abs"
_pparams = "pparams"
_untapped_potential = 'untapped-potential'

runs = []
for sweep_id in ['wpcowdl5', 'j9lvvxjg', 'a5kr3muw', 'iat6x9bh']:
    runs += [run for run in api.sweep('concat_moons/' + sweep_id).runs]

histories = [run.history() for run in runs]
configs = [run.config for run in runs]

In [None]:
SAVE = True

In [None]:
# active vs pparams at first degradation level or split level
active, pparams, extensionlevels = [], [], []
for h, c in zip(histories, configs):
    potential = h[_untapped_potential].values
    indices_split = list(np.where(potential == 0)[0])
    indices_degraded = list(np.where(potential < 0)[0])
    
    indices = (  # swap for 'is_degraded' if desired
        indices_split
        #indices_degraded
    )
    if indices:
        active.append(h[_active_weights_abs][indices[0]])
        pparams.append(h[_pparams][indices[0]])
        extensionlevels.append(c['extension_levels'])

group = pd.DataFrame({'active' : active,'pparams' : pparams,'level' : extensionlevels}).groupby('level')

plt.figure(figsize=(4, 4))

x = group.mean().index
y = group.mean()['active']
y_min = y - group.min()['active']
y_max = group.max()['active'] - y

plt.errorbar(x=x-0.05, y=y, yerr=np.stack([y_min, y_max]), marker='o', alpha=0.6, label=f'active weights')

y = group.mean()['pparams']
y_min = y - group.min()['pparams']
y_max = group.max()['pparams'] - y

plt.errorbar( x=x+0.05, y=y, yerr=np.stack([y_min, y_max]), alpha=0.6,marker='o',label=f'available weights')
plt.xlabel('extension level'), plt.ylabel('number of weights'), plt.legend(loc='upper left'), plt.tight_layout()
plt.gca().xaxis.get_major_locator().set_params(integer=True)

plt.grid(which='minor',  axis='y')
plt.grid(which='major', axis='y')
if SAVE: plt.savefig('2-layer-active-available-at-degrade.png')

plt.yscale('log')
if SAVE: plt.savefig('2-layer-active-available-at-degrade-log.png')

print(int(group.mean()["pparams"].mean()))
print(int(group.mean()["pparams"].max()))
print(int(group.mean()["pparams"].min()))
print(int(group.mean()["active"].mean()))
print(int(group.mean()["active"].max()))
print(int(group.mean()["active"].min()))

In [None]:
# histogram of number of networks that split
split, degrade, extensionlevels, split_degrade, connected = [], [], [], [], []
for h, c in zip(histories, configs):
    potential = h[_untapped_potential].values
    indices_split = list(np.where(potential == 0)[0])
    indices_degraded = list(np.where(potential < 0)[0])

    split.append(1 if indices_split and not indices_degraded else 0)
    degrade.append(1 if indices_degraded and not indices_split else 0)
    split_degrade.append(1 if indices_degraded and indices_split else 0)
    connected.append(1 if not indices_degraded and not indices_split else 0)
    extensionlevels.append(c['extension_levels'])

df = pd.DataFrame({
    'seperated' : split, 
    'seperated-degraded': split_degrade, 
    'interconnected' :connected, 
    'degraded' : degrade, 
    'extension level' : extensionlevels
})

df.groupby('extension level').mean().plot.bar(
    width=1.,  
    rot=0,  
    stacked=True,
    color=['#68B684', '#508B65', '#FFC107', '#B84A62'],
    figsize=(9,2.1),
    linewidth=0.1,
    edgecolor='k'
).grid(axis='y', linestyle='--', color='k', linewidth=0.1)
plt.ylabel('percentage of runs')
plt.yticks([0, 0.25, .5, .75, 1])
plt.ylim((0,1)), plt.xlim((-0.5, len(df.groupby('extension level').mean())-0.5))
plt.legend(loc='upper left')
plt.tight_layout()
if SAVE : plt.savefig('2-layer-histogram-split-behaviour.png')

In [None]:
# grid evaluation of size vs number of pruning levels
split, degrade, extensionlevels, split_degrade, connected = [], [], [], [], []
for h, c in zip(histories, configs):
    potential = h[_untapped_potential].values
    indices_split = list(np.where(potential == 0)[0])
    indices_degraded = list(np.where(potential < 0)[0])

    y = list(h['active-weights-abs'])[-1]
    x = c['extension_levels']
    if indices_split and not indices_degraded:
        split.append((x,y))
    elif indices_degraded and not indices_split:
        degrade.append((x,y))
    elif indices_degraded and indices_split:
        split_degrade.append((x,y))
    elif not indices_degraded and not indices_split:
        connected.append((x,y))

# Example 1: Fi
alpha=1
marker='o'
size=30
plt.figure(figsize=(7,2.5))
plt.scatter(*zip(*split), color='#68B684', label='seperated', alpha=alpha, marker=marker,  s=size)
plt.scatter(*zip(*split_degrade), color='#508B65', label='seperated-degraded', alpha=alpha, marker=marker,  s=size)
plt.scatter(*zip(*degrade), color='#B84A62', label='degraded', alpha=alpha, marker=marker,  s=size)
plt.scatter(*zip(*connected), color='#FFC107', label='interconnected', alpha=alpha, marker=marker,  s=size)
plt.xlabel('extension level')
plt.ylabel('number of active weights')
plt.legend(loc='lower left')
#plt.grid()
plt.tight_layout()
plt.savefig('2-layer-compund-damage.png')

In [None]:
s, d, sd, c = [],[],[],[]
d = {
    'before' : [],
    'during' : [],
    'after'  : [],
    #'class'  : [],
}
active = []
loss = []
active_degrade = []
loss_degrade = []
active_before = []
loss_before = []
for h, config in zip(histories, configs):
    potential = h[_untapped_potential].values
    indices_before = list(np.where(potential > 0)[0])
    indices_split = list(np.where(potential == 0)[0])
    indices_degraded = list(np.where(potential < 0)[0])

    losses_while_split = h['val-loss'].iloc[indices_split]
    losses_before_split = h['val-loss'].iloc[indices_before]
    losses_degrade_split = h['val-loss'].iloc[indices_degraded]

    loss += list(losses_while_split)
    active += list(h['active-weights-abs'].iloc[indices_split])
    
    loss_degrade += list(losses_degrade_split)
    active_degrade += list(h['active-weights-abs'].iloc[indices_degraded])

    loss_before += list(losses_before_split)
    active_before += list(h['active-weights-abs'].iloc[indices_before])

    best_during_loss = min(losses_while_split) if losses_while_split.any() else None
    best_before_loss = min(losses_before_split) if losses_before_split.any() else None
    best_after_loss = min(losses_degrade_split) if losses_degrade_split.any() else None

    d['after'].append(best_after_loss)
    d['before'].append(best_before_loss)
    d['during'].append(best_during_loss)
    #d['class'].append('A')

    continue
    val = None
    if indices_split and not indices_degraded:       s.append(val)
    elif indices_degraded and not indices_split:     d.append(val)
    elif indices_degraded and indices_split:         sd.append(val)
    elif not indices_degraded and not indices_split: c.append(val)

_df = pd.DataFrame.from_dict(d)

df = _df[_df['after'].isna()]
before = df[df['during'].isna()].drop(['during', 'after'], axis=1).reset_index(drop=True)

df = _df[_df['after'].isna()]
during = df[df['during'].notna()].drop( 'after', axis=1).reset_index(drop=True)

df = _df[_df['after'].notna()]
during_after = df[df['during'].notna()].reset_index(drop=True)

df = _df[_df['after'].notna()]
after = df[df['during'].isna()].drop('during', axis=1).reset_index(drop=True)
alpha=0.5
marker='o'

# TODO: 
the plots for performance did not yet get made.

In [None]:
_df = pd.DataFrame.from_dict(d)

df = _df[_df['during'].notna()]
n_improve = sum(df['before']-df['during'] > 0)
n_worsen = sum(df['before']-df['during'] < 0)
mean_from_connect_to_split = (df['before']-df['during']).mean()
std_from_connect_to_split = (df['before']-df['during']).std()

# TODO: put the plot into the thesis, add the improvement rate at the splitting iteration
mean_from_connect_to_split, std_from_connect_to_split, n_worsen, n_improve

In [None]:
alpha=0.35
marker='o'
size=20
plt.figure(figsize=(6,2))
plt.scatter(active_before, loss_before, alpha=alpha, marker=marker, label='interconnected', color='#FFC107', s=size)
plt.scatter(active, loss, alpha=alpha*2, marker=marker, label='seperated', color='#68B684', s=size)
plt.scatter(active_degrade, loss_degrade, alpha=alpha*2, marker=marker, label='degraded', color='#B84A62', s=size)
plt.xscale('log')
plt.yscale('log')
plt.legend(loc='upper right')
plt.xlabel('active weights')
plt.ylabel('validation loss')
plt.tight_layout()
plt.show()