In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from tqdm.notebook import trange, tqdm

In [4]:
%matplotlib notebook
np.set_printoptions(precision=4, linewidth=500, threshold=500, suppress=True)

In [5]:
from numpy import kron, eye as I, exp, trace as tr, diag
from numpy.linalg import inv, eigh

In [6]:
from utils import vec, mat, get_chain_graph, get_random_graph, matrix_derivative_numerical, mat_pow, diagi

In [54]:
def get_params(T, N, gamma, beta, random_graph=False, seed=True, p=0.5):
    
    if seed:
        np.random.seed(1)
    
    Y = np.random.normal(size=(N, T)) 
    S = np.random.choice([0, 1], p=[1 - p, p], replace=True, size=(N, T))
    S_ = 1 - S
    Y = Y * S

    K = np.exp(-(np.linspace(0, 3, T)[:, None] - np.linspace(0, 3, T)[None, :]) ** 2) + 1e-4 * I(T)
    
    if random_graph:
        _, LT = get_random_graph(T)
        _, LN = get_random_graph(N)
        
    else:
        _, LT = get_chain_graph(T)
        _, LN = get_chain_graph(N)

    lamLT, UT = eigh(LT)
    lamLN, UN = eigh(LN)
    lamK, V = eigh(K)
        
    lamT = exp(-beta * lamLT) ** 2
    lamN = exp(-beta * lamLN) ** 2

    HT = UT @ diag(lamT) @ UT.T
    HN = UN @ diag(lamN) @ UN.T
    
    J = np.outer(lamN, lamT) / (np.outer(lamN, lamT) + gamma)
    G = np.outer(lamN, lamT)
    
    return T, N, gamma, Y, S, S_, lamT, lamN, HT, HN, UT, UN, K, lamK, V, J, G

T, N, gamma, Y, S, S_, lamT, lamN, HT, HN, UT, UN, K, lamK, V, J, G = get_params(T=8, N=5, beta=1, gamma=1.4, random_graph=False)

# 1. Basic solution

For a cost function 

$$
\newcommand{\vecc}[1]{\text{vec}(#1)}
\newcommand{\Vec}[1]{\text{vec}\big(#1\big)}
\newcommand{\VEC}[1]{\text{vec}\Big(#1\Big)}
\newcommand{\diag}[1]{\text{diag}(#1)}
\newcommand{\Diag}[1]{\text{diag}\big(#1\big)}
\newcommand{\DIAG}[1]{\text{diag}\Big(#1\Big)}
\newcommand{\aand}{\quad \text{and} \quad}
\newcommand{\orr}{\quad \text{or} \quad}
\newcommand{\for}{\; \text{for} \;}
\newcommand{\with}{\quad \text{with} \quad}
\newcommand{\where}{\quad \text{where} \quad}
\newcommand{\iif}{\quad \text{if} \quad}
\newcommand{\SN}{\Sigma_N}
\newcommand{\ST}{\Sigma_T}
\newcommand{\SNi}{\Sigma_N^{-1}}
\newcommand{\STi}{\Sigma_T^{-1}}
\newcommand{\tr}[1]{\text{tr}\big(#1\big)}
\newcommand{\Tr}[1]{\text{tr}\Big(#1\Big)}
\newcommand{\R}{\mathbb{R}}
C(F) = \tr{(Y - S \circ F)^\top (Y- S \circ F)} + \gamma \, \tr{H_N^{-2} F H_T^{-2} F^\top}
$$

The minimising value of $F$ is 

$$
\vecc{\hat{F}} = \Big( \Diag{\vecc{S}} + \gamma H_T^{-2} \otimes H_N^{-2}\Big)^{-1} \vecc{Y}
$$

In [55]:
HNi = inv(HN)
HTi = inv(HT)

def ff(F):
    return tr((Y - S * F).T @ (Y - S * F)) + gamma * tr(HNi @ F @ HTi @ F.T)

F_star = inv(diag(vec(S)) + gamma * kron(HTi, HNi)) @ vec(Y)

np.allclose(matrix_derivative_numerical(ff, mat(F_star, like=Y)), 0, atol=1e-5)

True

# Transformed version

Define $Z$ in the following way

$$
F = U_N \, (G^{1/2} \circ Z) \, U_T^\top
$$

Then 

$$
C(Z) = \Tr{\big(Y - S \circ (U_N \, (G^{1/2} \circ Z) \, U_T^\top) \big)^\top \big(Y- S \circ (U_N \, (G^{1/2} \circ Z) \, U_T^\top) \big)} + \gamma \, \tr{Z^\top Z}
$$

In [56]:
Z = np.random.normal(size=Y.shape)
F = UN @ (G ** 0.5 * Z) @ UT.T

def fz(Z):
    return tr((Y - S * (UN @ (G ** 0.5 * Z) @ UT.T)).T @ (Y - S * (UN @ (G ** 0.5 * Z) @ UT.T))) + gamma * tr(Z.T @ Z)

def derivz(Z):
    return 2 * gamma * Z - 2 * UN.T @ Y @ UT * G ** 0.5 + 2 * (UN.T @ (S * (UN @ (G ** 0.5 * Z) @ UT.T)) @ UT) * G ** 0.5

In [57]:
np.isclose(fz(Z), ff(F))

True

In [58]:
np.allclose(matrix_derivative_numerical(fz, Z), derivz(Z), atol=1e-5)

True

In [63]:
R = 1 / (S + gamma)

Z_star = G ** - 0.5 * (UN @ (R * (UN @ (G * (UN.T @ Y @ UT)) @ UT.T)) @ UT)

In [64]:
derivz(Z_star)

array([[  0.6784,  -0.694 ,   0.5546,  -0.1929,  -0.2861,   0.046 ,   0.9709,  -1.1446],
       [  0.2833,  -1.1155,   0.2591,   0.3048,  -0.1293,  -0.828 ,  -0.735 ,   2.1988],
       [ -0.5842,   0.7905,  -0.5306,  -0.5123,   2.6439,   3.6093,  -8.9486,   1.9216],
       [  3.3063,  -8.4276,   0.6929,   2.3785,  -0.9011, -11.7847,  50.4395, -11.1765],
       [-15.5925,  14.6645, -17.1637,   4.8492,  13.8062,   3.0089,  -6.242 ,  30.6649]])