In [None]:
%load_ext autoreload 

from tqdm import tqdm
import matplotlib.pyplot as plt
import plotly.express as px
import numpy as np
from easydict import EasyDict as edict
import pandas as pd
import os
from collections import defaultdict
from joblib import Parallel, delayed
import multiprocessing as mp
from IPython.core.debugger import set_trace
from IPython.display import clear_output

import copy

from sklearn.decomposition import PCA, KernelPCA, FastICA
from sklearn.metrics import r2_score, make_scorer
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.datasets import make_swiss_roll,\
                             make_s_curve,\
                             make_moons

from umap import UMAP

import torch
from torch import nn
from torch import optim
from torch import autograd

from torch.utils.data import DataLoader

from train_utils import *
from metric_utils import calculate_Q_metrics, \
                         strain, \
                         l2_loss, \
                         to_numpy, \
                         numpy_metric, \
                         cosine_sim

from input_utils import DataGenerator, make_random_affine
from mlp_model import MLP_NonlinearEncoder
from models_utils import init_weights, \
                         universal_approximator, \
                         dJ_criterion, \
                         gained_function, \
                         sigmoid, \
                         initialize_nonlinearities

from embedding_utils import ConstructUMAPGraph, UMAPLoss, UMAPDataset, umap_criterion_compatibility

import warnings
warnings.filterwarnings("ignore")

plt.rcParams['font.size'] = 20
device = torch.device('cuda:0')
N_CPU = mp.cpu_count()
SEED = 42
%autoreload 2

# Load data 

In [None]:
SCALER = StandardScaler()

input_parameters = {'generator': make_swiss_roll, #make_s_curve, 
                    'generator_kwargs': {'n_samples':10000, 'noise':1e-2}, # 1e-1
                    'unsupervised':True,
                    'whiten':True,
                    'scaler':SCALER, #SCALER,
                    'use_outpt_color':True} 

create_data = DataGenerator(**input_parameters)

inpt, _, color = create_data()

N_TEST = 1000

In [None]:
inpt.mean(1), inpt.std(1)

In [None]:
inpt.max(1), inpt.min(1), inpt@inpt.T / inpt.shape[1]

In [None]:
inpt_train, inpt_test, color_train, color_test = train_test_split(inpt.T, 
                                                                  color, 
                                                                  random_state=42,
                                                                  test_size=N_TEST)

In [None]:
inpt_train_torch = torch.tensor(inpt_train, dtype=torch.float32).to(device)
inpt_test_torch = torch.tensor(inpt_test, dtype=torch.float32).to(device)

In [None]:
(inpt_train_torch.T@inpt_train_torch / inpt_train_torch.shape[0]).round()

In [None]:
inpt_train_torch.mean(0), inpt_train_torch.std(0)

In [None]:
inpt_train_torch.shape, inpt_test_torch.shape

In [None]:
# plt.ioff()
# plt.figure()
# df = pd.DataFrame(inpt.T, columns=['x','y', 'z'])
# if color is not None:
#     df['target'] = color
# fig = px.scatter_3d(df, x='x', y='y', z='z', color='target' if 'target' in df else None)

# fig.show()

# Metalearning: gradient

### Setup training

In [None]:
encoder_parameters = {
                    'input_dim':inpt.shape[0],
                    'hidden_dim':60,
                    'embedding_dim':2,
                    'add_readout':False,
                    'add_recurrent_connections':False,
                    'add_recurrent_nonlinearity':False,
                    'hebbian_update':criterion_rule,
                    'inplace_update':False,
                    'normalize_hebbian_update':True,
                    'lr_hebb':1e-5,
                    'W_requires_grad':False,
                    'W_r_requires_grad':False,
                    'f_requires_grad':True,
                    'final_nonlinearity':False,
                    'parametrized_f':True,
                    'nonlinearity': universal_approximator,
#                     'nonlinearity': nn.Tanh(),
                    'f_kwargs':{'hidden_dim':10, 'requires_grad':True},
                    'layers_number':4,
                    'add_bn':True,
                    'seed':None,
                    'set_seed':False,
                     }

network = MLP_NonlinearEncoder(**encoder_parameters).to(device)

def weight_saver(network):
    weights = {}
    for k,v in network.named_parameters():
        weights[k] = to_numpy(v.flatten())
    return weights

training_parameters = edict({'epochs':None,
                           'enable_grad_train':True,
                           'enable_grad_val':True,
                           'backprop_learning':True,
                           'hebbian_learning':False,
                           'lr':1e-4,
                           'wd':0, 
                           'maxiter':None, # maxiter
                           'progress_bar':True,
                           'weight_saver':None,
                           'calculate_grad':True,
                           'clip_grad_value': None,
                           'val_metrics':None,
                           'device':device
                          })

criterion_kwargs = defaultdict(dict)
criterion_kwargs['skip_train'] = False
criterion_kwargs['skip_val'] = False

opt = None
if training_parameters['backprop_learning']:
    opt = optim.Adam(get_grad_params(network.parameters()), 
                     lr=training_parameters.lr,  
                     weight_decay=training_parameters.wd)

In [None]:
print('Net capacity:', get_capacity(network))
print('Parameters:')

for name, param in network.named_parameters():
    print(name, param.shape, 'requires_grad:', param.requires_grad, 'Device:', param.device)
    W = to_numpy(param.data)
    plt.hist(W.flatten(),bins=20, alpha=0.5, label=f'{name}')
plt.legend()
plt.show()

In [None]:
state_dict = torch.load('./results/mlp_nonlinear/swiss_roll/MLP_bp_hdim-60_lnum-4_Wgrad-1_fgrad-1_universal_approximator_bn-1')
initialize_nonlinearities(network, state_dict)
# network.load_state_dict(state_dict)

### Create datasets

In [None]:
graph_constructor = ConstructUMAPGraph(metric='euclidean', 
                                        n_neighbors=20, 
                                        random_state=SEED)

# (epochs_per_sample, head, tail, weight) 
train_graph_data = graph_constructor(inpt_train)
test_graph_data = graph_constructor(inpt_test)

BATCH_SIZE_BP = 10000

criterion_umap = UMAPLoss(device=device, 
                         min_dist=0.1,
                         negative_sample_rate=5,
                         edge_weight=None,
                         repulsion_strength=1.0)

criterion = umap_criterion_compatibility(criterion_umap)

### Create criterion

In [None]:
X_s = network.forward(inpt_test_torch.to(training_parameters['device']))
outpt_val_pred = to_numpy(X_s[-1])

plt.figure()
plt.scatter(outpt_val_pred[0],
            outpt_val_pred[1], 
            c=color_test
           )
plt.colorbar()
plt.show()

### Training Grad

In [None]:
# network, _, metric_dict = train(network, 
#                               opt=None, 
#                               criterion=criterion,
#                               criterion_kwargs=criterion_kwargs,
#                               parameters=training_parameters,
#                               train_dataloader=train_hebb_dataloader,
#                               val_dataloader=dataset_test, 
#                               metric_dict=None,
#                               val_metrics=None
#                               )
# plt.plot(metric_dict['criterion_val'])

In [None]:
# for k in network.state_dict().keys():
#     if not 'f_s' in k:
#         w_s = [w_dict[k] for w_dict in metric_dict['weights']]
#         w_s = np.stack(w_s, 0)
#         plt.figure()
#         plt.plot(w_s)
#         plt.title(k)
#         plt.show()

In [None]:
# X_s = network.forward(inpt_test_torch.to(training_parameters['device']))
# outpt_val_pred = to_numpy(X_s[-1])

# plt.figure()
# plt.scatter(outpt_val_pred[0],
#             outpt_val_pred[1], 
#             c=color_test
#            )
# plt.colorbar()
# plt.show()

### Meta-iterations

In [None]:
opt = optim.Adam(get_grad_params(network.parameters()),
                 lr=training_parameters['lr'])

In [None]:
best_state_dict = None
best_metric = np.inf

ITER_META = 50
EPOCHS_META = 1 
EPOCHS_HEBB = 1

metric_dict = defaultdict(list)

meta_switch_times = defaultdict(list)
ordinary_switch_times = defaultdict(list)

training_parameters['progress_bar'] = False
training_parameters['weight_saver'] = None

for meta_iter in tqdm(range(ITER_META)):
    
    print(f'Doing {meta_iter} META iteration')
    
    meta_switch_times['train'].append(len(metric_dict['criterion_train']))
    meta_switch_times['val'].append(len(metric_dict['criterion_val']))
    
    # training meta-parameters usign BP
    training_parameters['backprop_learning'] = True
    training_parameters['hebbian_learning'] = False
    
    training_parameters['epochs'] = EPOCHS_META
    criterion_kwargs['skip_train'] = False
    criterion_kwargs['skip_val'] = False
    
    
    if EPOCHS_META > 0:
        
        dataset_train = UMAPDataset(inpt_train, 
                                    *train_graph_data, 
                                    device=device, 
                                    batch_size=BATCH_SIZE_BP, 
                                    shuffle=True)
        
        dataset_test = UMAPDataset(inpt_test, 
                                   *test_graph_data, 
                                   device=device, 
                                   batch_size=BATCH_SIZE_BP, 
                                   shuffle=True)
        
        network, opt, metric_dict = train(network, 
                                      opt=opt, 
                                      criterion=criterion,
                                      criterion_kwargs=criterion_kwargs,
                                      parameters=training_parameters,
                                      train_dataloader=dataset_train,
                                      val_dataloader=dataset_test, 
                                      metric_dict=metric_dict,
                                      val_metrics=None
                                      )
    
                                  
    ordinary_switch_times['train'].append(len(metric_dict['criterion_train']))
    ordinary_switch_times['val'].append(len(metric_dict['criterion_val']))
    
    # save current model if there's improve
    current_val_criterion = metric_dict['criterion_val'][-1]
    if current_val_criterion < best_metric:
        best_metric = current_val_criterion
        best_state_dict = network.state_dict()
    
    
    # training connectivity using local rule
    training_parameters['backprop_learning'] = False
    training_parameters['hebbian_learning'] = True
    
    training_parameters['batch_size'] = 1
    training_parameters['epochs'] = EPOCHS_HEBB
    criterion_kwargs['skip_train'] = True
    criterion_kwargs['skip_val'] = False
    
    
    # for hebbian update
    train_hebb_dataloader = DataLoader(inpt_train_torch, batch_size=1, shuffle=True)

    network, _, metric_dict = train(network, 
                                  opt=None, 
                                  criterion=criterion,
                                  criterion_kwargs=criterion_kwargs,
                                  parameters=training_parameters,
                                  train_dataloader=train_hebb_dataloader,
                                  val_dataloader=dataset_test, 
                                  metric_dict=metric_dict,
                                  val_metrics=None
                                  )
    
    # save current model if there's improve
    current_val_criterion = metric_dict['criterion_val'][-1]
    if current_val_criterion < best_metric:
        best_metric = current_val_criterion
        best_state_dict = network.state_dict()

In [None]:
network.load_state_dict(best_state_dict)


### Visualization

In [None]:
X_s = network.forward(inpt_test_torch.to(training_parameters['device']))
outpt_val_pred = to_numpy(X_s[-1])

plt.figure()
plt.scatter(outpt_val_pred[0],
            outpt_val_pred[1], 
            c=color_test
           )
plt.colorbar()
plt.show()

In [None]:
plt.figure(figsize=(10,5), dpi=300)

ax2 = plt.subplot(1,1,1)
ax2.plot(metric_dict['criterion_val'])
ax2.set_title('Criterion val')

#####
# cval_min = min(metric_dict['criterion_val'])
# cval_max = max(metric_dict['criterion_val'])

# ax2.vlines(meta_switch_times['val'], \
#            cval_min, \
#            cval_max, \
#            color='purple', alpha=0.5)

# ax2.vlines(ordinary_switch_times['val'], \
#            cval_min, \
#            cval_max, \
#            color='green', alpha=0.5)
####

plt.tight_layout()
plt.show()

In [None]:
# criterion_meta = defaultdict(list)
# criterion_ordinary = defaultdict(list)

# for phase in ['val']:
#     for t in meta_switch_times[phase]:
#         criterion_meta[phase] += metric_dict[f'criterion_{phase}'][t:t+EPOCHS_META] 
        
#     for t in ordinary_switch_times[phase]:
#         criterion_ordinary[phase] += metric_dict[f'criterion_{phase}'][t:t+EPOCHS_HEBB] 

In [None]:
network.W_s[-1]@network.W_s[-1].T

In [None]:
n_fs = len(network.f_s)
ξ = torch.linspace(-4,4,1000).to(device)

for layer in range(n_fs):
    f_theta = network.f_s[layer]
    n_neurons = f_theta.input_dim
    y = torch.stack([ξ for _ in range(n_neurons)],0) # [n_neurons, T]
    
    f = to_numpy(f_theta(y)) # [n_neurons, T]
    
    fig, axes = plt.subplots(ncols=n_neurons, nrows=1, figsize=(n_neurons*3,3))
    
    for j,ax in enumerate(axes):
        ax.plot(to_numpy(ξ), f[j])
        ax.set_title(f'Neuron: {j}')
        
    fig.suptitle(f'Layer: {layer}', y=1.1, color='blue')
    
plt.tight_layout()
plt.show()