In [None]:
import pickle
import os
import sys

In [None]:
from collections import defaultdict
from utils.dict_utils import dict_to_defaultdict
from utils.data_loaders import get_shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
dataset = 'cifar2_binary'
num_classes = get_shape(dataset)[-1]
train_size = 1024 # 1024, 1024, 1024, 1024, 1024, 1024, 8192, 8192
num_hidden = 1
bias = False
normalization = 'none'
activation = 'lrelu'
optimizer = 'sgd'
optimizer_name = 'GD' if optimizer == 'sgd' else 'RMSProp' if optimizer == 'rmsprop' else optimizer
num_seeds = 10
num_epochs = 100  # 2000, 100,  500,  25,   125,  6,    125,  6
batch_size = 1024 # 1024, 1024, 256,  256,  64,   64,   512,  512
lr = 0.02

steps_per_epoch = train_size // batch_size

title = 'train size = {}, batch size = {}'.format(train_size, batch_size)

In [None]:
log_dir = os.path.join(
    'results', 'equiv_models_test', '{}_{}'.format(dataset, train_size), 
    'num_hidden={}_activation={}_bias={}_normalization={}'.format(num_hidden, activation, bias, normalization), 
    '{}_lr={}_batch_size={}_num_epochs={}'.format(optimizer, lr, batch_size, num_epochs)
)
log_dir

In [None]:
results_all_path = os.path.join(log_dir, 'results_all.dat')
results_all = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: None))))))
if os.path.exists(results_all_path):
    with open(results_all_path, 'rb') as f:
        results_all = dict_to_defaultdict(pickle.load(f), results_all)

In [None]:
correction_epochs = [0]
if num_hidden == 1:
    real_widths = [128, 65536]
    scaling_modes = [
        # zero-dimensional regions:
        'mean_field', 
        'ntk', 
        'mean_field_simple_init_corrected',
        'default',
        'default_sym',
        # one-dimensional regions:
#         (-0.75,0.5), #MF-NTK
#         (-0.5,0.25), #NTK-sym_default
#         (-0.75,0.75), #sym_default-MF
#         (-1.5,1.5), #post-MF
#         (0,0), #pre-sym_default
#         (-1,0.5), #post-NTK
#         (0,-0.5), #pre-NTK
        # two-dimensional regions:
#         (-0.66,0.5), #baricenter
#         (-1,0.75), #post-MF-NTK
#         (-10,9.75), #post-MF-NTK
#         (0,-0.25), #pre-NTK-sym_default, far
#         (10,-10.25), #pre-NTK-sym_default, far
        # ill-defined:
#         (1,1), #diverge
#         (-1,-1), #stagnate
    ]
    scaling_mode_names = {
        'mean_field': 'MF', 
        'ntk': 'NTK', 
        'mean_field_simple_init_corrected': 'IC-MF', 
        'default': 'default', 
        'default_sym': 'sym-default'
    }
    ref_widths = [128]
else:
    raise NotImplementedError

In [None]:
plt.rcParams.update({'font.size': 18})

In [None]:
def ewma(a, alpha):
    av_a = a
    for i in range(1, len(a)):
        av_a[i] = a[i] * alpha + av_a[i-1] * (1-alpha)
    return av_a

In [None]:
def draw_curve(scaling_mode, ref_width, real_width, correction_epoch, key, 
               idx=None, threshold=1000, smoothening_factor=0, label=None, **kwargs):
    data = [
        [
            results_all[scaling_mode][ref_width][correction_epoch][real_width][seed][key][epoch] 
            for epoch in range(num_epochs)
        ] for seed in range(num_seeds)
    ]
    data = np.array(data)
    data = np.clip(data, -threshold, threshold)
    data = np.exp(ewma(np.log(data.T)[::-1], alpha=1-smoothening_factor)[::-1].T)
    data_mean = data.mean(axis=0)
    data_std = data.std(axis=0)
    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)
    plt.fill_between(
        np.arange(1, num_epochs+1)*steps_per_epoch,
        data_mean - data_std, data_mean + data_std,
        alpha=0.3, **kwargs
    )
        

In [None]:
linestyles = ['solid', 'dashed', 'dotted', 'dashdot', (0, (1,5)), (0, (3,5))]
cmap = plt.get_cmap('tab10')
key_bases = ['test_losses', 'test_accs', 'train_losses', 'train_accs']
key_modifiers = [('', '')]
ylims = [(0.35,0.45), (0.8, 0.9), (0.6,0.7), (0.7,1.0)]

In [None]:
to_draw = None

for key_base, ylim in zip(key_bases, ylims):
    _ = plt.figure(figsize=(12,6))

    plt.xlabel('training step, k+{}'.format(steps_per_epoch))
    #plt.ylim(ylim) # uncomment to adjust y-limits manually
    plt.grid(True)

    if key_base.endswith('_losses'):
        plt.yscale('log')
        plt.ylabel("CE loss")
    elif key_base.endswith('_accs'):
        plt.ylabel("accuracy")
    plt.xscale('log')

    for ref_width in ref_widths:
        for real_width in real_widths[::-1]:
            for k, scaling_mode in enumerate(scaling_modes):
                if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):
                    continue
                if (ref_width == real_width) and (scaling_mode != 'default'):
                    continue
                for correction_epoch in (
                    correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]
                ):
                    for i, key_mod in enumerate(key_modifiers):
                        key = key_mod[0] + key_base + key_mod[1]
                        if (scaling_mode == 'default') and (ref_width == real_width):
                            draw_curve(
                                scaling_mode, None, 
                                real_width, correction_epoch, key, color='black', 
                                linestyle='dashed', lw=3
                            )
                        else:
                            draw_curve(
                                scaling_mode, ref_width if scaling_mode != 'default' else None, 
                                real_width, correction_epoch, key, color=cmap(k), 
                                linestyle=linestyles[i], lw=3
                            )
                    
    plt.legend(
        [scaling_mode_names[scaling_mode] for scaling_mode in scaling_modes] + ['reference; d*={}'.format(ref_widths[0])]
    )
    plt.show()

In [None]:
def draw_logits(
    scaling_mode, ref_width, real_width, correction_epoch, 
    add_displacement=0, mul_displacement=1,
    label=None, **kwargs
):
    test_logits = [[
        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_logits'][epoch] 
        for epoch in range(num_epochs)
    ] for seed in range(num_seeds)]
    test_logits = np.array(test_logits).squeeze(axis=-1)

    data = np.mean(np.abs(test_logits), axis=-1)
    data += add_displacement
    data *= mul_displacement
        
    data_mean = data.mean(axis=0)
    data_std = data.std(axis=0)
    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)
    plt.fill_between(
        np.arange(1, num_epochs+1)*steps_per_epoch, 
        data_mean-data_std, data_mean+data_std,
        alpha=0.3, **kwargs
    )
        

In [None]:
def get_label(scaling_mode, ref_width=None, real_width=None):
    if ref_width is None or real_width is None:
        appendix = ''
    else:
        appendix = ': ' + r'$d = 2^{}{}{}$'.format('{', int(np.log2(ref_width if scaling_mode == 'reference' else real_width)), '}')
    if scaling_mode == 'reference':
        return scaling_mode + appendix
    else:
        return scaling_mode_names[scaling_mode] + appendix

In [None]:
key_bases = ['losses', 'accs']
key_modifiers = [('test_',''), ('train_','')]
# ylims = [(0.35,0.45), (0.8, 0.9), (0.6,0.7), (0.7,1.0)]

to_draw = ['mean_field', 'ntk', 'default_sym', 'mean_field_simple_init_corrected',]

for key_base, ylim in zip(key_bases, ylims):
    _ = plt.figure(figsize=(6,4))

    plt.xlabel('training step, k+{}'.format(steps_per_epoch))
    # plt.ylim(ylim) # uncomment to adjust y-limits manually
    plt.grid(True)

    if key_base == 'losses':
        plt.ylabel("CE loss")
    elif key_base == 'accs':
        plt.ylabel("accuracy")
    #plt.yscale('log')
    plt.xscale('log')

    for ref_width in ref_widths:
        for real_width in real_widths[::-1]:
            for k, scaling_mode in enumerate(scaling_modes):
                if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):
                    continue
                if (ref_width == real_width) and (scaling_mode != 'default'):
                    continue
                for correction_epoch in (
                    correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]
                ):
                    for i, key_mod in enumerate(key_modifiers):
                        key = key_mod[0] + key_base + key_mod[1]
                        if (scaling_mode == 'default') and (ref_width == real_width):
                            draw_curve(
                                scaling_mode, None, 
                                real_width, correction_epoch, key, color='black', 
                                linestyle=linestyles[i], lw=3, label=get_label('reference', ref_width, real_width) if i == 0 else None
                            )
                        else:
                            draw_curve(
                                scaling_mode, ref_width if scaling_mode != 'default' else None, 
                                real_width, correction_epoch, key, color=cmap(k), 
                                linestyle=linestyles[i], lw=3, label=get_label(scaling_mode, ref_width, real_width) if i == 0 else None
                            )
                    
    if key_base == 'losses':
        plt.legend()
    plt.title(title)
    plt.show()

In [None]:
to_draw = ['mean_field', 'ntk', 'default_sym', 'mean_field_simple_init_corrected',]

_ = plt.figure(figsize=(6,4))

plt.xlabel('training step, k+{}'.format(steps_per_epoch))
#plt.ylim(ylim) # uncomment to adjust y-limits manually
plt.grid(True)

plt.ylabel('mean abs logit, ' + r"$\mathbb{E}_x |f(x)|$")
plt.yscale('log')
plt.xscale('log')

for ref_width in ref_widths:
    for real_width in real_widths[::-1]:
        for k, scaling_mode in list(enumerate(scaling_modes)):
            if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):
                continue
            if (ref_width == real_width) and (scaling_mode != 'default'):
                continue
            for correction_epoch in (
                correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]
            ):
                if (scaling_mode == 'default') and (ref_width == real_width):
                    draw_logits(
                        scaling_mode, None, 
                        real_width, correction_epoch, color='black',
                        linestyle='dotted',
                        lw=3, label=get_label('reference', ref_width, real_width)
                    )
                else:
                    draw_logits(
                        scaling_mode, ref_width if scaling_mode != 'default' else None, 
                        real_width, correction_epoch, color=cmap(k), 
                        lw=3, label=get_label(scaling_mode, ref_width, real_width),
                        mul_displacement=1.05**k
                    )

# plt.legend()
plt.title(title)
plt.show()

In [None]:
from scipy.special import kl_div, expit, digamma, gamma

In [None]:
def mean_and_var(a, **kwargs):
    return np.mean(a, **kwargs), np.var(a, ddof=1, **kwargs)

def kl_between_normals(mean, var, mean_ref, var_ref):
    return 0.5 * (var / var_ref + (mean_ref - mean) ** 2 / var_ref - 1 - np.log(var / var_ref))

def estimate_beta_distr_params(mean, var, eps=1e-2):
    alpha = np.clip(mean * ((mean * (1 - mean) + eps) / (var + eps) - 1), a_min=eps, a_max=1/eps)
    beta = np.clip((1 - mean) * ((mean * (1 - mean) + eps) / (var + eps) - 1), a_min=eps, a_max=1/eps)
    return alpha, beta

def B(alpha, beta):
    return gamma(alpha) * gamma(beta) / gamma(alpha + beta)

def kl_between_betas(alpha, beta, alpha_ref, beta_ref):
    return np.log(B(alpha_ref, beta_ref) / B(alpha, beta)) +\
           (alpha - alpha_ref) * digamma(alpha) + (beta - beta_ref) * digamma(beta) - (alpha - alpha_ref + beta - beta_ref) * digamma(alpha + beta)

In [None]:
def discrepancy_between_logits(logits, logits_ref, discrepancy_type, seed_axis=0):
    if discrepancy_type == 'logit':
        mean, var = mean_and_var(logits, axis=seed_axis)
        mean_ref, var_ref = mean_and_var(logits_ref, axis=seed_axis)
        discrepancy = kl_between_normals(mean, var, mean_ref, var_ref)
    elif discrepancy_type == 'prob':
        mean, var = mean_and_var(expit(logits), axis=seed_axis)
        mean_ref, var_ref = mean_and_var(expit(logits_ref), axis=seed_axis)
        alpha, beta = estimate_beta_distr_params(mean, var)
        alpha_ref, beta_ref = estimate_beta_distr_params(mean_ref, var_ref)
        discrepancy = kl_between_betas(alpha, beta, alpha_ref, beta_ref)
    elif discrepancy_type == 'class':
        prob = np.mean((logits > 0).astype(float), axis=seed_axis)
        prob_ref = np.mean((logits_ref > 0).astype(float), axis=seed_axis)
        discrepancy = np.abs(prob - prob_ref)
    else:
        raise ValueError
    return discrepancy

In [None]:
def draw_mean_discrepancy(scaling_mode, ref_width, real_width, correction_epoch, discrepancy_type, **kwargs):
    test_logits = [[
        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_logits'][epoch] 
        for epoch in range(num_epochs)
    ] for seed in range(num_seeds)]
    test_logits = np.array(test_logits).squeeze(axis=-1)

    test_logits_ref = [[
        results_all['default'][None][None][ref_widths[0]][seed]['test_logits'][epoch] 
        for epoch in range(num_epochs)
    ] for seed in range(num_seeds)]
    test_logits_ref = np.array(test_logits_ref).squeeze(axis=-1)
    
    data = discrepancy_between_logits(test_logits, test_logits_ref, discrepancy_type=discrepancy_type)
        
    data_mean = data.mean(axis=-1)
    data_std = data.std(axis=-1)
    data_max = data.max(axis=-1)
    plt.plot(np.arange(1, num_epochs+1) * steps_per_epoch, data_mean, **kwargs)
        

In [None]:
discrepancy_types = ['logit', 'prob', 'class']
ylabels = ['KL(limit, ref)', 'KL(limit, ref)', '|p_limit - p_ref|']

for discrepancy_type, ylabel in zip(discrepancy_types, ylabels):
    _ = plt.figure(figsize=(12,6))

    plt.xlabel('training step, k+{}'.format(steps_per_epoch))
    #plt.ylim(ylim) # uncomment to adjust y-limits manually
    if discrepancy_type == 'logit':
        plt.yscale('log')
    elif discrepancy_type == 'prob':
        plt.yscale('log')
    else:
        plt.yscale('log')
    plt.xscale('log')
    plt.grid(True)

    plt.ylabel(ylabel)

    for ref_width in ref_widths:
        for real_width in real_widths[::-1]:
            for k, scaling_mode in enumerate(scaling_modes):
                if ref_width == real_width:
                    continue
                for correction_epoch in (
                    correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]
                ):
                    draw_mean_discrepancy(
                        scaling_mode, ref_width if scaling_mode != 'default' else None, 
                        real_width, correction_epoch, discrepancy_type=discrepancy_type,
                        color=cmap(k), lw=3, label=get_label(scaling_mode)
                    )

    plt.legend()
    plt.title(title)
    plt.show()

In [None]:
def draw_tangent_kernels(
    scaling_mode, ref_width, real_width, correction_epoch, 
    add_displacement=0, mul_displacement=1,
    label=None, layer='sum', relative=False, **kwargs
):
    test_tangent_kernels = [[
        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_tangent_kernels'][epoch][::2]
        for epoch in range(num_epochs)
    ] for seed in range(num_seeds)]
    test_tangent_kernels = np.array(test_tangent_kernels)

    if layer == 'input':
        test_tangent_kernels = test_tangent_kernels[...,0,:]
    elif layer == 'output':
        test_tangent_kernels = test_tangent_kernels[...,1,:]
    elif layer == 'sum':
        test_tangent_kernels = np.sum(test_tangent_kernels, axis=-2)
    elif layer == 'hidden':
        raise NotImplementedError
    else:
        raise ValueError
        
    if relative:
        test_tangent_kernels = test_tangent_kernels / test_tangent_kernels[:,0:1]
    
    data = np.mean(np.abs(test_tangent_kernels), axis=-1)
    data += add_displacement
    data *= mul_displacement

    data_mean = data.mean(axis=0)
    data_std = data.std(axis=0)
    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)
    plt.fill_between(
        np.arange(1, num_epochs+1)*steps_per_epoch, 
        data_mean-data_std, data_mean+data_std, 
        alpha=0.5, **kwargs
    )
        

In [None]:
# to_draw = ['mean_field', 'ntk', 'default_sym']

for relative in [True, False]:
    for layer in ['sum']:
        _ = plt.figure(figsize=(6,4))
#         plt.title(layer + (' relative' if relative else ''))

        plt.xlabel('training step, k+{}'.format(steps_per_epoch))
        #plt.ylim(ylim) # uncomment to adjust y-limits manually
        plt.grid(True)

        plt.yscale('log')
        plt.xscale('log')
        if relative:
            plt.ylabel(r"$\mathbb{E}_x (K(x,x) / K_{init}(x,x))$")
            plt.yticks([1,2], ['1', '2'])
        else:
            plt.ylabel('mean diag kernel, ' + r"$\mathbb{E}_x K(x,x)$")

        for ref_width in ref_widths:
            for real_width in real_widths[::-1]:
                for k, scaling_mode in enumerate(scaling_modes):
                    if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):
                        continue
                    if (ref_width == real_width) and (scaling_mode != 'default'):
                        continue
                    for correction_epoch in (
                        correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]
                    ):
                        if (scaling_mode == 'default') and (ref_width == real_width):
                            draw_tangent_kernels(
                                scaling_mode, None, 
                                real_width, correction_epoch, color='black', layer=layer, relative=relative,
                                linestyle='dotted', lw=3, label=get_label('reference', ref_width, real_width)
                            )
                        else:
                            draw_tangent_kernels(
                                scaling_mode, ref_width if scaling_mode != 'default' else None, 
                                real_width, correction_epoch, color=cmap(k), layer=layer, relative=relative,
                                lw=3, label=get_label(scaling_mode, ref_width, real_width),
                                mul_displacement=1.03**k
                            )

        #plt.legend()
        plt.title(title)
        plt.show()

In [None]:
def draw_logits_by_tangent_kernels(
    scaling_mode, ref_width, real_width, correction_epoch, 
    add_displacement=0, mul_displacement=1,
    label=None, layer='sum', relative=False, **kwargs
):
    test_logits = [[
        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_logits'][epoch] 
        for epoch in range(num_epochs)
    ] for seed in range(num_seeds)]
    test_logits = np.array(test_logits).squeeze(axis=-1)

    test_tangent_kernels = [[
        results_all[scaling_mode][ref_width][correction_epoch][real_width][seed]['test_tangent_kernels'][epoch][::2]
        for epoch in range(num_epochs)
    ] for seed in range(num_seeds)]
    test_tangent_kernels = np.array(test_tangent_kernels)

    if layer == 'input':
        test_tangent_kernels = test_tangent_kernels[...,0,:]
    elif layer == 'output':
        test_tangent_kernels = test_tangent_kernels[...,1,:]
    elif layer == 'sum':
        test_tangent_kernels = np.sum(test_tangent_kernels, axis=-2)
    elif layer == 'hidden':
        raise NotImplementedError
    else:
        raise ValueError
        
    data = np.mean(np.abs(test_logits[...,:test_tangent_kernels.shape[-1]] / test_tangent_kernels), axis=-1)
    data += add_displacement
    data *= mul_displacement

    data_mean = data.mean(axis=0)
    data_std = data.std(axis=0)
    plt.plot(np.arange(1, num_epochs+1)*steps_per_epoch, data_mean, label=label, **kwargs)
    plt.fill_between(
        np.arange(1, num_epochs+1)*steps_per_epoch, 
        data_mean-data_std, data_mean+data_std, 
        alpha=0.3, **kwargs
    )
        

In [None]:
# to_draw = ['mean_field', 'ntk', 'default_sym']

for relative in [False]:
    for layer in ['sum', 'input', 'output']:
        _ = plt.figure(figsize=(6,4))
#         plt.title(layer + (' relative' if relative else ''))

        plt.xlabel('training step, k+{}'.format(steps_per_epoch))
        #plt.ylim(ylim) # uncomment to adjust y-limits manually
        plt.grid(True)

        plt.ylabel(r"$\mathbb{E}_x |f(x) / K(x,x)|$")
        plt.yscale('log')
        plt.xscale('log')
#         plt.yticks([1,2], ['1', '2'])

        for ref_width in ref_widths:
            for real_width in real_widths[::-1]:
                for k, scaling_mode in enumerate(scaling_modes):
                    if to_draw is not None and scaling_mode not in to_draw and (ref_width != real_width):
                        continue
                    if (ref_width == real_width) and (scaling_mode != 'default'):
                        continue
                    for correction_epoch in (
                        correction_epochs if scaling_mode == 'mean_field' else [0] if scaling_mode.startswith('mean_field') else [None]
                    ):
                        if (scaling_mode == 'default') and (ref_width == real_width):
                            draw_logits_by_tangent_kernels(
                                scaling_mode, None, 
                                real_width, correction_epoch, color='black', layer=layer, relative=relative,
                                linestyle='dotted', lw=3, label=get_label('reference', ref_width, real_width)
                            )
                        else:
                            draw_logits_by_tangent_kernels(
                                scaling_mode, ref_width if scaling_mode != 'default' else None, 
                                real_width, correction_epoch, color=cmap(k), layer=layer, relative=relative,
                                lw=3, label=get_label(scaling_mode, ref_width, real_width),
                                mul_displacement=1.05**k
                            )

        # plt.legend()
        plt.title(title)
        plt.show()