## Exact solutions, decoupling


In [None]:
#!/usr/bin/env python
# coding: utf-8

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import axes3d

import argparse
import os
import datetime
import pathlib
import random
import json
import numpy as np
import math

import torch
from scipy.stats import ortho_group

import sys
sys.path.append('../code/')
from linear_utils import linear_model, get_modulation_matrix
from train_utils import save_config

In [None]:
d = 20
n = 1000
r = 5

B = np.random.uniform(low=0, high=1, size=(d, r))
D = np.diag(np.array([4, 2, 1, 1/2, 1/4]))
Z = np.random.multivariate_normal(mean=np.zeros(r,), cov=D, size=(n,))
eps = 10**(-3) * np.random.standard_normal(size=(n, d))

X = Z@B.T + eps

beta = B@D@B.T

In [None]:
Xs = torch.tensor(X, dtype=torch.float32)
ys = torch.tensor(X, dtype=torch.float32)

In [None]:
# define loss functions
loss_fn = torch.nn.MSELoss(reduction='sum')
risk_fn = torch.nn.MSELoss(reduction='sum')

## One layer

In [None]:
model = torch.nn.Sequential(
           torch.nn.Linear(d, d, bias=False),
         )      
                
stepsize = 0.000001
iterations = 100000
print_freq = 10000

In [None]:
# train the network
losses_emp = []
mse_weights_emp = []
ws = []


for t in range(int(iterations)):
    
    y_pred = model(Xs)

    loss = loss_fn(y_pred, ys)
    losses_emp.append(loss.item())

    if not t % print_freq:
        print(t, loss.item())
        
    model.zero_grad()
    loss.backward()


    with torch.no_grad():
        i = 0
        w_tot = torch.diag(torch.ones(d)) #[]
        for param in model.parameters():
                
            param.data -= stepsize * param.grad
                                
            w_tot = w_tot @ param.data.t()
            if len(param.shape) > 1:
                i += 1
        
        w_tot = w_tot.squeeze()
        assert w_tot.shape == beta.shape

        ws.append(w_tot)
        mse_weights_emp.append(((w_tot-beta.squeeze()) / beta.squeeze())**2) #w_tot
                

losses_1 = np.array(losses_emp)
risks_w_1 = np.row_stack([mse_w.reshape(-1) for mse_w in mse_weights_emp])
w_norm_1 = np.row_stack([np.linalg.norm(w) for w in ws])

## Two layers

In [None]:
model = torch.nn.Sequential(
           torch.nn.Linear(d, r, bias=False),
           torch.nn.Linear(r, d, bias=False),
         )      
                
stepsize = 0.000001
iterations = 100000
print_freq = 10000

In [None]:
# train the network
losses_emp = []
mse_weights_emp = []
ws = []


for t in range(int(iterations)):
    
    y_pred = model(Xs)

    loss = loss_fn(y_pred, ys)
    losses_emp.append(loss.item())

    if not t % print_freq:
        print(t, loss.item())
        
    model.zero_grad()
    loss.backward()


    with torch.no_grad():
        i = 0
        w_tot = torch.diag(torch.ones(d)) #[]
        for param in model.parameters():
                
            param.data -= stepsize * param.grad
                                
            w_tot = w_tot @ param.data.t()
            if len(param.shape) > 1:
                i += 1
        
        w_tot = w_tot.squeeze()
        assert w_tot.shape == beta.shape

        ws.append(w_tot)
        mse_weights_emp.append(((w_tot-beta.squeeze()) / beta.squeeze())**2) #w_tot
                

losses_2 = np.array(losses_emp)
risks_w_2 = np.row_stack([mse_w.reshape(-1) for mse_w in mse_weights_emp])
w_norm_2 = np.row_stack([np.linalg.norm(w) for w in ws])

In [None]:
geo_samples = [int(i) for i in np.geomspace(1, len(losses_emp)-1, num=700)]

In [None]:
cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(700/1000)]
labelList = ['One layer', 'Two layers']

plot_all_dims = True

num_axs = 3 + 5 if plot_all_dims else 3

fig, ax = plt.subplots(num_axs, 1, figsize=(12, 4 * num_axs))

ax[0].set_xscale('log')
ax[0].plot(geo_samples, w_norm_2[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[0].plot(geo_samples, w_norm_1[geo_samples], 
        color=colorList[0], 
        label=labelList[0],
        lw=4)


ax[0].legend(loc=1, bbox_to_anchor=(1, 1), fontsize='x-large',
    frameon=False, fancybox=True, shadow=True, ncol=1)
ax[0].set_ylabel('norm')
ax[0].set_xlabel(r'$t$ iterations')

ax[1].set_xscale('log')
ax[1].plot(geo_samples, losses_2[geo_samples], 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[1].plot(geo_samples, losses_1[geo_samples], 
        color=colorList[0], 
        label=labelList[0],
        lw=4)

ax[1].set_ylabel('loss')
ax[1].set_xlabel(r'$t$ iterations')


if plot_all_dims:
    for i in range(5): #range(risks_w_2.shape[-1]):
        ax[2+i].set_xscale('log')
        ax[2+i].plot(geo_samples, risks_w_2[geo_samples, i], 
                color=colorList[1], 
                label=labelList[1],
                lw=4)
        
        ax[2+i].plot(geo_samples, risks_w_1[geo_samples, i], 
        color=colorList[0], 
        label=labelList[0],
        lw=4)

        ax[2+i].set_ylabel('MSE weights, ' + str(i))
        ax[2+i].set_xlabel(r'$t$ iterations')


ax[-1].set_xscale('log')
ax[-1].plot(geo_samples, risks_w_2[geo_samples, :].mean(axis=-1), 
        color=colorList[1], 
        label=labelList[1],
        lw=4)

ax[-1].plot(geo_samples, risks_w_1[geo_samples, :].mean(axis=-1), 
        color=colorList[0], 
        label=labelList[0],
        lw=4)


ax[-1].set_ylabel('MSE weights')
ax[-1].set_xlabel(r'$t$ iterations')


plt.show()

In [None]:
# Har detta något att göra med att här har sanna beta en lägre rang? Och enlagermodellen hittar inte denna matris av lägre rang...
# Men är det att enlagerdynamiken rör sig i ett annat rum?

print(beta.shape)
print(np.linalg.matrix_rank(beta))

In [None]:
U, S, Vh = np.linalg.svd(Xs.T@Xs)
print(S)

## 