In [2]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor, tensor
import pandas as pd
np.set_printoptions(precision=3, threshold=5) # Print options

from utils.utils import print_name, print_shape

# Scalar case

Let $r_i \in R^d, W\in R^{d\times D}, \Delta \in R, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \Delta^2
\end{align*}
is attained at
\begin{align*}
    \Delta = \frac{\sum_{i=1}^N (Wx_i)^T r_i}{\sum_{i=1}^N\big( \|Wx_i\|^2 + \lambda \big)}
\end{align*}

In [286]:
# Parameters
d = 200
D = 3000
N = 1000
lambda_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, D)-1

def J(Delta):
    Wx = W @ x.T
    residual = r - Wx.T * Delta
    return np.mean(np.linalg.norm(residual, axis=1)**2) + lambda_reg * Delta**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W @ x[i] * Delta)**2 / N
    return res + lambda_reg * Delta**2

In [194]:
from scipy.optimize import minimize

# Closed form solution numpy
Wx = W @ x.T
numerator = np.sum(r.T * Wx)
denominator = np.sum(np.linalg.norm(Wx, axis=0)**2  + lambda_reg) 
Delta_closed_form = numerator / denominator
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")
print("numerator, denominator", numerator, denominator)

Closed form solution for Delta: -0.1617255589180953
Objective value for closed form solution: 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766267
numerator, denominator -21838.381139270066 135033.5796355476


In [195]:
# Closed from solution by hand
numerator = sum([ np.dot(W @ x[i], r[i]) for i in range(N) ])/N
denominator = sum([ np.linalg.norm(W @ x[i])**2 for i in range(N) ])/N + lambda_reg
Delta_by_hand = numerator / denominator
print(f"Closed form solution for Delta (by hand): {Delta_by_hand}")
print(f"Objective value for closed form solution (by hand): {J(Delta_by_hand)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_by_hand)}")
print("numerator, denominator", numerator, denominator)

Closed form solution for Delta (by hand): -0.16172555891809537
Objective value for closed form solution (by hand): 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766267
numerator, denominator -21.838381139270066 135.03357963554754


In [196]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(), method='BFGS')
Delta = result.x[0]
print(f"Gradient descent solution for Delta using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent solution for Delta using scipy.optimize: -0.1617255660741791
Objective value for gradient descent solution: 998.3451788766267
Objective value for closed form solution (by hand): 998.3451788766273


# Diagonal Case

Let $r_i \in R^d, W\in R^{d\times D}, \Delta = diag(\delta_1, ..., \delta_D) \in R^{D\times D}, x_i\in R^D$ for all $i\in[N]$. Fix regularization $\lambda > 0$. Then the minimum of 
\begin{align*}
    J(\Delta) = \frac{1}{N} \sum_{i=1}^N \big\| r_i - W\Delta x_i \big\|^2 + \lambda \|\Delta\|^2
\end{align*}
is attained by solving the system of linear equations
\begin{align*}
    b = (A + \lambda I)\Delta
\end{align*}
where
\begin{align*}
    A_{k,j} = \frac{1}{N} \sum_{i=1}^D \big( W_k x_{i,k} \big)^T W_j x_{i,j}
\end{align*}
and
\begin{align*}
    b_k = \frac{1}{N}\sum_{i=1}^N r_i^T W_k x_{i,k}
\end{align*}


In [335]:
# Parameters
d = 30
D = 20
N = 1000
lambda_reg = 10

# Create dummy data
np.random.seed(0)
r = np.random.randn(N, d)+2
W = np.random.randn(d, D)/100
x = np.random.randn(N, D)-1

def A_byhand():
    A = np.zeros((D, D))
    for k in range(D):
        for j in range(D):
            A[k, j] = np.mean([ x[i, k] * x[i, j] * np.dot(W[:, k], W[:, j]) for i in range(N)])
    return A


def A():
    return (W.T @ W) * (x.T @ x) / N


def b_byhand():
    b = np.zeros(D)
    for k in range(D):
        b[k] = np.mean([ x[i, k] * np.dot(W[:, k], r[i]) for i in range(N)])
    return b


def b():
    #return np.mean( (r @ W) * x, axis=0)
    #return np.diag(W.T @ r.T @ x) / N
    return np.einsum('nd,dk,nk->k', r, W, x) / N

In [336]:
(A() - A_byhand()).mean()

-9.452887691357992e-21

In [337]:
(b() - b_byhand()).mean()

5.585809592645319e-17

In [338]:
def J(Delta):
    return np.mean(np.linalg.norm(r - x @ np.diag(Delta) @ W.T, axis=1)**2) + lambda_reg * np.linalg.norm(Delta)**2

def J_byhand(Delta):
    res = 0
    for i in range(N):
        res += np.linalg.norm(r[i] - W @ (Delta*x[i]))**2 / N
    return res + lambda_reg * np.sum(Delta**2)

In [339]:
b = np.mean( (r @ W) * x, axis=0)
A = (W.T @ W) * (x.T @ x) / N
Delta_closed_form = np.linalg.solve(A + lambda_reg * np.eye(D), b)
print(f"Closed form solution for Delta: {Delta_closed_form}")
print(f"Objective value for closed form solution: {J(Delta_closed_form)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta_closed_form)}")

Closed form solution for Delta: [-0.005  0.011  0.011 ...  0.006  0.013  0.017]
Objective value for closed form solution: 149.0212011905012
Objective value for closed form solution (by hand): 149.02120119050105


In [340]:
# Gradient descent solution using scipy.optimiz
result = minimize(J, np.random.randn(D), method='BFGS')
Delta = result.x
print(f"Gradient descent using scipy.optimize: {Delta}")
print(f"Objective value for gradient descent solution: {J(Delta)}")
print(f"Objective value for closed form solution (by hand): {J_byhand(Delta)}")

Gradient descent using scipy.optimize: [-0.005  0.011  0.011 ...  0.006  0.013  0.017]
Objective value for gradient descent solution: 149.02120119051318
Objective value for closed form solution (by hand): 149.02120119051318


# Dense Case