In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#%matplotlib notebook
%matplotlib inline

In [3]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/initialization')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations/fc_abc.yaml')

In [4]:
import sys
sys.path.append(ROOT)

In [5]:
import torch
import pickle
from utils.tools import load_pickle

from utils.tools import read_yaml, set_random_seeds
from utils.plot.abc_parameterizations.initializations import *
from utils.plot.abc_parameterizations.one_d_functions import *
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected import ntk, ip, muP, ipllr
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP

## Set variables

In [6]:
SEED = 42
N_SAMPLES = 100
N_TRAIN = 10
BASE_LR = 0.1

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)
base_config = ModelConfig(config_dict)

# modify config for this notebook
base_config.architecture['input_size'] = 1
base_config.architecture['output_size'] = 1
base_config.architecture['bias'] = False
base_config.activation.name = 'relu'
base_config.loss.name = 'mse'
base_config.optimizer.params['lr'] = BASE_LR

In [7]:
xs, ys = generate_1d_data(n_samples=10)

## Define models

In [8]:
WIDTHS = [1024]
N_WARMUP_STEPS = 1  # 4
L = 4
N_TRIALS = 10  # 5
N_EPOCHS = 3000 # 6000

In [9]:
name_to_model_dict = {
    'StandardIP': StandardFCIP,
    'NTK': ntk.FCNTK,
    'muP': muP.FCmuP,
    'IPLLR': ipllr.FcIPLLR
}

In [10]:
models_dict = dict()
for name, model in name_to_model_dict.items():
    models_dict[name] = dict()
    for width in WIDTHS:  # define models with different widths
        # first modify the base common config
        config = deepcopy(base_config)
        config.name = name
        config.architecture['width'] = width
        config.architecture['n_layers'] = L + 1
        
        if name == 'IPLLR':
            scheduler_config = {'name': 'warmup_switch', 
                                'params': {'n_warmup_steps': N_WARMUP_STEPS}}
            config.scheduler = BaseConfig(scheduler_config)
        if 'IP' in name:
            config.optimizer.params['lr'] = 0.4
        if name == 'muP':
            config.optimizer.params['lr'] = 0.1
        
        # define N_TRIALS random initializations of the same model
        models_dict[name][width] = [name_to_model_dict[name](config) for _ in range(N_TRIALS)]

### Set U and v to be the same for all models

In [11]:
for name in models_dict.keys():
    if name != 'NTK':
        for width in WIDTHS:
            for i in range(N_TRIALS):
                ntk = models_dict['NTK'][width][i]
                model = models_dict[name][width][i]
                with torch.no_grad():
                    model.copy_initial_params_from_model(ntk, check_model=True)
                    model.initialize_params()

In [12]:
init_dict = dict()
for name in models_dict.keys():
    init_dict[name] = dict()
    for width in WIDTHS:
        init_dict[name][width] = [deepcopy(model) for model in models_dict[name][width]]

In [13]:
models_dict_copy = deepcopy(models_dict)

In [None]:
batch_xs = torch.unsqueeze(xs, 1)
batch_ys = torch.unsqueeze(ys, 1)
#for name in models_dict.keys():
#    for width in WIDTHS:
#        for model in models_dict[name][width]:
#            fit_model(model, batch_xs, batch_ys, n_epochs=N_EPOCHS)

for name in ['IPLLR']:
#for name in ['muP']:
    for width in WIDTHS:
        for model in models_dict[name][width]:
            fit_model(model, batch_xs, batch_ys, n_epochs=N_EPOCHS)

In [None]:
# INIT 
name = 'NTK'
plt.figure(figsize=(12, 6))
plot_model(init_dict[name][1024], xs, ys, label=name, scatter=True)

In [None]:
name = 'NTK'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024], xs, ys, label=name, scatter=True)

In [None]:
# INIT 
name = 'muP'
plt.figure(figsize=(12, 6))
plot_model(init_dict[name][1024], xs, ys, label=name, scatter=True)

In [None]:
name = 'muP'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024], xs, ys, label=name, scatter=True)

In [None]:
name = 'muP'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024][0], xs, ys, label=name, scatter=True)

In [None]:
name = 'muP'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024][1], xs, ys, label=name, scatter=True)

In [None]:
name = 'muP'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024][2], xs, ys, label=name, scatter=True)

In [None]:
# INIT 
name = 'IPLLR'
plt.figure(figsize=(12, 6))
plot_model(init_dict[name][1024], xs, ys, label=name, scatter=True)
plt.ylim(-0.00001, 0.00001)
plt.show()

In [None]:
name = 'IPLLR'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024][0], xs, ys, label=name, scatter=True)
# plt.ylim(-0.0001, 0.0001)
plt.show()

In [None]:
name = 'IPLLR'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024][1], xs, ys, label=name, scatter=True)
# plt.ylim(-0.0001, 0.0001)
plt.show()

In [None]:
name = 'IPLLR'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024][2], xs, ys, label=name, scatter=True)
# plt.ylim(-0.0001, 0.0001)
plt.show()

In [None]:
# WEIGHTS OF THE LAST LAYER 

In [None]:
name = 'IPLLR'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024], xs, ys, label=name, scatter=True)
# plt.ylim(-0.0001, 0.0001)
plt.show()

In [None]:
# INIT 
name = 'StandardIP'
plt.figure(figsize=(12, 6))
plot_model(init_dict[name][1024], xs, ys, label=name, scatter=True)
plt.ylim(-0.0001, 0.0001)
plt.show()

In [None]:
name = 'StandardIP'
plt.figure(figsize=(12, 6))
plot_model(models_dict[name][1024], xs, ys, label=name, scatter=True)
#plt.ylim(-0.0001, 0.0001)
plt.show()

### Training

In [None]:
# models = deepcopy(init_dict)

In [None]:
#name = 'NTK'
#plt.figure(figsize=(12, 6))
#plot_training(models[name][1024], xs, ys, label=name)
#plt.show()

In [None]:
#x = np.linspace(0, 6*np.pi, 100)
#y = np.sin(x)

# You probably won't need this if you're embedding things in a tkinter plot...
#plt.ion()

#fig = plt.figure()
#ax = fig.add_subplot(111)
#line1, = ax.plot(x, y, 'r-') # Returns a tuple of line objects, thus the comma

#for phase in np.linspace(0, 10*np.pi, 500):
#    line1.set_ydata(np.sin(x + phase))
#    fig.canvas.draw()
#    fig.canvas.flush_events()