In [1]:
import jax.numpy as jnp
import jax.random as jrandom

import matplotlib.pyplot as plt
import pandas as pd
from jax import jit, grad, jacfwd, jacrev



from scipy.stats import linregress

from tqdm import tqdm
import pickle

import numpy as np

from jax.config import config
config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams['figure.figsize'] = [12,12]
plt.style.use('ggplot')
plt.rcParams['lines.markersize'] = 8
plt.rcParams.update({'font.size': 12})

We want to minimize the following 
$$ 
    tr(\Sigma V^T D V \Sigma)
$$
with respect to $V$ with $V$ being an orthogonal matrix.

Differentiating with respect to $V$ we have 

$$
    \nabla_V tr(\Sigma V^T D V \Sigma) \\
    = 2 D V \Sigma^2.
$$
    
    
Now, we also have the condition that the matrix has to be orthogonal $(V^T V) = I$ which leads to the lagrangian 

$$
    \mathcal{L}(V, \lambda) = tr(\Sigma V^T D V \Sigma) + \sum_{i = 1}^d \sum_{j = 1}^d \lambda_{i, j}(v_i^T v_j - \delta_{ij}).
$$

The gradient of the constraint is $V (\lambda + \lambda^T)$. Notice here that $(\lambda + \lambda^T)$ is symmetric. 

We get from the KKT conditions that 

$$
    2 D V \Sigma^2 = V (\lambda + \lambda^T) \\
    V^T D V \Sigma^2 = \frac{1}{2} (\lambda + \lambda^T)
$$
and since the RHS is symmetric we have that 
$$
    V^T D V \Sigma^2 = \Sigma^2 V^T D V
$$

implying that $V^T D V$ commutes with $\Sigma^2$. Diagonalizable matricies which commute are mutually diagonalizable. Hence there exists $P$ orthogonal such that $P^{-1} V^T D V P$ and $P^{-1} \Sigma^2 P$ are diagonal matricies. 

Assume such a $P$ exists then we have that for some diagonal matrix $K$ that
$$ 
    P^{-1} \Sigma^2 P = K \\
    \Sigma^2 P K^{-1} = P \\
    \rightarrow \frac{1}{K_i} \Sigma^2 P_i = P_i \\
    \frac{1}{K_i} \Sigma^2_j P_{ij} = P_{ij}. \\
$$

Since $P$ has to have full rank at least one $P_{ij} \neq 0$. Assume that $P_{i j} \neq 0$ for some $j$ then we have
$$
    K_i = \Sigma^2_j
$$
which immidiately implies that for any other $j$ for which $\Sigma^2_j \neq K_i$ $P_{ij} = 0$. 





In [7]:
dim = 5
D = jnp.diag(jnp.linspace(10, 20, dim))
Sigma = jnp.diag(jnp.linspace(1, 10, dim))

jrandom_key = jrandom.PRNGKey(0)
jrandom_key, subkey = jrandom.split(jrandom_key)
V = jrandom.normal(subkey, shape=(dim, dim,))

predicted_grad = 2 * D @ V @ Sigma**2

def g(V, h):
    res = np.zeros(shape=(dim, dim))
    for i in range(dim):
        for j in range(dim):
            h_add = np.zeros(shape=(dim, dim))
            h_add[i, j] = h
            partial = (jnp.trace(Sigma @ (V + h_add).T @ D @ (V + h_add) @ Sigma) - jnp.trace(Sigma @ (V).T @ D @ (V) @ Sigma))/h
            res[i, j] = partial
            
    return res

# print(predicted_grad)

# g(V, 0.01)




In [20]:
# jrandom_key, subkey = jrandom.split(jrandom_key)
lmbda = jrandom.normal(subkey, shape=(dim, dim,))

def g_const(V, h):
    res = np.zeros(shape=(dim, dim))
    for i in range(dim):
        for j in range(dim):
            h_add = np.zeros(shape=(dim, dim))
            h_add[i, j] = h
            partial = (jnp.sum(lmbda * ((V + h_add).T @ (V + h_add) - np.eye(dim))) - jnp.sum(lmbda * (V.T @ V - np.eye(dim))))/h
            res[i, j] = partial
            
    return res

predicted_g_const = V @ (lmbda + lmbda.T) 

print(predicted_g_const)
g_const(V, 0.001)


[[-2.24155725 -3.41796199 -2.07293438  3.90568685 -0.16045151]
 [ 2.18186568  1.3422906   3.16194906  6.81276762 -1.24828383]
 [ 2.64887435  1.72390982  1.81490638 -4.09982664 -1.89533527]
 [ 0.11279967  1.23804592  1.21735931  2.09486267 -1.05874889]
 [-0.13607657  1.78528725  0.65198301 -2.30458692 -2.61382842]]


array([[-2.24182828, -3.41927897, -2.0733179 ,  3.90614404, -0.15934457],
       [ 2.18159466,  1.34097362,  3.16156553,  6.81322481, -1.24717689],
       [ 2.64860332,  1.72259284,  1.81452286, -4.09936945, -1.89422833],
       [ 0.11252864,  1.23672894,  1.21697579,  2.09531986, -1.05764195],
       [-0.1363476 ,  1.78397027,  0.65159949, -2.30412973, -2.61272148]])

In [21]:
def constraint_grad(V, lmbda):
    return V @ (lmbda + lmbda.T)

def obj_grad(D, V, Sigma):
    return 2 * D @ V @ Sigma**2


V = jnp.eye(dim)

print(obj_grad(D, V, Sigma))

[[  20.        0.        0.        0.        0.    ]
 [   0.      264.0625    0.        0.        0.    ]
 [   0.        0.      907.5       0.        0.    ]
 [   0.        0.        0.     2102.1875    0.    ]
 [   0.        0.        0.        0.     4000.    ]]
