In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor, tensor
from scipy.optimize import minimize

from models.sandwiched_least_squares import sandwiched_LS_scalar, sandwiched_LS_diag, sandwiched_LS_dense

# Scalar case

Let $R \in R^{n \times d}, W\in R^{D\times d}, \Delta \in R,$ and $X \in R^{n \times D}$. Let $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{n} \sum_{i=1}^n \big\| R_i - W^\top\Delta X_i \big\|^2 + \lambda \Delta^2
\end{align*}
is uniquely attained by
\begin{align*}
    \Delta_{\textnormal{scalar}} 
    &= \frac{\langle R, XW^\top\rangle_F}{\|X W\|_F^2 + n\lambda} 
    = \frac{\frac{1}{n}\sum_{i=1}^n \langle W^\top X_i,  R_i\rangle}{\frac{1}{n}\sum_{i=1}^n \|W^\top X_i\|^2 + \lambda }.
\end{align*}

In [2]:
# Parameters
d = 200
D = 300
N = 100
l2_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(D, d)/100
x = np.random.randn(N, D)-1

def J(Delta):
    Wx = W.T @ x.T
    residual = r - Wx.T * Delta
    return np.mean(np.linalg.norm(residual, axis=1)**2) + l2_reg * Delta**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W.T @ x[i] * Delta)**2 / N
    return res + l2_reg * Delta**2

In [3]:
# Closed form solution
Delta_closed_form = sandwiched_LS_scalar(tensor(r), tensor(W), tensor(x), l2_reg).numpy()

print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: 0.09612819692159835
Objective value for closed form solution: 993.0117548905527
Objective value for closed form solution (by hand): 993.0117548905528


In [4]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(), method='BFGS')
Delta = result.x[0]
print(f"Gradient descent solution for Delta using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent solution for Delta using scipy.optimize: 0.09612803569713149
Objective value for gradient descent solution: 993.0117548905532
Objective value for closed form solution (by hand): 993.0117548905529


# Diagonal Case

Let $R \in R^{n \times d}, W\in R^{D\times d}, \Delta = \textnormal{diag}(\delta_1, ..., \delta_D) \in R^{D \times D},$ and $X \in R^{n \times D}$. Let $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{n} \sum_{i=1}^n \big\| R_i - W^\top\Delta X_i \big\|^2 + \lambda \Delta^2
\end{align*}
is uniquely attained by the solution to the system of linear equations
\begin{align*}
    b = (A+ \lambda I)\Delta
\end{align*}
where
\begin{align*}
    A = W W^\top \odot X^\top X,  \qquad \qquad b = \textnormal{diag}(W^\top R^\top X).
\end{align*}


In [5]:
# Parameters
d = 30
D = 20
N = 1000
l2_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)
W = np.random.randn(D, d)
x = np.random.randn(N, D)

def A_byhand():
    A = np.zeros((D, D))
    for k in range(D):
        for j in range(D):
            A[k, j] = np.mean([ x[i, k] * x[i, j] * np.dot(W[k], W[j]) for i in range(N)])
    return A


def A():
    return (W @ W.T) * (x.T @ x) / N

def b_byhand():
    b = np.zeros(D)
    for k in range(D):
        b[k] = np.mean([ x[i, k] * np.dot(W[k], r[i]) for i in range(N)])
    return b


def b():
    return np.mean( (r @ W.T) * x, axis=0)
    #return np.diag(W @ r.T @ x) / N
    #return np.einsum('nd,kd,nk->k', r, W, x) / N

In [6]:
(A() - A_byhand()).mean()

-2.0510801173809213e-17

In [7]:
(b() - b_byhand()).mean()

2.4134340359527328e-17

In [8]:
def J(Delta):
    return np.mean(np.linalg.norm(r - x @ np.diag(Delta) @ W, axis=1)**2) + l2_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W.T @ (Delta*x[i]))**2 / N
    return res + l2_reg * np.sum(Delta**2)

In [9]:
Delta_closed_form = sandwiched_LS_diag(tensor(r), tensor(W), tensor(x), l2_reg).numpy()
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [ 1.15952169e-03  5.00119337e-03  8.21619145e-03  1.07988685e-02
 -2.68865051e-03  4.19528345e-03 -4.25148869e-03  5.80865727e-04
  4.97751029e-03 -7.35732222e-03  2.18509303e-06 -6.20406520e-03
  3.83358855e-03  8.42855673e-04 -5.64821215e-03  6.42419030e-03
  3.03962442e-03  1.63056427e-03  5.86754698e-03  7.41961758e-03]
Objective value for closed form solution: 29.54087641951766
Objective value for closed form solution (by hand): 29.540876419517648


In [10]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(D), method='BFGS')
Delta = result.x
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [ 1.15952293e-03  5.00119841e-03  8.21620435e-03  1.07988851e-02
 -2.68865000e-03  4.19529645e-03 -4.25149147e-03  5.80868866e-04
  4.97750987e-03 -7.35732630e-03  2.19717982e-06 -6.20406044e-03
  3.83358796e-03  8.42863595e-04 -5.64819249e-03  6.42419110e-03
  3.03962480e-03  1.63057012e-03  5.86755304e-03  7.41962085e-03]
Objective value for gradient descent solution: 29.540876419517705
Objective value for closed form solution (by hand): 29.54087641951767


# Dense Case

Let $R \in R^{n \times d}, W\in R^{D\times d}, \Delta = \in R^{D \times D},$ and $X \in R^{n \times D}$. Let $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) 
        &= \frac{1}{n} \sum_{i=1}^n \big\| r_i - W^\top \Delta x_i \big\|^2 + \sum_{k=1}^D\sum_{j=1}^p \lambda \Delta_{k,j}^2 \\
        &= \frac{1}{n}\| W^\top \Delta X^\top - R^\top\|^2_F + \lambda \|\Delta\|^2_F
\end{align*}
is uniquely obtained by solving the system of linear equations given by
\begin{align*} 
    W R^\top X    =  W W^\top \Delta X^\top X + \lambda n \Delta
\end{align*}
which can be solved by spectral decomposition $W W^\top = U \Lambda^W U^\top$  and $X^\top X = V \Lambda^X V^\top$
\begin{align*}
    \Delta_{\textnormal{dense}} = U \bigg[ U^\top W R^\top X V \oslash \bigg(\lambda N 1 + \textnormal{diag}(\Lambda^W) \otimes \textnormal{diag}(\Lambda^X)\bigg) \bigg] V^\top
\end{align*}
where $\oslash$ denotes element-wise division, $\otimes$ is the outer product, and $1$ is a matrix of ones.


In [11]:
# Parameters
d = 10
D = 50
p = 20 
N = 58
l2_reg = 1000

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)
W = np.random.randn(D, d)
x = np.random.randn(N, p)

def J(Delta):
    Delta = Delta.reshape(D, p)
    return 1/N * np.linalg.norm(W.T @ Delta @ x.T - r.T)**2 + l2_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    Delta = Delta.reshape(D, p)
    res = 0
    for i in range(N):
        res += 1/N * np.linalg.norm(r[i] - W.T @ Delta @ x[i])**2
    return res + l2_reg * np.linalg.norm(Delta)**2

In [18]:
Delta_closed_form = sandwiched_LS_dense(tensor(r), tensor(W), tensor(x), l2_reg).numpy().T
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [[ 3.79502611e-04  3.73698197e-04 -3.14955042e-04  1.22094528e-03
  -6.14976934e-04  1.41700688e-04  2.68869482e-04  8.15636930e-05
  -2.74746159e-04  2.42031991e-04 -7.82268829e-04  5.89967730e-04
   2.54167103e-04  4.57715190e-04  6.00741271e-04  6.48344140e-04
  -3.71850507e-05  2.36564443e-04 -4.55924228e-04  5.50561477e-04]
 [-7.68293568e-05 -2.53246322e-04  3.25699375e-04 -3.08510729e-04
   6.90006634e-05  2.31690563e-04 -1.83278252e-04 -8.05362183e-05
   1.47365158e-04  2.99572349e-05  2.07439311e-04 -2.77553164e-05
   3.67057439e-04  2.26333731e-04 -2.65277542e-04 -1.46667955e-04
  -8.26036297e-05  2.30390292e-04 -4.21209435e-04 -2.69941110e-04]
 [ 2.01248427e-04  1.82826066e-04 -3.43655514e-04  6.83085915e-04
  -4.70466820e-04  9.43317460e-05  2.11679375e-04  1.33431936e-04
   2.06597594e-04 -3.02078081e-05  1.01726338e-05  1.21991072e-04
   4.22474691e-05  2.71360714e-04  2.54763436e-04  3.05433870e-07
   4.80830751e-04 -4.33357131e-04 -4.27459

In [17]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(p*D), method='L-BFGS-B')
Delta = result.x.reshape(p, D)
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [[ 3.79497920e-04  3.73693304e-04 -3.14960162e-04  1.22093967e-03
  -6.14981716e-04  1.41695736e-04  2.68864832e-04  8.15588035e-05
  -2.74751155e-04  2.42027065e-04 -7.82273654e-04  5.89963086e-04
   2.54162700e-04  4.57710366e-04  6.00736255e-04  6.48339719e-04
  -3.71905714e-05  2.36560043e-04 -4.55929015e-04  5.50556956e-04
  -7.68343520e-05 -2.53251788e-04  3.25694497e-04 -3.08515563e-04
   6.89956941e-05  2.31685494e-04 -1.83283392e-04 -8.05412769e-05
   1.47360214e-04  2.99522642e-05  2.07434520e-04 -2.77604661e-05
   3.67052520e-04  2.26328498e-04 -2.65282653e-04 -1.46672933e-04
  -8.26084475e-05  2.30385072e-04 -4.21214629e-04 -2.69946381e-04
   2.01243294e-04  1.82821162e-04 -3.43660361e-04  6.83080687e-04
  -4.70471531e-04  9.43266882e-05  2.11674402e-04  1.33427046e-04
   2.06592848e-04 -3.02131803e-05]
 [ 1.01674577e-05  1.21986573e-04  4.22426910e-05  2.71356026e-04
   2.54758946e-04  3.00782997e-07  4.80825699e-04 -4.33362140e-04
  