## High-dimensional dynamics of generalization error in neural networks

Attempt to reproduce Figure 5B in the paper and reproduce double descent.


In [None]:
#!/usr/bin/env python
# coding: utf-8

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import axes3d

import argparse
import os
import datetime
import pathlib
import random
import json
import numpy as np
import math

import torch

import sys
sys.path.append('../code/')
from linear_utils import linear_model
from train_utils import save_config

In [None]:
# argument written in command line format
cli_args = '--seed 12 --save-results --risk-loss L2 -t 20000 -w 0.1 0.1 --lr 0.001 -d 2 -n 1000 --hidden 50 --sigmas 1 --kappa 5'
sigma_noise = 2.0
beta = np.array([1.0, 1.0])
transform_data = True

#cli_args = '--seed 12 --save-results --jacobian --risk-loss L2 -t 20000 -w 0.1 0.1 --lr 0.00001 -d 50 -n 1000 --hidden 50 --sigmas 1 --kappa 3'
#sigma_noise = 1.0



In [None]:
"""
A fully-connected network with one hidden layer, trained to predict y from x
by minimizing the MSE loss.
"""

# get CLI parameters
parser = argparse.ArgumentParser(description='CLI parameters for training')
parser.add_argument('--root', type=str, default='', metavar='DIR',
                    help='Root directory')
parser.add_argument('-t', '--iterations', type=int, default=1e4, metavar='ITERATIONS',
                    help='Iterations (default: 1e4)')
parser.add_argument('-n', '--samples', type=int, default=100, metavar='N',
                    help='Number of samples (default: 100)')
parser.add_argument('--print-freq', type=int, default=1000,
                    help='CLI output printing frequency (default: 1000)')
parser.add_argument('--gpu', type=int, default=None,
                    help='Number of GPUS to use')
parser.add_argument('--seed', type=int, default=None,
                    help='Random seed')                        
parser.add_argument('-d', '--dim', type=int, default=50, metavar='DIMENSION',
                    help='Feature dimension (default: 50)')
parser.add_argument('--hidden', type=int, default=200, metavar='DIMENSION',
                    help='Hidden layer dimension (default: 200)')
parser.add_argument('--sigmas', type=str, default=None,
                    help='Sigmas')     
parser.add_argument('-r','--s-range', nargs='*', type=float,
                    help='Range for sigmas')
parser.add_argument('--kappa', type=float,
                    help='Eigenvalue ratio')
parser.add_argument('-w','--scales', nargs='*', type=float,
                    help='scale of the weights')
parser.add_argument('--lr', type=float, default=1e-4, nargs='*', metavar='LR',
                    help='learning rate (default: 1e-4)')              
parser.add_argument('--normalized', action='store_true', default=False,
                    help='normalize sample norm across features')
parser.add_argument('--risk-loss', type=str, default='MSE', metavar='LOSS',
                    help='Loss for validation')
parser.add_argument('--jacobian', action='store_true', default=False,
                    help='compute the SVD of the jacobian of the network')
parser.add_argument('--save-results', action='store_true', default=False,
                    help='Save the results for plots')
parser.add_argument('--details', type=str, metavar='N',
                    default='no_detail_given',
                    help='details about the experimental setup')


args = parser.parse_args(cli_args.split())

# directories
root = pathlib.Path(args.root) if args.root else pathlib.Path.cwd().parent

current_date = str(datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
args.outpath = (pathlib.Path.cwd().parent / 'results' / 'two_layer_nn' /  current_date)

if args.save_results:
    args.outpath.mkdir(exist_ok=True, parents=True)

if args.seed is not None:
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

In [None]:
d_out = 1      # dimension of y

# sample training set from the linear model
lin_model = linear_model(args.dim, sigma_noise=sigma_noise, beta=beta, normalized=False, sigmas=args.sigmas, s_range=args.s_range, coupled_noise=False, transform_data=transform_data, kappa=args.kappa)
Xs, ys = lin_model.sample(args.samples, train=True)

Xs = torch.tensor(Xs, dtype=torch.float32).to(device)
ys = torch.tensor(ys.reshape((-1,1)), dtype=torch.float32).to(device)

# sample the set for empirical risk calculation
Xt, yt = lin_model.sample(args.samples, train=False)
Xt = torch.tensor(Xt, dtype=torch.float32).to(device)
yt = torch.tensor(yt.reshape((-1,1)), dtype=torch.float32).to(device)


In [None]:
# define loss functions
loss_fn = torch.nn.MSELoss(reduction='sum')
risk_fn = torch.nn.L1Loss(reduction='mean') if args.risk_loss == 'L1' else loss_fn

### Empirical

In [None]:
model = torch.nn.Sequential(
           torch.nn.Linear(args.dim, args.hidden, bias=False),
        #torch.nn.Sigmoid(),
           torch.nn.Linear(args.hidden, d_out, bias=False),
         ).to(device)      
                
# use kaiming initialization                
if args.scales:
    i = 0
    with torch.no_grad():
        for m in model:
            if type(m) == torch.nn.Linear:
                if i == 0:
                    torch.nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
                    m.weight.data = torch.mul(m.weight.data, args.scales[0])
                if i == 1:
                    torch.nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
                    m.weight.data = torch.mul(m.weight.data, args.scales[1])
                i += 1
                

# use same learning rate for the two layers
if isinstance(args.lr, list):
    stepsize = [max(args.lr)] * 2

In [None]:
# train the network
losses_emp = []
risks_emp = []
mse_weights_emp = []
for t in range(int(args.iterations)):
    y_pred = model(Xs)

    loss = loss_fn(y_pred, ys)
    losses_emp.append(loss.item())

    if not t % args.print_freq:
        print(t, loss.item())

    model.zero_grad()
    loss.backward()
    with torch.no_grad():
        i = 0
        w_tot = torch.diag(torch.ones(args.dim))
        for param in model.parameters():
            param.data -= stepsize[i] * param.grad
            w_tot = w_tot @ param.data.t()

            if not len(param.shape) > 1:
                i += 1
        
        w_tot = w_tot.squeeze()
        assert w_tot.shape == beta.shape
        mse_weights_emp.append(((w_tot-beta)**2))
                
            
    with torch.no_grad():
        yt_pred = model(Xt)
        
        risk = risk_fn(yt_pred, yt)
        risks_emp.append(risk.item())

        if not t % args.print_freq:
            print(t, risk.item())

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, len(risks_emp)-1, num=700)]
risks_w = np.row_stack(mse_weights_emp)

In [None]:
risks = np.array(risks_emp)
losses = np.array(losses_emp)
risks_w = np.row_stack(mse_weights_emp)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

fig, ax = plt.subplots(3 + args.dim, 1, figsize=(12,12 + 4 * args.dim))
ax[0].set_xscale('log')

ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

for i in range(args.dim):
    ax[2+i].set_xscale('log')
    ax[2+i].plot(geo_samples, risks_w[geo_samples, i], 
            color=colorList[2], 
            label=labelList[0],
            lw=4)

    ax[2+i].set_ylabel('MSE weights, ' + str(i))
    ax[2+i].set_xlabel(r'$t$ iterations')
    

ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[0],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

In [None]:
plt.plot(risks_emp[1000:])

## Theoretical 

In [None]:
# With actual input data
def dt(u, z, S, St):
    assert S.shape == z.shape
    return (St - u * z * S)


def dzdt(u, z, S, St):
    return u * dt(u, z, S, St)


def dudt(u, z, S, St):
    return (dt(u, z, S, St) @ z.T).squeeze()


# Sampling only noise in output (and assuming that we know the true weights)
def dt_s(u, z, S, beta, eps):
    assert S.shape == z.shape
    return (beta - u * z) * S + eps * S**0.5


def dzdt_s(u, z, S, beta, eps):
    return u * dt_s(u, z, S, beta, eps)


def dudt_s(u, z, S, beta, eps):
    return (dt_s(u, z, S, beta, eps) @ z.T).squeeze()


In [None]:
# For the sake of not messing anything up
Xs_t, ys_t, Xt_t, yt_t = Xs.T, ys.T, Xt.T, yt.T

if transform_data:
    V = np.eye(args.dim)
    Uh = np.transpose(lin_model.right_singular_vecs)
    _, s, _ = np.linalg.svd(Xs_t.numpy(), full_matrices=True)
else:
    V, s, Uh = np.linalg.svd(Xs_t.numpy(), full_matrices=True)

V_tensor, Uh_tensor = torch.tensor(V, dtype=torch.float32), torch.tensor(Uh, dtype=torch.float32)
S = torch.tensor(np.concatenate((s**2, np.zeros(args.dim - s.shape[0]))).reshape(1, -1), dtype=torch.float32)
St = ys_t @ Xs_t.T @ V_tensor

#eps_tensor = (torch.randn(size=(1, args.dim)) * sigma_noise)# @ torch.tensor(Uh).T)[:, :args.dim]).reshape(1, -1) #OBS: nu beror denna av input också

beta_tensor = torch.tensor(beta, dtype=torch.float32).reshape(1, -1)
eps_tensor = ((ys_t - beta_tensor @ Xs_t) @ Uh_tensor.T)[:, :args.dim]

In [None]:
# Simulation
w_init = args.scales[0] 
u = torch.normal(0, torch.tensor(w_init))
z = torch.normal(0, torch.tensor(w_init), size=(1, args.dim))
print(u)
print(z)

losses_teo = []
risks_teo = []
mse_weights_teo = []
for t in range(int(args.iterations)):
    
    u = u + args.lr[0] * dudt_s(u, z, S, beta_tensor, eps_tensor) #dudt(u, z, S, St)
    z = z + args.lr[0] * dzdt_s(u, z, S, beta_tensor, eps_tensor) #dzdt(u, z, S, St)
    
    Wtot = u * z @ V_tensor.T

    y_pred = Wtot @ Xs_t

    loss = loss_fn(y_pred.T, ys_t.T)
    losses_teo.append(loss.item())

    mse_weights_teo.append(((Wtot.squeeze()-beta)**2))

    if not t % args.print_freq:
        print(t, loss.item())
        
    yt_pred = Wtot @ Xt_t

    risk = risk_fn(yt_pred.T, yt_t.T)
    risks_teo.append(risk.item())

    if not t % args.print_freq:
        print(t, risk.item())

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, len(risks_teo)-1, num=700)]

In [None]:
risks = np.array(risks_teo)
losses = np.array(losses_teo)
risks_w = np.row_stack(mse_weights_teo)

cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['empirical', 'theoretical']

fig, ax = plt.subplots(3 + args.dim, 1, figsize=(12,12 + 4 * args.dim))
ax[0].set_xscale('log')

ax[0].plot(geo_samples, risks[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('risk')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')

for i in range(args.dim):
    ax[2+i].set_xscale('log')
    ax[2+i].plot(geo_samples, risks_w[geo_samples, i], 
            color=colorList[2], 
            label=labelList[1],
            lw=4)

    ax[2+i].set_ylabel('MSE weights, ' + str(i))
    ax[2+i].set_xlabel(r'$t$ iterations')
    

ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w[geo_samples, :].mean(axis=-1), 
        color=colorList[2], 
        label=labelList[1],
        lw=4)

ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

In [None]:
plt.plot(risks_teo[-2000:-1]) #MIIIIIHHH

In [None]:
plt.plot(np.row_stack(mse_weights_teo[-2000:-1]).mean(axis=-1))

# VECTOR FIELD

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -10, 10
grid_size = 20
x, y1 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))

# Example 1
y2_1 = 1

eps_0 = ys_t.numpy() - beta @ Xs_t.numpy()
eps = (eps_0 @ np.transpose(Uh)).squeeze()

d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5  #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2_1) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y2_1 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2_1
w1 = d1 * x

ax[0].quiver(x, y1, v, w1)

ax[0].set_xlabel("u")
ax[0].set_ylabel(fr"$z_1$")
ax[0].set_title(f"$z_2$ = {y2_1}")


# Example 2
y2_2 = 5
d2_2 = (beta[1] - x * y2_2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y2_2 * S[0, 1].numpy()

v_2 = d1 * y1 + d2_2 * y2_2

ax[1].quiver(x, y1, v_2, w1)

ax[1].set_xlabel("u")
ax[1].set_ylabel(fr"$z_1$")

ax[1].set_title(f"$z_2$ = {y2_2}")


# Example 3
y2_3 = 10
d2_3 = (beta[1] - x * y2_3) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y2_2 * S[0, 1].numpy()

v_3 = d1 * y1 + d2_3 * y2_3

ax[2].quiver(x, y1, v_3, w1)

ax[2].set_xlabel("u")
ax[2].set_ylabel(fr"$z_1$")

ax[2].set_title(f"$z_2$ = {y2_3}")


plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -10, 10
grid_size = 20
x, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))

# Example 1
y1_1 = 1

d1 = (beta[0] - x * y1_1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 # St[0, 0].numpy() - x * y2 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5  #St[0, 1].numpy() - x * y1_1 * S[0, 1].numpy()

v = d1 * y1_1 + d2 * y2
w2 = d2 * x

ax[0].quiver(x, y2, v, w2)

ax[0].set_xlabel("u")
ax[0].set_ylabel(fr"$z_2$")
ax[0].set_title(f"$z_1$ = {y1_1}")


# Example 2
y1_2 = 2
d1_2 = (beta[0] - x * y1_2) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 # St[0, 0].numpy() - x * y1_2 * S[0, 0].numpy()

v_2 = d1_2 * y1_2 + d2 * y2

ax[1].quiver(x, y2, v_2, w2)

ax[1].set_xlabel("u")
ax[1].set_ylabel(fr"$z_2$")

ax[1].set_title(f"$z_1$ = {y1_2}")


# Example 3
y1_3 = 10
d1_3 = (beta[0] - x * y1_2) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 # St[0, 0].numpy() - x * y1_2 * S[0, 0].numpy()

v_3 = d1_3 * y1_3 + d2 * y2

ax[2].quiver(x, y2, v_3, w2)

ax[2].set_xlabel("u")
ax[2].set_ylabel(fr"$z_2$")

ax[2].set_title(f"$z_1$ = {y1_3}")


plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

v_min, v_max = -20, 20
grid_size = 20
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


# Example 1
x_1 = 0.1
d1 = (beta[0] - x_1 * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x_1 * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2_1 * S[0, 1].numpy()

w1 = d1 * x_1
w2 = d2 * x_1

ax[0].quiver(y1, y2, w1, w2)

ax[0].set_xlabel(fr"$z_1$")
ax[0].set_ylabel(fr"$z_2$")
ax[0].set_title(f"u = {x_1}")


# Example 2
x_2 = 10
d1_2 = (beta[0] - x_2 * y1) * S[0, 0].numpy() + eps_0[0, 0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x_2 * y1 * S[0, 0].numpy()
d2_2 = (beta[1] - x_2 * y2) * S[0, 1].numpy() + eps_0[0, 1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x_2 * y2 * S[0, 1].numpy()

w1_2 = d1_2 * x_2
w2_2 = d2_2 * x_2

ax[1].quiver(y1, y2, w1_2, w2_2)

ax[1].set_xlabel(fr"$z_1$")
ax[1].set_ylabel(fr"$z_2$")
ax[1].set_title(f"u = {x_2}")


# Example 3
x_3 = 10

d1_3 = (beta[0] - x_3 * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x_2 * y1 * S[0, 0].numpy()
d2_3 = (beta[1] - x_3 * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x_2 * y2 * S[0, 1].numpy()

w1_3 = d1_3 * x_3
w2_3 = d2_3 * x_3

ax[2].quiver(y1, y2, w1_3, w2_3)

ax[2].set_xlabel(fr"$z_1$")
ax[2].set_ylabel(fr"$z_2$")
ax[2].set_title(f"u = {x_3}")

plt.show()

# För kappa > 1 så rör vi oss främst i z1-riktning 
# u förskjuter minima
# Vi bör se att vi rör oss längre ifrån de sanna vikterna vid något tillfälle; men vet inte om vi ser det?

In [None]:
fig = plt.figure(figsize=plt.figaspect(0.5))
ax = fig.add_subplot(1, 2, 1, projection='3d')



v_min, v_max = -20, 20
grid_size = 10
x, y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                        np.linspace(v_min, v_max, grid_size),
                        np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

ax.quiver(x, y1, y2, v, w1, w2, length=0.0001)

ax.set_xlabel(fr"$u$")
ax.set_ylabel(fr"$z_1$")
ax.set_zlabel(fr"$z_2$")


plt.show()

In [None]:
# Wtot

fig, ax = plt.subplots(1, 3, figsize=(15, 4))


# u fixed 
v_min, v_max = -20, 20
grid_size = 10
x = 5
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                     np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[0].quiver(p1, p2, q1, q2)

ax[0].set_xlabel(fr"$w_1$")
ax[0].set_ylabel(fr"$w_2$")
ax[0].set_title(f"u = {x}")


# z_1 fixed 
y1 = 5
x, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[1].quiver(p1, p2, q1, q2)

ax[1].set_xlabel(fr"$w_1$")
ax[1].set_ylabel(fr"$w_2$")
ax[1].set_title(f"$z_1$ = {y1}")



# z_2 fixed 
y2 = 5
x, y1 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

ax[2].quiver(p1, p2, q1, q2)

ax[2].set_xlabel(fr"$w_1$")
ax[2].set_ylabel(fr"$w_2$")
ax[2].set_title(f"$z_2$ = {y2}")

In [None]:
# MSE

fig, ax = plt.subplots(1, 3, figsize=(15, 4))


# u fixed 
v_min, v_max = -5, 5
grid_size = 10
x = 5
y1, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                     np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

r = (beta[0] - p1)**2 + (beta[1] - p2)**2 

beta_v = beta @ np.transpose(V)  # TODO: do all other equations assume V=I or does it not matter?
rm = (x * y1 - beta_v[0]) * q1 + (x * y2 - beta_v[1]) * q2 

ax[0].quiver(y1, r, w1, rm)

ax[0].set_xlabel(fr"$z_1$")
ax[0].set_ylabel(fr"$L$")
ax[0].set_title(f"u = {x}")


# z_1 fixed 
y1 = 5
x, y2 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

r = (beta[0] - p1)**2 + (beta[1] - p2)**2 # Tar V ut sig självt?

beta_v = beta @ np.transpose(V)  # TODO: do all other equations assume V=I or does it not matter?
rm = (x * y1 - beta_v[0]) * q1 + (x * y2 - beta_v[1]) * q2 

ax[1].quiver(y2, r, w2, rm)

ax[1].set_xlabel(fr"$z_2$")
ax[1].set_ylabel(fr"$L$")
ax[1].set_title(f"$z_1$ = {y1}")



# z_2 fixed 
y2 = 5
x, y1 = np.meshgrid(np.linspace(v_min, v_max, grid_size),
                    np.linspace(v_min, v_max, grid_size))


d1 = (beta[0] - x * y1) * S[0, 0].numpy() + eps[0] * S[0, 0].numpy()**0.5 #St[0, 0].numpy() - x * y1 * S[0, 0].numpy()
d2 = (beta[1] - x * y2) * S[0, 1].numpy() + eps[1] * S[0, 1].numpy()**0.5 #St[0, 1].numpy() - x * y2 * S[0, 1].numpy()

v = d1 * y1 + d2 * y2
w1 = d1 * x
w2 = d2 * x

# Total weights 
p1 = x * y1
p2 = x * y2

q1 = v * y1 + w1 * x
q2 = v * y2 + w2 * x

r = (beta[0] - p1)**2 + (beta[1] - p2)**2 # Tar V ut sig självt?

beta_v = beta @ np.transpose(V)  # TODO: do all other equations assume V=I or does it not matter?
rm = (x * y1 - beta_v[0]) * q1 + (x * y2 - beta_v[1]) * q2 

ax[2].quiver(y1, r, w1, rm)

ax[2].set_xlabel(fr"$z_1$")
ax[2].set_ylabel(fr"$L$")
ax[2].set_title(f"$z_2$ = {y2}")

