## Exact solutions, decoupling


In [None]:
#!/usr/bin/env python
# coding: utf-8

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import axes3d

import argparse
import os
import datetime
import pathlib
import random
import json
import numpy as np
import math

import torch
from scipy.stats import ortho_group

import sys
sys.path.append('../code/')
from linear_utils import linear_model, get_modulation_matrix
from train_utils import save_config

In [None]:
# argument written in command line format
cli_args = '--seed 12 --save-results --risk-loss L2 -t 100000 -w 0.1 0.1 --lr 0.001 0.001 -d 2 -n 50 --hidden 50 --sigmas 1 --kappa 10.0'
sigma_noise = 0.0
transform_data = True
cont_eigs = False

#cli_args = '--seed 12 --save-results --jacobian --risk-loss L2 -t 20000 -w 0.1 0.1 --lr 0.00001 -d 50 -n 1000 --hidden 50 --sigmas 1 --kappa 3'
#sigma_noise = 1.0



In [None]:

# get CLI parameters
parser = argparse.ArgumentParser(description='CLI parameters for training')
parser.add_argument('--root', type=str, default='', metavar='DIR',
                    help='Root directory')
parser.add_argument('-t', '--iterations', type=int, default=1e4, metavar='ITERATIONS',
                    help='Iterations (default: 1e4)')
parser.add_argument('-n', '--samples', type=int, default=100, metavar='N',
                    help='Number of samples (default: 100)')
parser.add_argument('--print-freq', type=int, default=1000,
                    help='CLI output printing frequency (default: 1000)')
parser.add_argument('--gpu', type=int, default=None,
                    help='Number of GPUS to use')
parser.add_argument('--seed', type=int, default=None,
                    help='Random seed')                        
parser.add_argument('-d', '--dim', type=int, default=50, metavar='DIMENSION',
                    help='Feature dimension (default: 50)')
parser.add_argument('--hidden', type=int, default=200, metavar='DIMENSION',
                    help='Hidden layer dimension (default: 200)')
parser.add_argument('--sigmas', type=str, default=None,
                    help='Sigmas')     
parser.add_argument('-r','--s-range', nargs='*', type=float,
                    help='Range for sigmas')
parser.add_argument('--kappa', type=float,
                    help='Eigenvalue ratio')
parser.add_argument('-w','--scales', nargs='*', type=float,
                    help='scale of the weights')
parser.add_argument('--lr', type=float, default=1e-4, nargs='*', metavar='LR',
                    help='learning rate (default: 1e-4)')              
parser.add_argument('--normalized', action='store_true', default=False,
                    help='normalize sample norm across features')
parser.add_argument('--risk-loss', type=str, default='MSE', metavar='LOSS',
                    help='Loss for validation')
parser.add_argument('--jacobian', action='store_true', default=False,
                    help='compute the SVD of the jacobian of the network')
parser.add_argument('--save-results', action='store_true', default=False,
                    help='Save the results for plots')
parser.add_argument('--details', type=str, metavar='N',
                    default='no_detail_given',
                    help='details about the experimental setup')


args = parser.parse_args(cli_args.split())

# directories
root = pathlib.Path(args.root) if args.root else pathlib.Path.cwd().parent

current_date = str(datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
args.outpath = (pathlib.Path.cwd().parent / 'results' / 'two_layer_nn' /  current_date)

if args.save_results:
    args.outpath.mkdir(exist_ok=True, parents=True)

if args.seed is not None:
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

In [None]:
zero_eigs = 0 if args.samples >= args.dim else (args.dim - args.samples)
p = int(np.ceil((args.dim - zero_eigs) / 2))
#k = args.kappa
#F = get_modulation_matrix(args.dim, p, k)


d_out = 2      # dimension of y

beta = np.ones((args.dim, d_out))

# Set to False if you want to transform after data sampling
scale_beta = False

# sample training set from the linear model
lin_model = linear_model(args.dim, dy=d_out, sigma_noise=sigma_noise, beta=beta, scale_beta=scale_beta, normalized=False, sigmas=args.sigmas, s_range=args.s_range, coupled_noise=False, transform_data=transform_data, kappa=args.kappa, p=p, cont_eigs=cont_eigs, zero_eigs=zero_eigs)
Xs, ys = lin_model.sample(args.samples, train=True)

# sample the set for empirical risk calculation
Xt, yt = lin_model.sample(args.samples * 100, train=False) # 1000
beta = lin_model.beta

In [None]:
beta = np.zeros((d_out, args.dim))
beta[0, 0] = 0.01001
beta[1, 1] = 1
ys = Xs @ beta.T
yt = Xt @ beta.T

beta = beta.T

In [None]:
Xs = torch.tensor(Xs, dtype=torch.float32).to(device)
ys = torch.tensor(ys.reshape((-1, d_out)), dtype=torch.float32).to(device)

Xt = torch.tensor(Xt, dtype=torch.float32).to(device)
yt = torch.tensor(yt.reshape((-1, d_out)), dtype=torch.float32).to(device)

In [None]:
np.linalg.svd(ys.T @ Xs)

In [None]:
_, _, Vhyx = np.linalg.svd(ys.T @ Xs)
Vhyx = torch.tensor(Vhyx)
np.linalg.svd(Vhyx @ Xs.T @ Xs @ Vhxy)

In [None]:
# define loss functions
loss_fn = torch.nn.MSELoss(reduction='sum')
risk_fn = torch.nn.L1Loss(reduction='sum') if args.risk_loss == 'L1' else torch.nn.MSELoss(reduction='sum')

## Empirical, diagonal

In [None]:
def diagonal_init(model, X, y, g_cpu, args, synaptic=False):
    
    Uxy, _, Vhxy = np.linalg.svd(y.T @ X @ torch.inverse(X.T @ X))
    
    if args.hidden > 1:
        R = torch.tensor(ortho_group.rvs(dim=args.hidden), dtype=torch.float32)
    else:
        R = torch.tensor([[1]], dtype=torch.float32)
    
    i = 0
    w_check = None
    with torch.no_grad(): 
        p, q, u = 0, 0, 0
        for m in model:
            if type(m) == torch.nn.Linear:
                     
                
                if i == 0:
                    D = torch.zeros(m.weight.data.shape, dtype=torch.float32).fill_diagonal_(1, wrap=False)
                    D[D == 1] = torch.normal(mean=0, std=args.scales[0], size=(min(m.weight.data.shape[0], m.weight.data.shape[1]),), generator=g_cpu)
                    
                    
                    if synaptic:
                        m.weight.data = torch.matmul(R, D)
                    else:
                        m.weight.data = torch.matmul(torch.matmul(R, D), torch.tensor(Vhxy))
                    
                    w_check = m.weight.data.clone()
                     
                else:
                    
                    if i == 1:
                        D = torch.zeros(m.weight.data.shape, dtype=torch.float32).fill_diagonal_(1, wrap=False)
                        D[D == 1] = torch.normal(mean=0, std=args.scales[1], size=(min(m.weight.data.shape[0], m.weight.data.shape[1]),), generator=g_cpu)
                        
                        if synaptic: 
                            m.weight.data = torch.matmul(D, R.T)
                        else:
                            m.weight.data = torch.matmul(torch.tensor(Uxy), torch.matmul(D, R.T))
                        
                        w_check = m.weight.data.clone() @ w_check
                        
                    else:
                        print("Initialisation only supported for two layers.")
               
                i += 1
   
        print("Initial weight: \n {}".format(w_check))
        w_check = (torch.tensor(Uxy).T @ w_check @ torch.tensor(Vhxy).T).numpy()
        print("Transformed initial weight: \n {}".format(w_check))

        #assert np.count_nonzero(w_check - np.diag(np.diagonal(w_check))) == 0
    
    return model

In [None]:
model = torch.nn.Sequential(
           torch.nn.Linear(args.dim, args.hidden, bias=False),
           #torch.nn.ReLU(),
           torch.nn.Linear(args.hidden, d_out, bias=False),
         ).to(device)      
                
# initialization 
synaptic = True
manual_grads = False
decaying_lr = False

XX = Xs.T @ Xs
yX = ys.T @ Xs
Uxy, Sxy, Vhxy = np.linalg.svd(yX @ torch.inverse(XX))
Uxy, Sxy, Vhxy = torch.tensor(Uxy), torch.diag(torch.tensor(Sxy)), torch.tensor(Vhxy)
VXXV = Vhxy.T @ Xs.T @ Xs @ Vhxy

g_cpu = torch.Generator()
g_cpu.manual_seed(args.seed)
model = diagonal_init(model, Xs, ys, g_cpu, args, synaptic=synaptic)
 

# use same learning rate for the two layers
#if isinstance(args.lr, list):
#    stepsize = [max(args.lr)] * 2
stepsize = args.lr

In [None]:
Wtot = torch.diag(torch.ones(args.dim))
for param in model.parameters():
    if len(param.shape) > 1:
        Wtot = Wtot @ param.data.t()
        
print(Wtot)

In [None]:
# train the network
losses_emp = []
risks_emp = []
mse_weights_emp = []
ws = []

base_stepsize = stepsize.copy()
print(stepsize[0])
gamma = 0.0001
for t in range(int(args.iterations)):
    
    if synaptic:
        y_pred = Xs @ torch.matmul(Uxy, torch.matmul(torch.matmul(model[1].weight, model[0].weight),  Vhxy)).T
    else:
        y_pred = model(Xs)
   

    loss = loss_fn(y_pred, ys)
    losses_emp.append(loss.item())

    if not t % args.print_freq:
        print(t, loss.item())
        
    model.zero_grad()
    loss.backward()


    with torch.no_grad():
        i = 0
        w_tot = torch.diag(torch.ones(args.dim)) #[]
        for param in model.parameters():
                
                
            if synaptic or manual_grads:
                
                if synaptic: 
                    grad_diff = torch.matmul(Sxy - torch.matmul(model[1].weight, model[0].weight), VXXV)
                elif manual_grads:
                    grad_diff = yX - torch.matmul(torch.matmul(model[1].weight, model[0].weight), XX)

                if i == 0:
                    grad = - torch.matmul(model[1].weight.T, grad_diff)
                    param.data -= stepsize[i] * grad
                elif i == 1:
                    grad = - torch.matmul(grad_diff, model[0].weight.T)
                    param.data -= stepsize[i] * grad

                else:
                    print("Training in synaptic weight space and manual gradients only supported for two layers.")
            else:
                param.data -= stepsize[i] * param.grad
                                
            w_tot = w_tot @ param.data.t()
            if len(param.shape) > 1:
                i += 1
        
        w_tot = w_tot.squeeze()
        ws.append(w_tot)
        assert w_tot.shape == beta.squeeze().shape
        mse_weights_emp.append((w_tot-beta.squeeze())**2)#((w_tot-beta.squeeze()) / beta.squeeze())**2) #w_tot
                
                    
        if synaptic:
            yt_pred = Xt @ torch.matmul(Uxy, torch.matmul(torch.matmul(model[1].weight, model[0].weight), Vhxy)).T
        else:
            yt_pred = model(Xt)

            
        risk = risk_fn(yt_pred, yt)
        risks_emp.append(risk.item())

        if not t % args.print_freq:
            print(t, risk.item())
            
    if decaying_lr:
        stepsize = [base_lr / (1 + gamma * (t+1)) for base_lr in base_stepsize]

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, len(risks_emp)-1, num=700)]

In [None]:
print(mse_weights_emp[-1])
print(mse_weights_emp[-1].reshape(-1))
print(w_tot)

In [None]:
risks = np.array(risks_emp)
losses = np.array(losses_emp)
risks_w = np.row_stack([mse_w.reshape(-1) for mse_w in mse_weights_emp])

#Uyx, _, Vhyx = np.linalg.svd(ys.T @ Xs)
#Uyx, Vhyx = torch.tensor(Uyx, dtype=torch.float32), torch.tensor(Vhyx, dtype=torch.float32)
#ws_transf = [(Uyx @ w.T @ Vhyx).T for w in ws]
#risks_w_transf = np.row_stack([((w-beta)**2).reshape(-1) for w in ws_transf])

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

plot_all_dims = True

num_axs = 3 + risks_w.shape[-1] if plot_all_dims else 3

fig, ax = plt.subplots(num_axs, 1, figsize=(12, 4 * num_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')


if plot_all_dims:
    for i in range(risks_w.shape[-1]):
        ax[2+i].set_xscale('log')
        ax[2+i].plot(geo_samples, risks_w[geo_samples, i], 
                color=colorList[2], 
                label=labelList[0],
                lw=4)

        ax[2+i].set_ylabel('MSE weights, ' + str(i))
        ax[2+i].set_xlabel(r'$t$ iterations')


ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

## Visualisation 

In [None]:
# SYNAPTIC WEIGHT SPACE

if synaptic: 
    track_w = np.row_stack([w.T.reshape(-1) for w in ws]) # (Uxy @ w.T @ Vhxy) 
else:
    track_w = np.row_stack([(Uxy.T @ w.T.reshape(d_out, -1) @ Vhxy.T).reshape(-1) for w in ws]) 
    
fig, ax = plt.subplots(2, 3, figsize=(18, 12))

ax[0, 0].plot(track_w[:, 0], track_w[:, 1])
ax[0, 0].plot(track_w[0, 0], track_w[0, 1], '*')
ax[0, 0].set_xlabel("Weight 1")
ax[0, 0].set_ylabel("Weight 2")

ax[0, 1].plot(track_w[:, 0], track_w[:, 2])
ax[0, 1].plot(track_w[0, 0], track_w[0, 2], '*')
ax[0, 1].set_xlabel("Weight 1")
ax[0, 1].set_ylabel("Weight 3")

ax[0, 2].plot(track_w[:, 0], track_w[:, 3])
ax[0, 2].plot(track_w[0, 0], track_w[0, 3], '*')
ax[0, 2].set_xlabel("Weight 1")
ax[0, 2].set_ylabel("Weight 4")

ax[1, 0].plot(track_w[:, 1], track_w[:, 2])
ax[1, 0].plot(track_w[0, 1], track_w[0, 2], '*')
ax[1, 0].set_xlabel("Weight 2")
ax[1, 0].set_ylabel("Weight 3")

ax[1, 1].plot(track_w[:, 1], track_w[:, 3])
ax[1, 1].plot(track_w[0, 1], track_w[0, 3], '*')
ax[1, 1].set_xlabel("Weight 2")
ax[1, 1].set_ylabel("Weight 4")

ax[1, 2].plot(track_w[:, 2], track_w[:, 3])
ax[1, 2].plot(track_w[0, 2], track_w[0, 3], '*')
ax[1, 2].set_xlabel("Weight 3")
ax[1, 2].set_ylabel("Weight 4")

#for axis in ax.reshape(-1):
#    axis.plot([0, 1], [0, 1])
#    axis.plot

_, S, Vh = np.linalg.svd((Uxy.T @ XX @ Vhxy.T).numpy())
print(S)
for axis in ax.reshape(-1):
    axis.arrow(0, 0, Vh[0, 0], Vh[0, 1])
    axis.annotate("", xy=(Vh[0, 0], Vh[0, 1]), xytext=(0, 0), 
                      arrowprops=dict(arrowstyle="->"))
    axis.arrow(0, 0, Vh[1, 0], Vh[1, 1])
    axis.annotate("", xy=(Vh[1, 0], Vh[1, 1]), xytext=(0, 0), 
                  arrowprops=dict(arrowstyle="->"))
    

In [None]:
# ORIGINAL WEIGHT SPACE

if synaptic: 
    track_w = np.row_stack([(Uxy @ w.T.reshape(d_out, -1) @ Vhxy).reshape(-1) for w in ws]) # 
else:
    track_w = np.row_stack([w.T.reshape(-1) for w in ws]) 
    
fig, ax = plt.subplots(2, 3, figsize=(18, 12))

ax[0, 0].plot(track_w[:, 0], track_w[:, 1])
ax[0, 0].plot(track_w[0, 0], track_w[0, 1], '*')
ax[0, 0].set_xlabel("Weight 1")
ax[0, 0].set_ylabel("Weight 2")

ax[0, 1].plot(track_w[:, 0], track_w[:, 2])
ax[0, 1].plot(track_w[0, 0], track_w[0, 2], '*')
ax[0, 1].set_xlabel("Weight 1")
ax[0, 1].set_ylabel("Weight 3")

ax[0, 2].plot(track_w[:, 0], track_w[:, 3])
ax[0, 2].plot(track_w[0, 0], track_w[0, 3], '*')
ax[0, 2].set_xlabel("Weight 1")
ax[0, 2].set_ylabel("Weight 4")

ax[1, 0].plot(track_w[:, 1], track_w[:, 2])
ax[1, 0].plot(track_w[0, 1], track_w[0, 2], '*')
ax[1, 0].set_xlabel("Weight 2")
ax[1, 0].set_ylabel("Weight 3")

ax[1, 1].plot(track_w[:, 1], track_w[:, 3])
ax[1, 1].plot(track_w[0, 1], track_w[0, 3], '*')
ax[1, 1].set_xlabel("Weight 2")
ax[1, 1].set_ylabel("Weight 4")

ax[1, 2].plot(track_w[:, 2], track_w[:, 3])
ax[1, 2].plot(track_w[0, 2], track_w[0, 3], '*')
ax[1, 2].set_xlabel("Weight 3")
ax[1, 2].set_ylabel("Weight 4")

#for axis in ax.reshape(-1):
#    axis.plot([0, 1], [0, 1])
#    axis.plot


for axis in ax.reshape(-1):
    axis.arrow(0, 0, Vhxy[0, 0], Vhxy[0, 1])
    axis.annotate("", xy=(Vhxy[0, 0], Vhxy[0, 1]), xytext=(0, 0), 
                      arrowprops=dict(arrowstyle="->"))
    axis.arrow(0, 0, Vhxy[1, 0], Vhxy[1, 1])
    axis.annotate("", xy=(Vhxy[1, 0], Vhxy[1, 1]), xytext=(0, 0), 
                  arrowprops=dict(arrowstyle="->"))
    

In [None]:
print("Final weights:")
print(ws[-1].T) 

if synaptic:
    print(f"Global minimum: \n {Sxy}")
    print(f"Transformed final weights (true space): \n {Uxy@ws[-1].T@Vhxy}")
    print(f"Transformed global minimum (true space): \n {ys.T @ Xs @ torch.inverse(Xs.T @ Xs)}")
    print(f"Final loss: {loss_fn(Xs @ (Uxy@ws[-1].T@Vhxy).T, ys)}")
else:
    print(f"Global minimum: \n {ys.T @ Xs @ torch.inverse(Xs.T @ Xs)}")
    print(f"Transformed final weights (synaptic weight space): \n {Uxy.T@ws[-1].T@Vhxy.T}")
    print(f"Transformed global minimum (synaptic weight space): \n {Sxy}")
    print(f"Final loss: \n {loss_fn(Xs @ ws[-1], ys)}")

print(f"Loss of global minimum: \n {loss_fn(Xs @ (ys.T @ Xs @ torch.inverse(Xs.T @ Xs)).T, ys)}")


## 