## High-dimensional dynamics of generalization error in neural networks

Attempt to reproduce Figure 5B in the paper and reproduce double descent.


In [None]:
#!/usr/bin/env python
# coding: utf-8

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import axes3d

import argparse
import os
import datetime
import pathlib
import random
import json
import numpy as np
import math

import torch


import sys
sys.path.append('../code/')
from linear_utils import linear_model, get_modulation_matrix
from train_utils import save_config

In [None]:
# argument written in command line format
# Mot-exempel ES-artikel (?)
#cli_args = '--seed 12 --save-results --risk-loss L2 -t 200000 -w 0.1 0.001 --lr 0.00001 0.00001 -d 50 -n 100 --hidden 250 --sigmas 1 --kappa 10.0'
cli_args = '--seed 12 --save-results --risk-loss L2 -t 100000 -w 0.01 0.01 --lr 0.001 0.001 -d 2 -n 10 --hidden 50 --sigmas 1 --kappa 10.0'
sigma_noise = 1.0
transform_data = True
cont_eigs = False

#cli_args = '--seed 12 --save-results --jacobian --risk-loss L2 -t 20000 -w 0.1 0.1 --lr 0.00001 -d 50 -n 1000 --hidden 50 --sigmas 1 --kappa 3'
#sigma_noise = 1.0



In [None]:

# get CLI parameters
parser = argparse.ArgumentParser(description='CLI parameters for training')
parser.add_argument('--root', type=str, default='', metavar='DIR',
                    help='Root directory')
parser.add_argument('-t', '--iterations', type=int, default=1e4, metavar='ITERATIONS',
                    help='Iterations (default: 1e4)')
parser.add_argument('-n', '--samples', type=int, default=100, metavar='N',
                    help='Number of samples (default: 100)')
parser.add_argument('--print-freq', type=int, default=1000,
                    help='CLI output printing frequency (default: 1000)')
parser.add_argument('--gpu', type=int, default=None,
                    help='Number of GPUS to use')
parser.add_argument('--seed', type=int, default=None,
                    help='Random seed')                        
parser.add_argument('-d', '--dim', type=int, default=50, metavar='DIMENSION',
                    help='Feature dimension (default: 50)')
parser.add_argument('--hidden', type=int, default=200, metavar='DIMENSION',
                    help='Hidden layer dimension (default: 200)')
parser.add_argument('--sigmas', type=str, default=None,
                    help='Sigmas')     
parser.add_argument('-r','--s-range', nargs='*', type=float,
                    help='Range for sigmas')
parser.add_argument('--kappa', type=float,
                    help='Eigenvalue ratio')
parser.add_argument('-w','--scales', nargs='*', type=float,
                    help='scale of the weights')
parser.add_argument('--lr', type=float, default=1e-4, nargs='*', metavar='LR',
                    help='learning rate (default: 1e-4)')              
parser.add_argument('--normalized', action='store_true', default=False,
                    help='normalize sample norm across features')
parser.add_argument('--risk-loss', type=str, default='MSE', metavar='LOSS',
                    help='Loss for validation')
parser.add_argument('--jacobian', action='store_true', default=False,
                    help='compute the SVD of the jacobian of the network')
parser.add_argument('--save-results', action='store_true', default=False,
                    help='Save the results for plots')
parser.add_argument('--details', type=str, metavar='N',
                    default='no_detail_given',
                    help='details about the experimental setup')


args = parser.parse_args(cli_args.split())

# directories
root = pathlib.Path(args.root) if args.root else pathlib.Path.cwd().parent

current_date = str(datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
args.outpath = (pathlib.Path.cwd().parent / 'results' / 'two_layer_nn' /  current_date)

if args.save_results:
    args.outpath.mkdir(exist_ok=True, parents=True)

if args.seed is not None:
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

In [None]:
def get_jacobian_two_layer(X, y, model, crit):
    
    grads = []
    for cx, cy in zip(X, y):

        cur_grads = []
        model.zero_grad()
        co = model(cx)
        co.backward(torch.ones(len(cy)))

        for p in model.parameters():
            if p.grad is not None and len(p.data.shape)>1:
                cur_grads.append(p.grad.data.numpy().flatten())
        grads.append(np.concatenate(cur_grads))
    return np.array(grads)


def compute_jacobian(Xs, ys, model, loss_fn, args):
    J = get_jacobian_two_layer(Xs, ys, model, loss_fn)
    uv, sv, vtv = np.linalg.svd(J)

    v1 = []
    v2 = []
    for i in range(sv.shape[0]):
        v1.append(np.linalg.norm(vtv[i,:][:np.prod([args.hidden, args.dim])]))
        v2.append(np.linalg.norm(vtv[i,:][-np.prod([1, args.hidden]):]))
    v1 = np.array(v1)
    v2 = np.array(v2)
    vTrec = np.linalg.norm(np.stack((v1, v2)), axis=0)


    return sv, v1, v2, vTrec


def plot_jacobian(sv, v1, v2, vTrec):
    weights = [{'w': 0.01, 'v': 1}, {'w': 1, 'v': 1}, {'w': 1, 'v': 0.01}]
    cmap = matplotlib.cm.get_cmap('viridis')
    colorList = [cmap(50/1000), cmap(350/1000), cmap(650/1000)]
    labelList = [r'$W$', r'$v$', r'$W + v$']

    fig = plt.figure(figsize=(12,8))

    ax_list = [plt.subplot(111)]

    ax_list[0].scatter(sv, v1, 
                    color=colorList[0], 
                    label=labelList[0],
                    lw=4)
    ax_list[0].scatter(sv, v2, 
                    color=colorList[1], 
                    label=labelList[1],
    #                 ls='dashed',
                    lw=4)
    ax_list[0].scatter(sv, vTrec, 
                    color=colorList[2], 
                    label=labelList[2],
    #                 ls='dashed',
                    lw=2)

    ax_list[-1].legend(loc=0, bbox_to_anchor=(1, 0.5), fontsize='x-large',
                       frameon=True, fancybox=True, shadow=True, ncol=1)
    ax_list[0].set_ylabel(r'$\Vert v \Vert_2^2$')

    # for i, ax in enumerate(ax_list): ax.set_title(r'$w = $' + str(weights[i]['w']) + 
    #                                               r';$v = $' + str(weights[i]['v']))
    for ax in ax_list: ax.set_xlabel(r'$\sigma_i$')
    for ax in ax_list: ax.set_xscale('log')
    for ax in ax_list: ax.set_yscale('log')
    plt.show()
    

In [None]:
zero_eigs = 0 if args.samples >= args.dim else (args.dim - args.samples)
p = int(np.ceil((args.dim - zero_eigs) / 2))
#k = args.kappa
#F = get_modulation_matrix(args.dim, p, k)

beta = np.ones((args.dim,))# * 0.01
#beta[1] *= -1.0
print(beta)

# Set to False if you want to transform after data sampling
scale_beta = False

d_out = 1      # dimension of y

# sample training set from the linear model
lin_model = linear_model(args.dim, sigma_noise=sigma_noise, beta=beta, scale_beta=scale_beta, normalized=False, sigmas=args.sigmas, s_range=args.s_range, coupled_noise=False, transform_data=transform_data, kappa=args.kappa, p=p, cont_eigs=cont_eigs, zero_eigs=zero_eigs)
Xs, ys = lin_model.sample(args.samples, train=True)

# sample the set for empirical risk calculation
Xt, yt = lin_model.sample(args.samples * 100, train=False) # 1000
beta = lin_model.beta.squeeze()

In [None]:
Xt = (Xt - Xs.mean(axis=0)) / Xs.std(axis=0)
Xs = (Xs - Xs.mean(axis=0)) / Xs.std(axis=0)

U, S, Vh = np.linalg.svd(Xs)
print(S)

In [None]:
# Transform data (masking)


# Uncomment if you want to transform after data sampling
#Xs = Xs @ F
#Xt = Xt @ F

#_, Ss, Vh = np.linalg.svd(Xs)
#_, St, _ = np.linalg.svd(np.transpose(Xt) @ Xt)

#print("train")
#print(Ss)
#print("test")
#print(St)

# Uncomment if you transform_data=False, but you want to decouple features
#Xs = Xs @ np.transpose(Vh)
#Xt = Xt @ np.transpose(Vh)

#beta = np.linalg.inv(F) @ beta


In [None]:
_, Sxy, _ = np.linalg.svd(ys.T@Xs)
print(Sxy)

In [None]:
Xs = torch.tensor(Xs, dtype=torch.float32).to(device)
ys = torch.tensor(ys.reshape((-1,1)), dtype=torch.float32).to(device)

Xt = torch.tensor(Xt, dtype=torch.float32).to(device)
yt = torch.tensor(yt.reshape((-1,1)), dtype=torch.float32).to(device)

In [None]:
# define loss functions
loss_fn = torch.nn.MSELoss(reduction='sum')
risk_fn = torch.nn.L1Loss(reduction='sum') if args.risk_loss == 'L1' else torch.nn.MSELoss(reduction='sum')

### Empirical

In [None]:
model = torch.nn.Sequential(
           torch.nn.Linear(args.dim, args.hidden, bias=False),
           #torch.nn.ReLU(),
           torch.nn.Linear(args.hidden, d_out, bias=False),
         ).to(device)      
 

fixed_init = True
scale_factor = 1.0

if args.scales:
    i = 0
    with torch.no_grad():
        for m in model:
            if type(m) == torch.nn.Linear:
                if i == 0:
                    if fixed_init:
                        m.weight.data = torch.ones(m.weight.data.shape) * args.scales[0] 
                        m.weight.data[:, p:] = m.weight.data[:, p:] * scale_factor
                    else:
                        torch.nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
                        m.weight.data = torch.mul(m.weight.data, args.scales[0])
                if i == 1:
                    if fixed_init:
                        m.weight.data = torch.ones(m.weight.data.shape) * args.scales[1]
                    else:
                        torch.nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
                        m.weight.data = torch.mul(m.weight.data, args.scales[1])
                i += 1
                

# use same learning rate for the two layers
#if isinstance(args.lr, list):
#    stepsize = [max(args.lr)] * 2
stepsize = args.lr

In [None]:
print(model[0].weight.data)

In [None]:
Zs = torch.nn.ReLU()(Xs @ torch.transpose(model[0].weight.data, 0, 1))
plt.plot(Zs, '*')
plt.show()

print(Zs.shape)

_, S, _ = np.linalg.svd(Zs)
print(S)

In [None]:

if args.samples >= args.dim:
    w_min = np.linalg.solve(np.transpose(Xs)@Xs, np.transpose(Xs)@ys).squeeze()
else:
    w_min = (np.transpose(Xs)@np.linalg.inv(Xs@np.transpose(Xs))@ys).squeeze()
    
print(w_min)


In [None]:
U, S, Vh= np.linalg.svd(Xs)
plt.plot(S, torch.norm(model[0].weight.data @ np.transpose(Vh), dim=0), '*')

In [None]:
Uz, Sz, Vzh = np.linalg.svd(Zs)
plt.plot(np.concatenate((Sz, np.zeros((int(Zs.shape[-1] - Sz.shape[0]),)))), 
         torch.norm(model[2].weight.data @ np.transpose(Vzh), dim=0), '*')

In [None]:
sv, v1, v2, vTrec = compute_jacobian(Xs, ys, model, loss_fn, args)
plot_jacobian(sv, v1, v2, vTrec)

In [None]:
# train the network
losses_emp = []
risks_emp = []
mse_weights_emp = []
weight_norms_emp = []
for t in range(int(args.iterations)):
    y_pred = model(Xs)

    loss = loss_fn(y_pred, ys)
    losses_emp.append(loss.item())

    if not t % args.print_freq:
        print(t, loss.item())
        
    if t in []:#0, 100, 1000, 10000]: #, 1000, 1500, 3000, 4500, 7500, 13000, 20000, 35000, 60000]:
        print(fr"Eigenvalue association at iteration {t}:")
        sv, v1, v2, vTrec = compute_jacobian(Xs, ys, model, loss_fn, args)
        plot_jacobian(sv, v1, v2, vTrec)
        
    if t in []:#[0, 10, 100, 1000, 10000, 50000, 100000, 150000]:
        print("Connection first layer - input:")
        U, S, Vh = np.linalg.svd(Xs)
        plt.plot(np.concatenate((S, np.zeros((int(Xs.shape[-1] - S.shape[0]),)))), 
         torch.norm(model[0].weight.data @ np.transpose(Vh), dim=0), '*')
        plt.show()
        
        print("Connection second layer - latent repr.:")
        Zs = torch.nn.ReLU()(Xs @ torch.transpose(model[0].weight.data, 0, 1))
        Uz, Sz, Vzh = np.linalg.svd(Zs)
        plt.plot(np.concatenate((Sz, np.zeros((int(Zs.shape[-1] - Sz.shape[0]),)))), 
         torch.norm(model[2].weight.data @ np.transpose(Vzh), dim=0), '*')
        plt.show()

    model.zero_grad()
    loss.backward()
    with torch.no_grad():
        i = 0
        w_tot = torch.diag(torch.ones(args.dim)) #[]
        weight_norms_it = []
        for param in model.parameters():
            param.data -= stepsize[i] * param.grad
            #w_tot.append(param.data.t().numpy().copy().reshape(1, -1)) #
            w_tot = w_tot @ param.data.t()
            
            if len(param.shape) > 1:
                weight_norms_it.append(float(torch.norm(param.data.flatten())))
                
                if i == 0:
                    weight_norms_it.append(float(torch.norm(param.data[:, :p].flatten())))
                    weight_norms_it.append(float(torch.norm(param.data[:, p:].flatten())))

                
                i += 1
                
        weight_norms_emp.append(weight_norms_it)
                
        #w_tot = np.column_stack(w_tot)
        w_tot = w_tot.squeeze()
        assert w_tot.shape == beta.shape
        mse_weights_emp.append(((w_tot.numpy()-beta) / beta)**2) #w_tot
                
            
    with torch.no_grad():
        yt_pred = model(Xt)
        
        risk = risk_fn(yt_pred, yt)
        risks_emp.append(risk.item())

        if not t % args.print_freq:
            print(t, risk.item())

In [None]:
sv, v1, v2, vTrec = compute_jacobian(Xs, ys, model, loss_fn, args)
plot_jacobian(sv, v1, v2, vTrec)

In [None]:
U, S, Vh = np.linalg.svd(Xs)
plt.plot(np.concatenate((S, np.zeros((int(Xs.shape[-1] - S.shape[0]),)))), 
 torch.norm(model[0].weight.data @ np.transpose(Vh), dim=0), '*')
plt.show()

Zs = torch.nn.ReLU()(Xs @ torch.transpose(model[0].weight.data, 0, 1))
Uz, Sz, Vzh = np.linalg.svd(Zs)
plt.plot(np.concatenate((Sz, np.zeros((int(Zs.shape[-1] - Sz.shape[0]),)))), 
 torch.norm(model[2].weight.data @ np.transpose(Vzh), dim=0), '*')
plt.show()

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, args.iterations-1, num=700)]

In [None]:
risks = np.array(risks_emp)
losses = np.array(losses_emp)
risks_w = np.row_stack(mse_weights_emp)
weight_norms = np.row_stack(weight_norms_emp)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

plot_all_dims = False

extra_axs = 0
if plot_all_dims:
    extra_axs = risks_w.shape[-1] #args.dim

num_axs = 5 + extra_axs
fig, ax = plt.subplots(num_axs, 1, figsize=(16, 4 * num_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

if plot_all_dims:
    for i in range(risks_w.shape[-1]):
        ax[2+i].set_xscale('log')
        ax[2+i].plot(geo_samples, risks_w[geo_samples, i], 
                color=colorList[2], 
                label=labelList[0],
                lw=4)

        ax[2+i].set_ylabel('MSE weights, ' + str(i))
        ax[2+i].set_xlabel(r'$t$ iterations')


ax[-3].set_ylabel('Weight norm') # (SCALED BY INIT!)
ax[-3].set_xlabel(r'$t$ iterations')

ax[-3].set_xscale('log')
ax[-3].plot(geo_samples, weight_norms[geo_samples, 0], # / weight_norms[0, 0], 
        color=colorList[1], 
        label="Layer 1",
        lw=4)

ax[-3].plot(geo_samples, weight_norms[geo_samples, 3], # / weight_norms[0, 3], 
        color=colorList[2], 
        label="Layer 2",
        lw=4)

ax[-3].legend()


ax[-2].set_ylabel('Weight norm') # (SCALED BY INIT!)
ax[-2].set_xlabel(r'$t$ iterations')

ax[-2].set_xscale('log')
ax[-2].plot(geo_samples, weight_norms[geo_samples, 0], # / weight_norms[0, 0], 
        color=colorList[1], 
        label="Layer 1",
        lw=4)

ax[-2].plot(geo_samples, weight_norms[geo_samples, 1], # / weight_norms[0, 1], 
        color=colorList[0], 
        label="Layer 1, large",
        lw=4)

ax[-2].plot(geo_samples, weight_norms[geo_samples, 2], # / weight_norms[0, 2], 
        color=colorList[2], 
        label="Layer 1, small",
        lw=4)

ax[-2].legend()


ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

In [None]:

risks = np.array(risks_emp)
losses = np.array(losses_emp)
risks_w = np.row_stack(mse_weights_emp)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

extra_axs = 2
fig, ax = plt.subplots(3 + extra_axs, 1, figsize=(12,12 + 4 * extra_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

p = int(args.dim / 2)
ax[2].set_xscale('log')
ax[2].plot(geo_samples, risks_w[geo_samples, :p].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[2].set_ylabel('MSE weights, large')
ax[2].set_xlabel(r'$t$ iterations')

ax[3].set_xscale('log')
ax[3].plot(geo_samples, risks_w[geo_samples, p:].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[3].set_ylabel('MSE weights, small')
ax[3].set_xlabel(r'$t$ iterations')


ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


#ax[0].set_ylim([9, 13])
#ax[1].set_ylim([80, 105])
#ax[2].set_ylim([1, 2.55])
#ax[3].set_ylim([0.35, 1.05])
#ax[4].set_ylim([0.9, 1.6])


plt.show()

In [None]:
fig, ax  = plt.subplots()

min_x, max_x = 3000, 10000-1
plt.plot(np.arange(min_x, max_x), risks[min_x:max_x])
print(risks.shape)

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(5, 15))

ax[0].plot(np.arange(min_x, max_x), weight_norms[min_x:max_x, 0])
ax[1].plot(np.arange(min_x, max_x), risks[min_x:max_x])
ax[2].plot(np.arange(min_x, max_x), losses[min_x:max_x])

## Empirical, one rank init

In [None]:
def rank_one_init(model, g_cpu, args):
    i = 0
    with torch.no_grad(): 
        p, q, u = 0, 0, 0
        for m in model:
            if type(m) == torch.nn.Linear:
                                
                if i == 0:
                    #q = torch.ones((m.weight.data.shape[1], 1)) * args.scales[0] 
                    q = torch.normal(mean=0, std=args.scales[0], size=(m.weight.data.shape[1], 1), generator=g_cpu)
                    z = q.clone()
                    u = 1
                else:
                    q = p.clone()
                    
                    if i == 1:
                        #u = torch.tensor(args.scales[1])
                        u = torch.normal(mean=0, std=args.scales[1], size=(), generator=g_cpu)
                    
                p = torch.normal(mean=0, std=1, size=(m.weight.data.shape[0], 1), generator=g_cpu)
                p /= torch.norm(p, dim=0) # = (-)1 for last layer

                m.weight.data = u * torch.matmul(p, q.T)
               
                i += 1
   
    return model, u.clone(), z 

In [None]:
model = torch.nn.Sequential(
           torch.nn.Linear(args.dim, args.hidden, bias=False),
           #torch.nn.ReLU(),
           torch.nn.Linear(args.hidden, d_out, bias=False),
         ).to(device)      
                
# use rank one initialization 
g_cpu = torch.Generator()
g_cpu.manual_seed(args.seed)
model, u_init, z_init = rank_one_init(model, g_cpu, args)
 
print(u_init)
print(z_init)

# use same learning rate for the two layers
#if isinstance(args.lr, list):
#    stepsize = [max(args.lr)] * 2
stepsize = args.lr

In [None]:
Wtot = torch.diag(torch.ones(args.dim))
for param in model.parameters():
    if len(param.shape) > 1:
        Wtot = Wtot @ param.data.t()
        
print(Wtot)
print(u_init * z_init)

In [None]:
# train the network
losses_emp = []
risks_emp = []
mse_weights_emp = []
grad_norms_emp = []
weights_rank = []
for t in range(int(args.iterations)):
    y_pred = model(Xs)

    loss = loss_fn(y_pred, ys)
    losses_emp.append(loss.item())

    if not t % args.print_freq:
        print(t, loss.item())
        
    if not t % 1000 and t <= 10000:
        print(fr"Eigenvalue association at iteration {t}:")
        sv, v1, v2, vTrec = compute_jacobian(Xs, ys, model, loss_fn)
        plot_jacobian(sv, v1, v2, vTrec)
    
    model.zero_grad()
    loss.backward()
    
    grad_norms_it = []
    weights_rank_it = []
    with torch.no_grad():
        i = 0
        w_tot = torch.diag(torch.ones(args.dim)) #[]
        for param in model.parameters():
            
            if len(param.shape) > 1:
                grad_norms_it.append(float(torch.norm(param.grad.flatten())))
            
            weights_rank_it.append(torch.linalg.matrix_rank(param.data))
                
            param.data -= stepsize[i] * param.grad
            #w_tot.append(param.data.t().numpy().copy().reshape(1, -1)) #
            w_tot = w_tot @ param.data.t()
            
            if len(param.shape) > 1:
                i += 1
                
        grad_norms_emp.append(grad_norms_it)
        weights_rank.append(weights_rank_it)
        
        #w_tot = np.column_stack(w_tot)
        w_tot = w_tot.squeeze()
        assert w_tot.shape == beta.shape
        mse_weights_emp.append(((w_tot-beta.squeeze()) / beta)**2) #w_tot
                
            
    with torch.no_grad():
        yt_pred = model(Xt)
        
        risk = risk_fn(yt_pred, yt)
        risks_emp.append(risk.item())

        if not t % args.print_freq:
            print(t, risk.item())

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, len(risks_emp)-1, num=700)]

In [None]:
risks = np.array(risks_emp)
losses = np.array(losses_emp)
risks_w = np.row_stack(mse_weights_emp)
grad_norms = np.row_stack(grad_norms_emp)
print(grad_norms[0])
weights_rank = np.row_stack(weights_rank)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

plot_all_dims = False

extra_axs = 0
if plot_all_dims:
    extra_axs = risks_w.shape[-1] #args.dim

fig, ax = plt.subplots(5 + extra_axs, 1, figsize=(12,20 + 4 * extra_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

ax[2].set_xscale('log')

for i in range(grad_norms.shape[-1]):
    ax[2].plot(geo_samples, grad_norms[geo_samples, i], 
            color=colorList[1], 
            label=labelList[0],
            lw=4)

ax[2].set_ylabel('gradient norms')
ax[2].set_xlabel(r'$t$ iterations')

if plot_all_dims:
    for i in range(risks_w.shape[-1]):
        ax[3+i].set_xscale('log')
        ax[3+i].plot(geo_samples, risks_w[geo_samples, i], 
                color=colorList[2], 
                label=labelList[0],
                lw=4)

        ax[3+i].set_ylabel('MSE weights, ' + str(i))
        ax[3+i].set_xlabel(r'$t$ iterations')


ax[-2].set_xscale('log')
ax[-2].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[-2].set_ylabel('MSE weights')
ax[-2].set_xlabel(r'$t$ iterations')

ax[-1].set_xscale('log')

for i in range(weights_rank.shape[-1]):
    ax[-1].plot(geo_samples, weights_rank[geo_samples, i], 
            color=colorList[i], 
            label=labelList[0],
            lw=4)

ax[-1].set_ylabel('rank of weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

In [None]:
print(weights_rank[:, 0])

In [None]:
risks = np.array(risks_emp)
losses = np.array(losses_emp)
risks_w = np.row_stack(mse_weights_emp)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

extra_axs = 2
fig, ax = plt.subplots(3 + extra_axs, 1, figsize=(12, 12 + 4 * extra_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

p = int(args.dim / 2)
ax[2].set_xscale('log')
ax[2].plot(geo_samples, risks_w[geo_samples, :p].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[2].set_ylabel('MSE weights, large')
ax[2].set_xlabel(r'$t$ iterations')

ax[3].set_xscale('log')
ax[3].plot(geo_samples, risks_w[geo_samples, p:].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[3].set_ylabel('MSE weights, small')
ax[3].set_xlabel(r'$t$ iterations')


ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


#ax[0].set_ylim([9, 13])
#ax[1].set_ylim([80, 105])
#ax[2].set_ylim([1, 2.55])
#ax[3].set_ylim([0.35, 1.05])
#ax[4].set_ylim([0.9, 1.6])


plt.show()

## Theoretical 

In [None]:
# With actual input data
def dt(u, z, S, St):
    assert S.shape == z.shape
    return (St - u * z * S)


def dzdt(u, z, S, St):
    return 2 * u * dt(u, z, S, St)


def dudt(u, z, S, St):
    return 2 * (dt(u, z, S, St) @ z.T).squeeze()


# Sampling only noise in output (and assuming that we know the true weights)
def dt_s(u, z, S, beta, eps):
    assert S.shape == z.shape
    return (beta - u * z) * S + eps * S**0.5


def dzdt_s(u, z, S, beta, eps):
    return 2 * u * dt_s(u, z, S, beta, eps)


def dudt_s(u, z, S, beta, eps):
    return 2 * (dt_s(u, z, S, beta, eps) @ z.T).squeeze()


In [None]:
# For the sake of not messing anything up
Xs_t, ys_t, Xt_t, yt_t = Xs.T, ys.T, Xt.T, yt.T

if transform_data:
    V = np.eye(args.dim)
    Uh = np.transpose(lin_model.left_singular_vecs)
    _, s, _ = np.linalg.svd(Xs_t.numpy(), full_matrices=True)
else:
    V, s, Uh = np.linalg.svd(Xs_t.numpy(), full_matrices=True)

V_tensor, Uh_tensor = torch.tensor(V, dtype=torch.float32), torch.tensor(Uh, dtype=torch.float32)
S = torch.tensor(np.concatenate((s**2, np.zeros(args.dim - s.shape[0]))).reshape(1, -1), dtype=torch.float32)
print(S.shape)

#eps_tensor = (torch.randn(size=(1, args.dim)) * sigma_noise)# @ torch.tensor(Uh).T)[:, :args.dim]).reshape(1, -1) #OBS: nu beror denna av input också

beta_tensor = torch.tensor(beta, dtype=torch.float32).reshape(1, -1)


In [None]:
St = ys_t @ Xs_t.T @ V_tensor
eps_tensor_0 = ((ys_t - beta_tensor @ Xs_t) @ Uh_tensor.T)[:, :args.dim]
eps_tensor = torch.concat((eps_tensor_0, torch.zeros((1, args.dim - s.shape[0]))), dim=-1)
St2 = beta_tensor * S + eps_tensor * S**0.5

print(St)
print(St2)


print(torch.abs(St - St2))

u = 2.0

print((St - u * z * S))
print((beta_tensor * S).shape)
print((eps_tensor * S**0.5).shape)
print((u * z * S).shape)

# TODO: Detta nedan bör vara exakt samma?? Är det någon precisionsgrej? Tror det, för ser bättre ut för större u
print(beta_tensor * S + eps_tensor * S**0.5 - u * z * S)
print(beta_tensor * S - u * z * S + eps_tensor * S**0.5 )
print((beta_tensor  - u * z) * S + eps_tensor * S**0.5 )


In [None]:
# Simulation
g_cpu = torch.Generator()
g_cpu.manual_seed(args.seed)

w_init = args.scales[0] 
u = torch.tensor(args.scales[1]) #torch.normal(0, torch.tensor(w_init), generator=g_cpu) # u_init.clone()  
z = torch.ones((1, args.dim)) * torch.tensor(args.scales[0]) #torch.normal(0, torch.tensor(w_init), size=(1, args.dim), generator=g_cpu) # z_init.T.clone()
#z[0, 1] *= -1999.0
print(u)
print(z)
print(u*z)

u_track, z_track = [], []
u_track.append(u)
z_track.append(z)

grad_u_track = []
grad_z_track = []

losses_teo = []
risks_teo = []
mse_weights_teo = []

for t in range(int(args.iterations)):
    
    grad_u = dudt_s(u_track[-1], z_track[-1], S, beta_tensor, eps_tensor)
    grad_z = dzdt_s(u_track[-1], z_track[-1], S, beta_tensor, eps_tensor)
    
    u = u + args.lr[1] * grad_u #dudt(u, z, S, St)
    z = z + args.lr[0] * grad_z #dzdt(u, z, S, St) 
    #u = (S.max() - S.min()) * z.mean()
    
    grad_u_track.append(torch.norm(grad_u))
    grad_z_track.append(torch.norm(grad_z))
    
    u_track.append(u)
    z_track.append(z)
    
    Wtot = u * z @ V_tensor.T

    y_pred = Wtot @ Xs_t

    loss = loss_fn(y_pred.T, ys_t.T)
    losses_teo.append(loss.item())

    mse_weights_teo.append((((Wtot.squeeze()-beta_tensor.squeeze()) / beta_tensor.squeeze())**2))

    if not t % args.print_freq:
        print(t, loss.item())
        
    yt_pred = Wtot @ Xt_t

    risk = risk_fn(yt_pred.T, yt_t.T)
    risks_teo.append(risk.item())

    if not t % args.print_freq:
        print(t, risk.item())

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, len(risks_teo)-1, num=700)]

In [None]:
risks = np.array(risks_teo)
losses = np.array(losses_teo)
risks_w = np.row_stack(mse_weights_teo)
path_u = np.array(u_track)
path_z = np.row_stack(z_track)
grad_norm_u = np.array(grad_u_track)
grad_norm_z = np.array(grad_z_track)
print(grad_norm_u[0])
print(grad_norm_z[0])

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

plot_all_dims = True

extra_axs = 0
if plot_all_dims:
    extra_axs = args.dim

fig, ax = plt.subplots(4 + extra_axs, 1, figsize=(12, 16 + 4 * extra_axs))
ax[0].set_xscale('log')

ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

ax[2].set_xscale('log')
ax[2].plot(geo_samples, grad_norm_u[geo_samples], lw=4)
ax[2].plot(geo_samples, grad_norm_z[geo_samples], lw=4)
ax[2].set_ylabel('gradient norms')
ax[2].set_xlabel(r'$t$ iterations')


if plot_all_dims:
    for i in range(args.dim):
        ax[3+i].set_xscale('log')
        ax[3+i].plot(geo_samples, risks_w[geo_samples, i], 
                color=colorList[2], 
                label=labelList[1],
                lw=4)

        ax[3+i].set_ylabel('MSE weights, ' + str(i))
        ax[3+i].set_xlabel(r'$t$ iterations')


ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[1],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

In [None]:
grad_norm_z[geo_samples] / grad_norms[geo_samples, 0]

In [None]:
u_track = np.array(u_track)
z_track = np.row_stack(z_track)
grad_u_track = np.array(grad_u_track)
grad_z_track = np.row_stack(grad_z_track)

fig, ax = plt.subplots(2, 1, figsize=(12, 8))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, u_track[geo_samples], label="u")
#ax[0].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), label="risk")
ax[0].set_ylabel('u')
ax[0].set_xlabel(r'$t$ iterations')


#ax[0].plot(geo_samples, losses[geo_samples] )


ax[1].set_xscale('log')

mean_z = True
if mean_z:
    ax[1].plot(geo_samples, z_track[geo_samples, :int(args.dim/2)].mean(axis=-1), label="z, large")
    ax[1].plot(geo_samples, z_track[geo_samples, int(args.dim/2):].mean(axis=-1), label="z, small")
    
else:
    for i in range(args.dim):
        ax[1].plot(geo_samples, z_track[geo_samples, i], label=fr"$z_{i}$")
    
ax[1].legend()
ax[1].set_ylabel('z')
ax[1].set_xlabel(r'$t$ iterations')

# Interaktionen ökar med tiden? Trenden följer lossen (eller lossen följer u)?


In [None]:
risks = np.array(risks_teo)
losses = np.array(losses_teo)
risks_w = np.row_stack(mse_weights_teo)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

extra_axs = 2
fig, ax = plt.subplots(3 + extra_axs, 1, figsize=(12,12 + 4 * extra_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

p = int(args.dim / 2)
ax[2].set_xscale('log')
ax[2].plot(geo_samples, risks_w[geo_samples, :p].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[1],
        lw=4)

ax[2].set_ylabel('MSE weights, large')
ax[2].set_xlabel(r'$t$ iterations')

ax[3].set_xscale('log')
ax[3].plot(geo_samples, risks_w[geo_samples, p:].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[1],
        lw=4)

ax[3].set_ylabel('MSE weights, small')
ax[3].set_xlabel(r'$t$ iterations')


ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[1],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


#ax[0].set_ylim([9, 13])
#ax[1].set_ylim([80, 105])
#ax[2].set_ylim([1, 2.55])
#ax[3].set_ylim([0.35, 1.05])
#ax[4].set_ylim([1, 1.6])


plt.show()

In [None]:
plt.plot(risks_teo[-2000:-1]) #MIIIIIHHH

In [None]:
(beta_tensor + eps_tensor * S**(-0.5)) @ (z**(-1)).T

In [None]:
plt.plot(np.row_stack(mse_weights_teo[-2000:-1]).mean(axis=-1))

# VECTOR FIELD

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -10, 10
grid_size = 20
x, y1 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))

# Example 1
y2_1 = 1

eps_0 = ys_t.numpy() - beta @ Xs_t.numpy()
eps = (eps_0 @ np.transpose(Uh)).squeeze()

d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5  #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2_1) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y2_1 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2_1
w1 = d1 * x

ax[0].quiver(x, y1, v, w1)

ax[0].set_xlabel("u")
ax[0].set_ylabel(fr"$z_1$")
ax[0].set_title(f"$z_2$ = {y2_1}")


# Example 2
y2_2 = 5
d2_2 = (beta[1] - x * y2_2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y2_2 * S[0, 1].numpy()

v_2 = d1 * y1 + d2_2 * y2_2

ax[1].quiver(x, y1, v_2, w1)

ax[1].set_xlabel("u")
ax[1].set_ylabel(fr"$z_1$")

ax[1].set_title(f"$z_2$ = {y2_2}")


# Example 3
y2_3 = 10
d2_3 = (beta[1] - x * y2_3) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y2_2 * S[0, 1].numpy()

v_3 = d1 * y1 + d2_3 * y2_3

ax[2].quiver(x, y1, v_3, w1)

ax[2].set_xlabel("u")
ax[2].set_ylabel(fr"$z_1$")

ax[2].set_title(f"$z_2$ = {y2_3}")


plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -10, 10
grid_size = 20
x, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))

# Example 1
y1_1 = 1

d1 = (beta[0] - x * y1_1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 # St[0, 0].numpy() - x * y2 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y1_1 * S[0, 1].numpy()

v = d1 * y1_1 + d2 * y2
w2 = d2 * x

ax[0].quiver(x, y2, v, w2)

ax[0].set_xlabel("u")
ax[0].set_ylabel(fr"$z_2$")
ax[0].set_title(f"$z_1$ = {y1_1}")


# Example 2
y1_2 = 2
d1_2 = (beta[0] - x * y1_2) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 # St[0, 0].numpy() - x * y1_2 * S[0, 0].numpy()

v_2 = d1_2 * y1_2 + d2 * y2

ax[1].quiver(x, y2, v_2, w2)

ax[1].set_xlabel("u")
ax[1].set_ylabel(fr"$z_2$")

ax[1].set_title(f"$z_1$ = {y1_2}")


# Example 3
y1_3 = 10
d1_3 = (beta[0] - x * y1_2) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 # St[0, 0].numpy() - x * y1_2 * S[0, 0].numpy()

v_3 = d1_3 * y1_3 + d2 * y2

ax[2].quiver(x, y2, v_3, w2)

ax[2].set_xlabel("u")
ax[2].set_ylabel(fr"$z_2$")

ax[2].set_title(f"$z_1$ = {y1_3}")


plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -20, 20
grid_size = 20
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


# Example 1
x_1 = 0.1
d1 = (beta[0] - x_1 * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x_1 * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2_1 * S[0, 1].numpy()

w1 = d1 * x_1
w2 = d2 * x_1

ax[0].quiver(y1, y2, w1, w2)

ax[0].set_xlabel(fr"$z_1$")
ax[0].set_ylabel(fr"$z_2$")
ax[0].set_title(f"u = {x_1}")


# Example 2
x_2 = 1
d1_2 = (beta[0] - x_2 * y1) * S[0, 0].numpy() + eps_0[0, 0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x_2 * y1 * S[0, 0].numpy()
d2_2 = (beta[1] - x_2 * y2) * S[0, 1].numpy() + eps_0[0, 1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x_2 * y2 * S[0, 1].numpy()

w1_2 = d1_2 * x_2
w2_2 = d2_2 * x_2

ax[1].quiver(y1, y2, w1_2, w2_2)

ax[1].set_xlabel(fr"$z_1$")
ax[1].set_ylabel(fr"$z_2$")
ax[1].set_title(f"u = {x_2}")


# Example 3
x_3 = 10

d1_3 = (beta[0] - x_3 * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x_2 * y1 * S[0, 0].numpy()
d2_3 = (beta[1] - x_3 * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x_2 * y2 * S[0, 1].numpy()

w1_3 = d1_3 * x_3
w2_3 = d2_3 * x_3

ax[2].quiver(y1, y2, w1_3, w2_3)

ax[2].set_xlabel(fr"$z_1$")
ax[2].set_ylabel(fr"$z_2$")
ax[2].set_title(f"u = {x_3}")

plt.show()

# För kappa > 1 så rör vi oss främst i z1-riktning, men stort u jämnar ut skillnaderna?
# Vi bör se att vi rör oss längre ifrån de sanna vikterna vid något tillfälle; men vet inte om vi ser det?

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -20, 20
grid_size = 20
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


# Example 1
x_1 = 0.1
d1 = (beta[0] - x_1 * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x_1 * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2_1 * S[0, 1].numpy()

w1 = d1 * x_1**2 + (d1 * y1 + d2 * y2) * y1  
w2 = d2 * x_1**2 + (d1 * y1 + d2 * y2) * y2

ax[0].quiver(y1, y2, w1, w2)

ax[0].set_xlabel(fr"$uz_1$")
ax[0].set_ylabel(fr"$uz_2$")
ax[0].set_title(f"u = {x_1}")


# Example 2
x_2 = 1.0
d1_2 = (beta[0] - x_2 * y1) * S[0, 0].numpy() + eps_0[0, 0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x_2 * y1 * S[0, 0].numpy()
d2_2 = (beta[1] - x_2 * y2) * S[0, 1].numpy() + eps_0[0, 1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x_2 * y2 * S[0, 1].numpy()

w1_2 = d1_2 * x_2**2 + (d1_2 * y1 + d2_2 * y2) * y1  
w2_2 = d2_2 * x_2**2 + (d1_2 * y1 + d2_2 * y2) * y2

ax[1].quiver(y1, y2, w1_2, w2_2)

ax[1].set_xlabel(fr"$z_1$")
ax[1].set_ylabel(fr"$z_2$")
ax[1].set_title(f"u = {x_2}")


# Example 3
x_3 = 10

d1_3 = (beta[0] - x_3 * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x_2 * y1 * S[0, 0].numpy()
d2_3 = (beta[1] - x_3 * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x_2 * y2 * S[0, 1].numpy()

w1_3 = d1_3 * x_2**2 + (d1_3 * y1 + d2_3 * y2) * y1  
w2_3 = d2_3 * x_3**2 + (d1_3 * y1 + d2_3 * y2) * y2

ax[2].quiver(y1, y2, w1_3, w2_3)

ax[2].set_xlabel(fr"$z_1$")
ax[2].set_ylabel(fr"$z_2$")
ax[2].set_title(f"u = {x_3}")

plt.show()

In [None]:
fig = plt.figure(figsize=plt.figaspect(0.5))
ax = fig.add_subplot(1, 2, 1, projection='3d')

v_min, v_max = -20, 20
grid_size = 10
x, y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                        np.linspace(v_min, v_max, grid_size),
                        np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

ax.quiver(x, y1, y2, v, w1, w2, length=0.0001)

ax.set_xlabel(fr"$u$")
ax.set_ylabel(fr"$z_1$")
ax.set_zlabel(fr"$z_2$")


plt.show()

In [None]:
# Wtot

fig, ax = plt.subplots(1, 5, figsize=(20, 4))


# u fixed, #1
v_min, v_max = -20, 20
grid_size = 10
x = 1
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                     np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[0].quiver(p1, p2, q1, q2)

ax[0].set_xlabel(fr"$w_1$")
ax[0].set_ylabel(fr"$w_2$")
ax[0].set_title(f"u = {x}")


# u fixed, #2
v_min, v_max = -20, 20
grid_size = 10
x = 5
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                     np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[1].quiver(p1, p2, q1, q2)

ax[1].set_xlabel(fr"$w_1$")
ax[1].set_ylabel(fr"$w_2$")
ax[1].set_title(f"u = {x}")


# u fixed, #3
v_min, v_max = -20, 20
grid_size = 10
x = 10
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                     np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[2].quiver(p1, p2, q1, q2)

ax[2].set_xlabel(fr"$w_1$")
ax[2].set_ylabel(fr"$w_2$")
ax[2].set_title(f"u = {x}")


# z_1 fixed 
y1 = 5
x, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[3].quiver(p1, p2, q1, q2)

ax[3].set_xlabel(fr"$w_1$")
ax[3].set_ylabel(fr"$w_2$")
ax[3].set_title(f"$z_1$ = {y1}")



# z_2 fixed 
y2 = 5
x, y1 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[4].quiver(p1, p2, q1, q2)

ax[4].set_xlabel(fr"$w_1$")
ax[4].set_ylabel(fr"$w_2$")
ax[4].set_title(f"$z_2$ = {y2}")

In [None]:
# MSE

fig, ax = plt.subplots(1, 3, figsize=(15, 4))


# u fixed 
v_min, v_max = -5, 5
grid_size = 10
x = 5
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                     np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

r = (beta[0] - p1)**2 + (beta[1] - p2)**2 

beta_v = beta @ np.transpose(V)  # TODO: do all other equations assume V=I or does it not matter?
rm = (x * y1 - beta_v[0]) * q1 + (x * y2 - beta_v[1]) * q2 

ax[0].quiver(y1, r, w1, rm)

ax[0].set_xlabel(fr"$z_1$")
ax[0].set_ylabel(fr"$L$")
ax[0].set_title(f"u = {x}")


# z_1 fixed 
y1 = 5
x, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

r = (beta[0] - p1)**2 + (beta[1] - p2)**2 # Tar V ut sig självt?

beta_v = beta @ np.transpose(V)  # TODO: do all other equations assume V=I or does it not matter?
rm = (x * y1 - beta_v[0]) * q1 + (x * y2 - beta_v[1]) * q2 

ax[1].quiver(y2, r, w2, rm)

ax[1].set_xlabel(fr"$z_2$")
ax[1].set_ylabel(fr"$L$")
ax[1].set_title(f"$z_1$ = {y1}")



# z_2 fixed 
y2 = 5
x, y1 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

r = (beta[0] - p1)**2 + (beta[1] - p2)**2 # Tar V ut sig självt?

beta_v = beta @ np.transpose(V)  # TODO: do all other equations assume V=I or does it not matter? Se logg 22/11 09:42.
rm = (x * y1 - beta_v[0]) * q1 + (x * y2 - beta_v[1]) * q2 

ax[2].quiver(y1, r, w1, rm)

ax[2].set_xlabel(fr"$z_1$")
ax[2].set_ylabel(fr"$L$")
ax[2].set_title(f"$z_2$ = {y2}")



In [None]:
# Comparison linear model

# Two-layer model
w_init = args.scales[0]
u = torch.normal(0, torch.tensor(w_init)) #, generator=g_cpu) # u_init.clone() 
z = torch.normal(0, torch.tensor(w_init), size=(1, args.dim)) #, generator=g_cpu) # z_init.T.clone()

grad_u = dudt_s(u, z, S, beta_tensor, eps_tensor)
grad_z = dzdt_s(u, z, S, beta_tensor, eps_tensor)

grad_w_tot = grad_u * z + u * grad_z
diff_w = grad_w_tot[0, 0] - grad_w_tot[0, 1]
print(diff_w)

# One-layer model
grad_w_tot_comp = dzdt_s(1.0, z, S, beta_tensor, eps_tensor)
diff_w_comp = grad_w_tot_comp[0, 0] - grad_w_tot_comp[0, 1]
print(diff_w_comp)

In [None]:
# Comparison linear model

# Two-layer model
u = torch.normal(0, torch.tensor(w_init)) #, generator=g_cpu) # u_init.clone() 
z = torch.normal(0, torch.tensor(w_init), size=(1, args.dim)) #, generator=g_cpu) # z_init.T.clone()

grad_u = dudt_s(u, z, S, beta_tensor, eps_tensor)
grad_z = dzdt_s(u, z, S, beta_tensor, eps_tensor)

grad_w_tot = grad_u * z + u * grad_z


# One-layer model
grad_w_tot_comp = dzdt_s(1.0, z, S, beta_tensor, eps_tensor)

diff_l = grad_w_tot[0, 0] - grad_w_tot_comp[0, 0]
diff_s = grad_w_tot[0, 1] - grad_w_tot_comp[0, 1]


print(grad_w_tot)
print(grad_w_tot_comp)

# men här är det ju att gradienten är större för den linjära modellen...

print(diff_l)
print(diff_s)