In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor, tensor
from scipy.optimize import minimize

from models.sandwiched_least_squares import sandwiched_LS_scalar, sandwiched_LS_diag, sandwiched_LS_dense

# Scalar case

Let $R \in R^{n \times d}, W\in R^{D\times d}, \Delta \in R,$ and $X \in R^{n \times D}$. Let $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{n} \sum_{i=1}^n \big\| R_i - W^\top\Delta X_i \big\|^2 + \lambda \Delta^2
\end{align*}
is uniquely attained by
\begin{align*}
    \Delta_{\textnormal{scalar}} 
    &= \frac{\langle R, XW^\top\rangle_F}{\|X W\|_F^2 + n\lambda} 
    = \frac{\frac{1}{n}\sum_{i=1}^n \langle W^\top X_i,  R_i\rangle}{\frac{1}{n}\sum_{i=1}^n \|W^\top X_i\|^2 + \lambda }.
\end{align*}

In [2]:
# Parameters
d = 200
D = 300
N = 100
l2_reg = 0.01

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(D, d)/100
x = np.random.randn(N, D)-1

def J(Delta):
    Wx = W.T @ x.T
    residual = r - Wx.T * Delta
    return np.mean(np.linalg.norm(residual, axis=1)**2) + l2_reg * Delta**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W.T @ x[i] * Delta)**2 / N
    return res + l2_reg * Delta**2

In [3]:
# Closed form solution
Delta_closed_form = sandwiched_LS_scalar(tensor(r), tensor(W), tensor(x), l2_reg).numpy()

print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: 0.17964255108434818
Objective value for closed form solution: 992.8392404324366
Objective value for closed form solution (by hand): 992.8392404324366


In [4]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(), method='BFGS')
Delta = result.x[0]
print(f"Gradient descent solution for Delta using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent solution for Delta using scipy.optimize: 0.17964211768408475
Objective value for gradient descent solution: 992.8392404324388
Objective value for closed form solution (by hand): 992.8392404324385


# Diagonal Case

Let $R \in R^{n \times d}, W\in R^{D\times d}, \Delta = \textnormal{diag}(\delta_1, ..., \delta_D) \in R^{D \times D},$ and $X \in R^{n \times D}$. Let $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{n} \sum_{i=1}^n \big\| R_i - W^\top\Delta X_i \big\|^2 + \lambda \Delta^2
\end{align*}
is uniquely attained by the solution to the system of linear equations
\begin{align*}
    b = (A+ \lambda I)\Delta
\end{align*}
where
\begin{align*}
    A = W W^\top \odot X^\top X,  \qquad \qquad b = \textnormal{diag}(W^\top R^\top X).
\end{align*}


In [5]:
# Parameters
d = 30
D = 20
N = 1000
l2_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(D, d)/100
x = np.random.randn(N, D)-1

def A_byhand():
    A = np.zeros((D, D))
    for k in range(D):
        for j in range(D):
            A[k, j] = np.mean([ x[i, k] * x[i, j] * np.dot(W[k], W[j]) for i in range(N)])
    return A


def A():
    return (W @ W.T) * (x.T @ x) / N

def b_byhand():
    b = np.zeros(D)
    for k in range(D):
        b[k] = np.mean([ x[i, k] * np.dot(W[k], r[i]) for i in range(N)])
    return b


def b():
    return np.mean( (r @ W.T) * x, axis=0)
    #return np.diag(W @ r.T @ x) / N
    #return np.einsum('nd,kd,nk->k', r, W, x) / N

In [6]:
(A() - A_byhand()).mean()

1.5242357885841135e-20

In [7]:
(b() - b_byhand()).mean()

1.5959455978986624e-17

In [8]:
def J(Delta):
    return np.mean(np.linalg.norm(r - x @ np.diag(Delta) @ W, axis=1)**2) + l2_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W.T @ (Delta*x[i]))**2 / N
    return res + l2_reg * np.sum(Delta**2)

In [9]:
Delta_closed_form = sandwiched_LS_diag(tensor(r), tensor(W), tensor(x), l2_reg).numpy()
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [-0.00118117  0.0088751  -0.00139431 -0.02298253 -0.01659739  0.00393831
  0.01807629  0.00768456  0.00326887 -0.00505116 -0.00754621  0.00641875
  0.00796784  0.00098712 -0.00994351  0.0028343   0.01790257 -0.00202378
  0.00719184  0.01089629]
Objective value for closed form solution: 149.01900449101015
Objective value for closed form solution (by hand): 149.01900449101026


In [10]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(D), method='BFGS')
Delta = result.x
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [-0.00118107  0.00887488 -0.00139438 -0.02298261 -0.01659761  0.00393813
  0.01807629  0.00768457  0.00326876 -0.00505104 -0.00754647  0.00641864
  0.00796751  0.00098706 -0.00994363  0.00283416  0.0179027  -0.00202393
  0.00719158  0.01089627]
Objective value for gradient descent solution: 149.01900449101524
Objective value for closed form solution (by hand): 149.0190044910153


# Dense Case

Let $R \in R^{n \times d}, W\in R^{D\times d}, \Delta = \in R^{D \times D},$ and $X \in R^{n \times D}$. Let $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) 
        &= \frac{1}{n} \sum_{i=1}^n \big\| r_i - W^\top \Delta x_i \big\|^2 + \sum_{k=1}^D\sum_{j=1}^p \lambda \Delta_{k,j}^2 \\
        &= \frac{1}{n}\| W^\top \Delta X^\top - R^\top\|^2_F + \lambda \|\Delta\|^2_F
\end{align*}
is uniquely obtained by solving the system of linear equations given by
\begin{align*} 
    W R^\top X    =  W W^\top \Delta X^\top X + \lambda n \Delta
\end{align*}
which can be solved by spectral decomposition $W W^\top = U \Lambda^W U^\top$  and $X^\top X = V \Lambda^X V^\top$
\begin{align*}
    \Delta_{\textnormal{dense}} = U \bigg[ U^\top W R^\top X V \oslash \bigg(\lambda N 1 + \textnormal{diag}(\Lambda^W) \otimes \textnormal{diag}(\Lambda^X)\bigg) \bigg] V^\top
\end{align*}
where $\oslash$ denotes element-wise division, $\otimes$ is the outer product, and $1$ is a matrix of ones.


In [None]:
# Parameters
d = 10
D = 50
p = 20 
N = 100
l2_reg = 0.001

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)
W = np.random.randn(D, d)
x = np.random.randn(N, p)

def J(Delta):
    Delta = Delta.reshape(D, p)
    return 1/N * np.linalg.norm(W.T @ Delta @ x.T - r.T)**2 + l2_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    Delta = Delta.reshape(D, p)
    res = 0
    for i in range(N):
        res += 1/N * np.linalg.norm(r[i] - W.T @ Delta @ x[i])**2
    return res + l2_reg * np.linalg.norm(Delta)**2

In [17]:
Delta_closed_form = sandwiched_LS_dense(tensor(r), tensor(W), tensor(x), l2_reg).numpy().T
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [[-9.27318358e-01 -1.16353978e+00 -2.18291898e-01  8.70977514e-01
   4.01190322e-01  1.40880902e+00 -2.76872851e-02  3.08328422e-02
   5.00261331e-01  7.76657893e-01  1.26838401e+00 -5.16316660e-01
   9.18902048e-01  5.39832709e-01 -2.39891158e+00 -7.04312250e-01
   6.75871597e-01  8.03538711e-01  1.01942520e+00 -9.53331160e-01]
 [-1.14780813e+00  2.08064544e-01  6.69442134e-01  6.64725051e-01
   2.86719578e-02  8.20669090e-01  6.90556111e-02  2.97883820e-01
   6.28638940e-01  1.52933002e+00  1.15069743e+00 -3.36695204e-01
   1.38806864e+00 -9.82547485e-03 -2.33206064e+00  1.47705237e-01
   2.37550791e-01 -6.65855521e-01  3.83687294e-01 -6.90545604e-01]
 [-2.21642710e-01 -1.83139851e+00 -3.75775136e-01 -3.30438012e-01
   1.02902246e+00  1.76257081e+00  7.61812938e-02 -5.74793289e-01
  -6.55050899e-01  5.18852765e-02  7.00981919e-01 -9.22306386e-01
   5.47516777e-01 -3.20858811e-01 -1.76902237e+00 -1.26875345e+00
   2.00814928e+00  2.79838548e+00  2.07304

In [13]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(p*D), method='L-BFGS-B')
Delta = result.x.reshape(p, D)
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [[-7.33175816e-02 -8.33825418e-02 -5.97449439e-02 -3.66615719e-02
  -5.17175950e-02 -2.70165737e-02 -7.24450817e-02 -5.56272100e-02
  -5.13077034e-02 -6.05601324e-03 -3.12510831e-02 -7.44195564e-02
  -3.43837678e-02 -8.27531595e-02 -1.41142607e-01 -6.69246670e-02
  -4.02947018e-02 -3.39417974e-02 -4.84393889e-02 -1.02105738e-01
  -1.95991915e-02  4.11892192e-02  7.05723555e-02  4.37947584e-02
   4.76532208e-02  2.01575973e-02  1.29581166e-02  2.75940021e-02
   4.75722339e-02  9.10300717e-02  6.15233241e-02 -2.40547691e-02
   7.50185274e-02 -1.03014969e-02 -6.53025393e-02  2.30389451e-02
   3.65794000e-02  8.16096726e-03 -1.05956434e-02 -2.49599529e-02
   2.11699575e-01  1.11460339e-01  1.47580188e-01  1.93799864e-01
   1.79080885e-01  2.33160727e-01  1.93897799e-01  1.07460295e-01
   1.69382759e-01  1.46881980e-01]
 [ 1.91281388e-01  1.42401880e-01  2.10852407e-01  1.16469781e-01
   1.20827854e-01  1.60582046e-01  1.89168677e-01  2.24431265e-01
  