In [3]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor, tensor
import pandas as pd
from scipy.optimize import minimize
np.set_printoptions(precision=3, threshold=5) # Print options

from utils.utils import print_name, print_shape

# Scalar case

Let $r_i \in R^d, W\in R^{d\times D}, \Delta \in R, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \Delta^2
\end{align*}
is attained at
\begin{align*}
    \Delta = \frac{\sum_{i=1}^N (Wx_i)^T r_i}{\sum_{i=1}^N\big( \|Wx_i\|^2 + \lambda \big)}
\end{align*}

In [11]:
# Parameters
d = 200
D = 3000
N = 1000
lambda_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, D)-1

def J(Delta):
    Wx = W @ x.T
    residual = r - Wx.T * Delta
    return np.mean(np.linalg.norm(residual, axis=1)**2) + lambda_reg * Delta**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W @ x[i] * Delta)**2 / N
    return res + lambda_reg * Delta**2

In [20]:
# Closed form solution numpy
Wx = W @ x.T
numerator = np.sum(r.T * Wx / N)
denominator = np.sum(Wx * Wx / N) + lambda_reg 
Delta_closed_form = numerator / denominator
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")
print("numerator, denominator", numerator, denominator)

Closed form solution for Delta: -0.16172555891809534
Objective value for closed form solution: 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766267
numerator, denominator -21.838381139270073 135.0335796355476


In [195]:
# Closed from solution by hand
numerator = sum([ np.dot(W @ x[i], r[i]) for i in range(N) ])/N
denominator = sum([ np.linalg.norm(W @ x[i])**2 for i in range(N) ])/N + lambda_reg
Delta_by_hand = numerator / denominator
print(f"Closed form solution for Delta (by hand): {Delta_by_hand}")
print(f"Objective value for closed form solution (by hand): {J(Delta_by_hand)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_by_hand)}")
print("numerator, denominator", numerator, denominator)

Closed form solution for Delta (by hand): -0.16172555891809537
Objective value for closed form solution (by hand): 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766267
numerator, denominator -21.838381139270066 135.03357963554754


In [196]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(), method='BFGS')
Delta = result.x[0]
print(f"Gradient descent solution for Delta using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent solution for Delta using scipy.optimize: -0.1617255660741791
Objective value for gradient descent solution: 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766273


# Diagonal Case

Let $r_i \in R^d, W\in R^{d\times D}, \Delta = diag(\delta_1, ..., \delta_D) \in R^{D\times D}, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \|\Delta\|^2
\end{align*}
is attained by solving the system of linear equations
\begin{align*}
    b = (A + \lambda I)\Delta
\end{align*}
where
\begin{align*}
    A_{k,j} = \frac{1}{N} \sum_{i=1}^D \big( W_k x_{i,k} \big)^T W_j x_{i,j}
\end{align*}
and
\begin{align*}
    b_k = \frac{1}{N}\sum_{i=1}^N r_i^T W_k x_{i,k}
\end{align*}


In [335]:
# Parameters
d = 30
D = 20
N = 1000
lambda_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, D)-1

def A_byhand():
    A = np.zeros((D, D))
    for k in range(D):
        for j in range(D):
            A[k, j] = np.mean([ x[i, k] * x[i, j] * np.dot(W[:, k], W[:, j]) for i in range(N)])
    return A


def A():
    return (W.T @ W) * (x.T @ x) / N


def b_byhand():
    b = np.zeros(D)
    for k in range(D):
        b[k] = np.mean([ x[i, k] * np.dot(W[:, k], r[i]) for i in range(N)])
    return b


def b():
    #return np.mean( (r @ W) * x, axis=0)
    #return np.diag(W.T @ r.T @ x) / N
    return np.einsum('nd,dk,nk->k', r, W, x) / N

In [336]:
(A() - A_byhand()).mean()

-9.452887691357992e-21

In [337]:
(b() - b_byhand()).mean()

5.585809592645319e-17

In [338]:
def J(Delta):
    return np.mean(np.linalg.norm(r - x @ np.diag(Delta) @ W.T, axis=1)**2) + lambda_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W @ (Delta*x[i]))**2 / N
    return res + lambda_reg * np.sum(Delta**2)

In [339]:
b = np.mean( (r @ W) * x, axis=0)
A = (W.T @ W) * (x.T @ x) / N
Delta_closed_form = np.linalg.solve(A + lambda_reg * np.eye(D), b)
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [-0.005  0.011  0.011 ...  0.006  0.013  0.017]
Objective value for closed form solution: 149.0212011905012
Objective value for closed form solution (by hand): 149.02120119050105


In [340]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(D), method='BFGS')
Delta = result.x
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [-0.005  0.011  0.011 ...  0.006  0.013  0.017]
Objective value for gradient descent solution: 149.02120119051318
Objective value for closed form solution (by hand): 149.02120119051318


# Dense Case

Let $r_i \in R^d, W\in R^{d\times p}, \Delta = R^{p \times D}, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \|\Delta\|^2_F
\end{align*}
is attained by solving the Sylvester system

...



In [3]:
# Parameters
d = 40
D = 50
p = 20 
N = 100
lambda_reg = 0.1

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, p)-1

def J(Delta):
    Delta = Delta.reshape(p, D)
    return 1/N * np.linalg.norm(W @ Delta @ x.T - r.T)**2 + lambda_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    Delta = Delta.reshape(p, D)
    res = 0
    for i in range(N):
        res += 1/N * np.linalg.norm(r[i] - W @ Delta @ x[i])**2
    return res + lambda_reg * np.linalg.norm(Delta)**2

In [238]:
SW, U = np.linalg.eigh(W.T @ W)
SX, V = np.linalg.eigh(x.T @ x)

print(SX[:, None].shape, SW[None, :].shape)
print(SX[None, :].shape, SW[:, None].shape)
# [[ 0.032  0.04   0.044 ...  0.167  0.201  2.046]
#  [ 0.055  0.071  0.077 ...  0.293  0.353  3.588]
#  [ 0.067  0.085  0.093 ...  0.351  0.424  4.304]
#  ...
#  [ 0.119  0.151  0.165 ...  0.626  0.755  7.674]
#  [ 0.143  0.182  0.199 ...  0.756  0.911  9.262]
#  [ 0.207  0.263  0.288 ...  1.091  1.315 13.365]] (7, 20)

# prod = SX[:, None] * SW[None, :]
# print(prod, prod.shape)
prod = SX[None, :] * SW[:, None]
print(prod, prod.shape)

(50, 1) (1, 40)
(1, 50) (40, 1)
[[2.861e-06 4.036e-06 5.400e-06 ... 7.310e-05 8.073e-05 1.658e-03]
 [9.838e-05 1.388e-04 1.857e-04 ... 2.514e-03 2.776e-03 5.703e-02]
 [3.753e-04 5.295e-04 7.084e-04 ... 9.590e-03 1.059e-02 2.176e-01]
 ...
 [1.043e-01 1.471e-01 1.968e-01 ... 2.664e+00 2.942e+00 6.044e+01]
 [1.173e-01 1.655e-01 2.214e-01 ... 2.998e+00 3.311e+00 6.801e+01]
 [1.293e-01 1.823e-01 2.440e-01 ... 3.303e+00 3.648e+00 7.493e+01]] (40, 50)


In [254]:
def solution_byhand1():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    Delta = U.T @ W.T @ r.T @ x @ V
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N*lambda_reg + SW[i] * SX[j])
    return U @ Delta @ V.T

def solution_byhand2():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    Delta = U @ W.T @ r.T @ x @ V.T
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N * lambda_reg + SW[i] * SX[j])
    return U.T @ Delta @ V

def solution_byhand3():
    Delta = W.T @ r.T @ x / N
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N * lambda_reg + SW[i] * SX[j])
    return Delta

def solution_byhand4():
    Delta = W.T @ r.T @ x / N
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N * lambda_reg + U[i,i] * V[j,j])
    return Delta

def solution1():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (W.T @ r.T @ x) / (N*lambda_reg + SW[:, None] * SX[None, :])
    return Delta

def solution2():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (W.T @ r.T @ x) / (N*lambda_reg + SW[:, None] * SX[None, :])
    return Delta

def solution3():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (U.T @ W.T @ r.T @ x @ V) / (N*lambda_reg + SW[:, None]*SX[None, :])
    return U @ Delta @ V.T

def solution4():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (U.T @ W.T @ r.T @ x @ V) / (N*lambda_reg + (SW[None, :] * SX[:, None]).T)
    return U @ Delta @ V.T

In [240]:
solution1()

array([[-0.325, -0.365, -0.512, ..., -0.401, -0.309, -0.343],
       [-1.309, -1.58 , -1.77 , ..., -1.597, -1.626, -1.566],
       [-0.659, -0.712, -0.914, ..., -0.758, -0.678, -0.858],
       ...,
       [ 0.553,  0.655,  0.633, ...,  0.508,  0.396,  0.099],
       [-0.145, -0.227, -0.193, ..., -0.186, -0.209, -0.031],
       [ 0.606,  0.662,  0.868, ...,  0.589,  0.588,  0.08 ]])

In [241]:
solution2()

array([[-0.325, -0.365, -0.512, ..., -0.401, -0.309, -0.343],
       [-1.309, -1.58 , -1.77 , ..., -1.597, -1.626, -1.566],
       [-0.659, -0.712, -0.914, ..., -0.758, -0.678, -0.858],
       ...,
       [ 0.553,  0.655,  0.633, ...,  0.508,  0.396,  0.099],
       [-0.145, -0.227, -0.193, ..., -0.186, -0.209, -0.031],
       [ 0.606,  0.662,  0.868, ...,  0.589,  0.588,  0.08 ]])

In [256]:
solution3()

array([[-0.062, -0.064, -0.168, ..., -0.103,  0.007, -0.028],
       [-0.435, -0.566, -0.628, ..., -0.6  , -0.56 , -0.512],
       [-0.336, -0.35 , -0.484, ..., -0.391, -0.296, -0.486],
       ...,
       [ 0.2  ,  0.24 ,  0.186, ...,  0.231,  0.084,  0.247],
       [ 0.059,  0.013,  0.069, ..., -0.004, -0.016,  0.017],
       [ 0.04 ,  0.013,  0.132, ...,  0.125,  0.092, -0.011]])

In [255]:
solution4()

array([[-0.062, -0.064, -0.168, ..., -0.103,  0.007, -0.028],
       [-0.435, -0.566, -0.628, ..., -0.6  , -0.56 , -0.512],
       [-0.336, -0.35 , -0.484, ..., -0.391, -0.296, -0.486],
       ...,
       [ 0.2  ,  0.24 ,  0.186, ...,  0.231,  0.084,  0.247],
       [ 0.059,  0.013,  0.069, ..., -0.004, -0.016,  0.017],
       [ 0.04 ,  0.013,  0.132, ...,  0.125,  0.092, -0.011]])

In [244]:
solution_byhand1()

array([[-0.062, -0.064, -0.168, ..., -0.103,  0.007, -0.028],
       [-0.435, -0.566, -0.628, ..., -0.6  , -0.56 , -0.512],
       [-0.336, -0.35 , -0.484, ..., -0.391, -0.296, -0.486],
       ...,
       [ 0.2  ,  0.24 ,  0.186, ...,  0.231,  0.084,  0.247],
       [ 0.059,  0.013,  0.069, ..., -0.004, -0.016,  0.017],
       [ 0.04 ,  0.013,  0.132, ...,  0.125,  0.092, -0.011]])

In [245]:
solution_byhand2()

array([[-0.307, -0.321, -0.514, ..., -0.363, -0.331, -0.257],
       [-1.251, -1.491, -1.793, ..., -1.507, -1.756, -1.391],
       [-0.63 , -0.676, -0.925, ..., -0.717, -0.753, -0.794],
       ...,
       [ 0.532,  0.639,  0.653, ...,  0.611,  0.569,  0.643],
       [-0.156, -0.266, -0.202, ..., -0.269, -0.279, -0.3  ],
       [ 0.585,  0.615,  0.894, ...,  0.727,  0.85 ,  0.568]])

In [246]:
solution_byhand3()

array([[-0.003, -0.004, -0.005, ..., -0.004, -0.003, -0.003],
       [-0.013, -0.016, -0.018, ..., -0.016, -0.016, -0.016],
       [-0.007, -0.007, -0.009, ..., -0.008, -0.007, -0.009],
       ...,
       [ 0.006,  0.007,  0.006, ...,  0.005,  0.004,  0.001],
       [-0.001, -0.002, -0.002, ..., -0.002, -0.002, -0.   ],
       [ 0.006,  0.007,  0.009, ...,  0.006,  0.006,  0.001]])

In [247]:
solution_byhand4()

array([[-0.003, -0.004, -0.005, ..., -0.004, -0.003, -0.003],
       [-0.013, -0.016, -0.018, ..., -0.016, -0.016, -0.016],
       [-0.007, -0.007, -0.009, ..., -0.008, -0.007, -0.009],
       ...,
       [ 0.006,  0.007,  0.006, ...,  0.006,  0.005,  0.007],
       [-0.001, -0.002, -0.002, ..., -0.002, -0.003, -0.002],
       [ 0.006,  0.007,  0.009, ...,  0.008,  0.008,  0.007]])

In [248]:
for fun in [solution1, solution2, solution3, solution4, solution_byhand1, solution_byhand2, solution_byhand3, solution_byhand4]:
    Delta = fun()
    print(f"\n", fun, J(Delta), J_byhand(Delta))


 <function solution1 at 0x7febde0276d0> 574.1871589602058 574.1871589602058

 <function solution2 at 0x7febde025ea0> 574.1871589602058 574.1871589602058

 <function solution3 at 0x7febdd4b5900> 130.6662564949538 130.66625649495384

 <function solution4 at 0x7febdd4b52d0> 130.6662564949538 130.66625649495384

 <function solution_byhand1 at 0x7febde17fd90> 130.6662564949538 130.66625649495384

 <function solution_byhand2 at 0x7febde1db6d0> 568.6318763998942 568.6318763998942

 <function solution_byhand3 at 0x7febde17f880> 190.60470509609306 190.60470509609317

 <function solution_byhand4 at 0x7febde0263b0> 190.5014158260387 190.5014158260387


In [268]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(p*D), method='L-BFGS-B')
Delta = result.x.reshape(p, D)
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [[-5.799e-02 -7.421e-02 -1.442e-01 ... -9.546e-02  1.748e-02 -4.842e-02]
 [-4.240e-01 -5.658e-01 -6.290e-01 ... -5.985e-01 -5.315e-01 -5.069e-01]
 [-3.435e-01 -3.748e-01 -4.830e-01 ... -4.162e-01 -3.117e-01 -5.186e-01]
 ...
 [ 1.877e-01  2.190e-01  1.734e-01 ...  2.341e-01  6.366e-02  2.500e-01]
 [ 4.213e-02 -6.926e-03  7.133e-02 ... -6.568e-03 -1.144e-02 -3.270e-04]
 [ 3.629e-02 -5.874e-04  1.170e-01 ...  9.635e-02  7.097e-02 -3.768e-02]]
Objective value for gradient descent solution: 130.75758700013705
Objective value for closed form solution (by hand): 130.75758700013694
