In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor, tensor
import pandas as pd
from scipy.optimize import minimize
np.set_printoptions(precision=3, threshold=5) # Print options

from utils.utils import print_name, print_shape

# Scalar case

Let $r_i \in R^d, W\in R^{d\times D}, \Delta \in R, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \Delta^2
\end{align*}
is attained at
\begin{align*}
    \Delta = \frac{\sum_{i=1}^N (Wx_i)^T r_i}{\sum_{i=1}^N\big( \|Wx_i\|^2 + \lambda \big)}
\end{align*}

In [2]:
# Parameters
d = 200
D = 3000
N = 1000
lambda_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, D)-1

def J(Delta):
    Wx = W @ x.T
    residual = r - Wx.T * Delta
    return np.mean(np.linalg.norm(residual, axis=1)**2) + lambda_reg * Delta**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W @ x[i] * Delta)**2 / N
    return res + lambda_reg * Delta**2

In [3]:
# Closed form solution numpy
Wx = W @ x.T
numerator = np.sum(r.T * Wx / N)
denominator = np.sum(Wx * Wx / N) + lambda_reg 
Delta_closed_form = numerator / denominator
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")
print("numerator, denominator", numerator, denominator)

Closed form solution for Delta: -0.16172555891809534
Objective value for closed form solution: 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766267
numerator, denominator -21.838381139270073 135.0335796355476


In [4]:
# Closed from solution by hand
numerator = sum([ np.dot(W @ x[i], r[i]) for i in range(N) ])/N
denominator = sum([ np.linalg.norm(W @ x[i])**2 for i in range(N) ])/N + lambda_reg
Delta_by_hand = numerator / denominator
print(f"Closed form solution for Delta (by hand): {Delta_by_hand}")
print(f"Objective value for closed form solution (by hand): {J(Delta_by_hand)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_by_hand)}")
print("numerator, denominator", numerator, denominator)

Closed form solution for Delta (by hand): -0.16172555891809537
Objective value for closed form solution (by hand): 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766267
numerator, denominator -21.838381139270066 135.03357963554754


In [5]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(), method='BFGS')
Delta = result.x[0]
print(f"Gradient descent solution for Delta using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent solution for Delta using scipy.optimize: -0.1617255660741791
Objective value for gradient descent solution: 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766273


# Diagonal Case

Let $r_i \in R^d, W\in R^{d\times D}, \Delta = diag(\delta_1, ..., \delta_D) \in R^{D\times D}, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \|\Delta\|^2
\end{align*}
is attained by solving the system of linear equations
\begin{align*}
    b = (A + \lambda I)\Delta
\end{align*}
where
\begin{align*}
    A_{k,j} = \frac{1}{N} \sum_{i=1}^D \big( W_k x_{i,k} \big)^T W_j x_{i,j}
\end{align*}
and
\begin{align*}
    b_k = \frac{1}{N}\sum_{i=1}^N r_i^T W_k x_{i,k}
\end{align*}


In [6]:
# Parameters
d = 30
D = 20
N = 1000
lambda_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, D)-1

def A_byhand():
    A = np.zeros((D, D))
    for k in range(D):
        for j in range(D):
            A[k, j] = np.mean([ x[i, k] * x[i, j] * np.dot(W[:, k], W[:, j]) for i in range(N)])
    return A


def A():
    return (W.T @ W) * (x.T @ x) / N


def b_byhand():
    b = np.zeros(D)
    for k in range(D):
        b[k] = np.mean([ x[i, k] * np.dot(W[:, k], r[i]) for i in range(N)])
    return b


def b():
    #return np.mean( (r @ W) * x, axis=0)
    #return np.diag(W.T @ r.T @ x) / N
    return np.einsum('nd,dk,nk->k', r, W, x) / N

In [7]:
(A() - A_byhand()).mean()

-9.452887691357992e-21

In [8]:
(b() - b_byhand()).mean()

5.585809592645319e-17

In [9]:
def J(Delta):
    return np.mean(np.linalg.norm(r - x @ np.diag(Delta) @ W.T, axis=1)**2) + lambda_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W @ (Delta*x[i]))**2 / N
    return res + lambda_reg * np.sum(Delta**2)

In [10]:
b = np.mean( (r @ W) * x, axis=0)
A = (W.T @ W) * (x.T @ x) / N
Delta_closed_form = np.linalg.solve(A + lambda_reg * np.eye(D), b)
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [-0.005  0.011  0.011 ...  0.006  0.013  0.017]
Objective value for closed form solution: 149.0212011905012
Objective value for closed form solution (by hand): 149.02120119050105


In [11]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(D), method='BFGS')
Delta = result.x
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [-0.005  0.011  0.011 ...  0.006  0.013  0.017]
Objective value for gradient descent solution: 149.02120119051318
Objective value for closed form solution (by hand): 149.02120119051318


# Dense Case

Let $r_i \in R^d, W\in R^{d\times p}, \Delta = R^{p \times D}, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \|\Delta\|^2_F
\end{align*}
is attained by solving the Sylvester system ... TODO

...



In [56]:
# Parameters
d = 1
D = 50
p = 20 
N = 100
lambda_reg = 0.1

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, p)-1

def J(Delta):
    Delta = Delta.reshape(D, p)
    return 1/N * np.linalg.norm(W @ Delta @ x.T - r.T)**2 + lambda_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    Delta = Delta.reshape(D, p)
    res = 0
    for i in range(N):
        res += 1/N * np.linalg.norm(r[i] - W @ Delta @ x[i])**2
    return res + lambda_reg * np.linalg.norm(Delta)**2

In [57]:
SW, U = np.linalg.eigh(W.T @ W)
SX, V = np.linalg.eigh(x.T @ x)

print(SX[:, None].shape, SW[None, :].shape)
print(SX[None, :].shape, SW[:, None].shape)
# [[ 0.032  0.04   0.044 ...  0.167  0.201  2.046]
#  [ 0.055  0.071  0.077 ...  0.293  0.353  3.588]
#  [ 0.067  0.085  0.093 ...  0.351  0.424  4.304]
#  ...
#  [ 0.119  0.151  0.165 ...  0.626  0.755  7.674]
#  [ 0.143  0.182  0.199 ...  0.756  0.911  9.262]
#  [ 0.207  0.263  0.288 ...  1.091  1.315 13.365]] (7, 20)

# prod = SX[:, None] * SW[None, :]
# print(prod, prod.shape)
prod = SX[None, :] * SW[:, None]
print(prod, prod.shape)

(20, 1) (1, 50)
(1, 20) (50, 1)
[[-4.160e-17 -4.730e-17 -5.720e-17 ... -1.832e-16 -2.032e-16 -2.539e-15]
 [-3.928e-17 -4.467e-17 -5.401e-17 ... -1.730e-16 -1.919e-16 -2.397e-15]
 [-3.012e-17 -3.425e-17 -4.141e-17 ... -1.327e-16 -1.471e-16 -1.838e-15]
 ...
 [ 3.239e-17  3.684e-17  4.454e-17 ...  1.427e-16  1.583e-16  1.977e-15]
 [ 9.408e-17  1.070e-16  1.294e-16 ...  4.144e-16  4.597e-16  5.741e-15]
 [ 2.037e-01  2.316e-01  2.801e-01 ...  8.972e-01  9.951e-01  1.243e+01]] (50, 20)


In [58]:
def solution_byhand1():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    Delta = U.T @ W.T @ r.T @ x @ V
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N*lambda_reg + SW[i] * SX[j])
    return U @ Delta @ V.T

def solution_byhand2():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    Delta = U @ W.T @ r.T @ x @ V.T
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N * lambda_reg + SW[i] * SX[j])
    return U.T @ Delta @ V

def solution_byhand3():
    Delta = W.T @ r.T @ x / N
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N * lambda_reg + SW[i] * SX[j])
    return Delta

def solution_byhand4():
    Delta = W.T @ r.T @ x / N
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    
    for i in range(Delta.shape[0]):
        for j in range(Delta.shape[1]):
            Delta[i, j] /= (N * lambda_reg + U[i,i] * V[j,j])
    return Delta

def solution1():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (W.T @ r.T @ x) / (N*lambda_reg + SW[:, None] * SX[None, :])
    return Delta

def solution2():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (W.T @ r.T @ x) / (N*lambda_reg + SW[:, None] * SX[None, :])
    return Delta

def solution3():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (U.T @ W.T @ r.T @ x @ V) / (N*lambda_reg + SW[:, None]*SX[None, :])
    return U @ Delta @ V.T

def solution4():
    SW, U = np.linalg.eigh(W.T @ W)
    SX, V = np.linalg.eigh(x.T @ x)
    Delta = (U.T @ W.T @ r.T @ x @ V) / (N*lambda_reg + (SW[None, :] * SX[:, None]).T)
    return U @ Delta @ V.T

In [59]:
solution1()

array([[-0.454, -0.444, -0.376, ..., -0.355, -0.377, -0.347],
       [ 0.325,  0.318,  0.269, ...,  0.254,  0.27 ,  0.248],
       [ 0.306,  0.3  ,  0.253, ...,  0.24 ,  0.254,  0.234],
       ...,
       [-0.269, -0.263, -0.223, ..., -0.211, -0.224, -0.206],
       [ 0.317,  0.31 ,  0.263, ...,  0.248,  0.263,  0.243],
       [ 0.109,  0.106,  0.09 , ...,  0.08 ,  0.084,  0.038]])

In [60]:
solution2()

array([[-0.454, -0.444, -0.376, ..., -0.355, -0.377, -0.347],
       [ 0.325,  0.318,  0.269, ...,  0.254,  0.27 ,  0.248],
       [ 0.306,  0.3  ,  0.253, ...,  0.24 ,  0.254,  0.234],
       ...,
       [-0.269, -0.263, -0.223, ..., -0.211, -0.224, -0.206],
       [ 0.317,  0.31 ,  0.263, ...,  0.248,  0.263,  0.243],
       [ 0.109,  0.106,  0.09 , ...,  0.08 ,  0.084,  0.038]])

In [61]:
solution3()

array([[-0.199, -0.2  , -0.174, ..., -0.164, -0.174, -0.161],
       [ 0.142,  0.143,  0.124, ...,  0.118,  0.125,  0.115],
       [ 0.134,  0.135,  0.117, ...,  0.111,  0.118,  0.109],
       ...,
       [-0.118, -0.119, -0.103, ..., -0.097, -0.103, -0.096],
       [ 0.139,  0.14 ,  0.121, ...,  0.115,  0.122,  0.113],
       [ 0.049,  0.049,  0.043, ...,  0.04 ,  0.043,  0.039]])

In [62]:
solution4()

array([[-0.199, -0.2  , -0.174, ..., -0.164, -0.174, -0.161],
       [ 0.142,  0.143,  0.124, ...,  0.118,  0.125,  0.115],
       [ 0.134,  0.135,  0.117, ...,  0.111,  0.118,  0.109],
       ...,
       [-0.118, -0.119, -0.103, ..., -0.097, -0.103, -0.096],
       [ 0.139,  0.14 ,  0.121, ...,  0.115,  0.122,  0.113],
       [ 0.049,  0.049,  0.043, ...,  0.04 ,  0.043,  0.039]])

In [63]:
solution_byhand1()

array([[-0.199, -0.2  , -0.174, ..., -0.164, -0.174, -0.161],
       [ 0.142,  0.143,  0.124, ...,  0.118,  0.125,  0.115],
       [ 0.134,  0.135,  0.117, ...,  0.111,  0.118,  0.109],
       ...,
       [-0.118, -0.119, -0.103, ..., -0.097, -0.103, -0.096],
       [ 0.139,  0.14 ,  0.121, ...,  0.115,  0.122,  0.113],
       [ 0.049,  0.049,  0.043, ...,  0.04 ,  0.043,  0.039]])

In [64]:
solution_byhand2()

array([[-0.451, -0.444, -0.375, ..., -0.354, -0.378, -0.345],
       [ 0.326,  0.318,  0.269, ...,  0.255,  0.269,  0.25 ],
       [ 0.314,  0.301,  0.256, ...,  0.245,  0.251,  0.24 ],
       ...,
       [-0.263, -0.262, -0.221, ..., -0.208, -0.226, -0.202],
       [ 0.313,  0.31 ,  0.261, ...,  0.246,  0.265,  0.24 ],
       [ 0.108,  0.108,  0.091, ...,  0.085,  0.094,  0.083]])

In [65]:
solution_byhand3()

array([[-0.005, -0.004, -0.004, ..., -0.004, -0.004, -0.003],
       [ 0.003,  0.003,  0.003, ...,  0.003,  0.003,  0.002],
       [ 0.003,  0.003,  0.003, ...,  0.002,  0.003,  0.002],
       ...,
       [-0.003, -0.003, -0.002, ..., -0.002, -0.002, -0.002],
       [ 0.003,  0.003,  0.003, ...,  0.002,  0.003,  0.002],
       [ 0.001,  0.001,  0.001, ...,  0.001,  0.001,  0.   ]])

In [66]:
solution_byhand4()

array([[-0.005, -0.004, -0.004, ..., -0.004, -0.004, -0.003],
       [ 0.003,  0.003,  0.003, ...,  0.003,  0.003,  0.002],
       [ 0.003,  0.003,  0.003, ...,  0.002,  0.003,  0.002],
       ...,
       [-0.003, -0.003, -0.002, ..., -0.002, -0.002, -0.002],
       [ 0.003,  0.003,  0.003, ...,  0.003,  0.003,  0.002],
       [ 0.001,  0.001,  0.001, ...,  0.001,  0.001,  0.001]])

In [67]:
for fun in [solution1, solution2, solution3, solution4, solution_byhand1, solution_byhand2, solution_byhand3, solution_byhand4]:
    Delta = fun()
    print(f"\n", fun, J(Delta), J_byhand(Delta))


 <function solution1 at 0x7ffa16d56680> 6.5214046564648465 6.5214046564648465

 <function solution2 at 0x7ffa16d56950> 6.5214046564648465 6.5214046564648465

 <function solution3 at 0x7ffa16d570a0> 2.906448688411172 2.906448688411172

 <function solution4 at 0x7ffa16d568c0> 2.906448688411172 2.906448688411172

 <function solution_byhand1 at 0x7ffa16d55cf0> 2.906448688411172 2.906448688411172

 <function solution_byhand2 at 0x7ffa16d56830> 6.5043157860348995 6.5043157860348995

 <function solution_byhand3 at 0x7ffa16d56c20> 5.154570606135915 5.154570606135918

 <function solution_byhand4 at 0x7ffa16d56320> 5.1545438852400896 5.154543885240089


In [68]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(p*D), method='L-BFGS-B')
Delta = result.x.reshape(p, D)
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [[-0.199 -0.2   -0.174 ...  0.14   0.096  0.096]
 [ 0.13   0.108  0.118 ...  0.102  0.109  0.1  ]
 [-0.205 -0.207 -0.179 ...  0.082  0.056  0.057]
 ...
 [-0.017 -0.014 -0.015 ... -0.208 -0.221 -0.204]
 [-0.1   -0.101 -0.087 ... -0.123 -0.084 -0.084]
 [-0.114 -0.095 -0.104 ...  0.04   0.043  0.039]]
Objective value for gradient descent solution: 2.9064486885964547
Objective value for closed form solution (by hand): 2.9064486885964547
