In [8]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
import scipy.linalg as spla
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import numpy.random as npr

In [5]:
def create_sigma(Ur, Ul, A, B, C, D, log_yl_var, log_yr_var, log_w_var):

  Sig_yl = np.diag(np.exp(log_yl_var))
  Sig_yr = np.diag(np.exp(log_yr_var))
  Sig_w  = np.diag(np.exp(log_w_var))
  
#   print(f"{B @ Sig_yl @ B.T + A @ (Ur @ Sig_z @ Ur.T + Sig_yr) @ A.T + Sig_w=}")
    

  Sig_wz = spla.block_diag(
    B @ Sig_yl @ B.T + A @ (Ur @ Sig_z @ Ur.T + Sig_yr) @ A.T + Sig_w, 
    D @ Sig_yl @ D.T + C @ (Ur @ Sig_z @ Ur.T + Sig_yr) @ C.T + Sig_w,
    A @ Sig_yr @ A.T + B @ (Ul @ Sig_z @ Ul.T + Sig_yl) @ B.T + Sig_w,
    C @ Sig_yr @ C.T + D @ (Ul @ Sig_z @ Ul.T + Sig_yl) @ D.T + Sig_w,
  )

  M = np.vstack([
    B @ Ul,
    D @ Ul,
    A @ Ur,
    C @ Ur,
  ])

  return M @ Sig_z @ M.T + Sig_wz, M, Sig_wz

def npr_init_params(seed, z_dim=50, y_dim=50, w_dim=50, scale=.5):

    # Parameters
    Ur = scale * npr.randn(y_dim, z_dim)
    Ul = scale * npr.randn(y_dim, z_dim)

    A = scale * npr.randn(w_dim, y_dim)
    B = scale * npr.randn(w_dim, y_dim)
    C = scale * npr.randn(w_dim, y_dim)
    D = scale * npr.randn(w_dim, y_dim)

    log_yl_var = np.log(scale) + npr.randn(y_dim)
    log_yr_var = np.log(scale) +  npr.randn(y_dim)
    log_w_var = np.log(scale) +  npr.randn(w_dim)

    W = scale * npr.randn(20, w_dim)
    
    return (Ur, Ul, A, B, C, D, log_yl_var, log_yr_var, log_w_var, W)

Sig_z = np.eye(1)

In [107]:
params = npr_init_params(np.random.seed(1), 1,1,1, scale=1)


global_scale = (0.0,2.0)
@interact(scaleA=global_scale, scaleB=global_scale ,scaleC=global_scale, scaleD=global_scale, scaleUr=global_scale, scaleUl=global_scale)
def plot_dependency(scaleA, scaleB, scaleC, scaleD,scaleUr,scaleUl):
    Ur, Ul, A, B, C, D, log_yl_var, log_yr_var, log_w_var = params[:-1]
    new_params = (scaleUr*Ur, scaleUl*Ul, scaleA*A, scaleB*B, scaleC*C, scaleD*D, log_yl_var, log_yr_var, log_w_var)
     
    S, _, Sig_wz = create_sigma(*new_params)
    
    corr_mat_w = S / np.outer(np.sqrt(np.diag(S)), np.sqrt(np.diag(S)))
    corr_mat_wz = Sig_wz / np.outer(np.sqrt(np.diag(Sig_wz)), np.sqrt(np.diag(Sig_wz)))
    inv_cm_w = spla.inv(corr_mat_w)
    inv_cm_wz = spla.inv(corr_mat_wz)
    fig, axs = plt.subplots(1, 2, figsize=(20,30))
    mats = [S, spla.inv(S)]#, Sig_wz, spla.inv(Sig_wz)]
    titles = ["$\Sigma_W$", "$\Lambda_W = \Sigma^{-1}_W$"]#, "$\Sigma_{W|Z}$", "$\Sigma^{-1}_{W|Z}$"]
    mask = np.zeros_like(mats[0])
    mask[np.tril_indices_from(mask)] = True
    fig.subplots_adjust( wspace=0.50)
#     cbar_ax = fig.add_axes([.91, .3, .03, .4])
    cmaps=["gist_ncar","gist_ncar"]
    vmin = np.min(mats)
    vmax = np.max(mats)
    with sns.axes_style("dark"):
        for i, ax in enumerate(axs.ravel()):  # reserve last axis for colorbar
    #         im = ax.imshow(mats[i], cmap="gist_stern")
            ax.set_title(titles[i], fontsize='xx-large')
            ax.xaxis.tick_top()
            hm = sns.heatmap(mats[i], ax=ax, square=True, annot=True,annot_kws={'fontsize': 'large'}, cbar=False, cmap=cmaps[i])
            hm.set_xticklabels(["$w_{-2}$","$w_{-1}$","$w_{1}$","$w_{2}$"], fontsize='x-large')
            hm.set_yticklabels(["$w_{-2}$","$w_{-1}$","$w_{1}$","$w_{2}$"],fontsize='x-large')
            
#         fig.colorbar(axs[1].collections[0], cax=axs[2])


#             plt.show()
            # fraction and pad are magic numbers from stackoverflow
    desc_str = """
    Effect of Scaling $U_r$, $U_l$, A,B,C,D on $\Sigma_W$ and $\Lambda_W = \Sigma^{-1}_W$:\n
    diag($[[B\Sigma_\ell B^T + A(U_r\Sigma_z U_r^T + \Sigma_r)A^T + \Sigma_w]~,\qquad ~~~ [[BU_\ell \Sigma_z (BU_\ell)^T~, BU_\ell \Sigma_z (DU_\ell)^T~, BU_\ell \Sigma_z (AU_r)^T~, BU_\ell \Sigma_z (CU_r)^T]~,$
    $~~~~\qquad[D \Sigma_\ell D^T + C(U_r\Sigma_zU_r^T+ \Sigma_r)C^T +\Sigma_w]~, \quad + \quad [DU_\ell \Sigma_z (BU_\ell)^T~,  DU_\ell \Sigma_z (DU_\ell)^T~,  DU_\ell \Sigma_z (AU_r)^T~,  DU_\ell \Sigma_z (CU_r)^T]~,$
    $~~~~~~\quad[A\Sigma_r A^T + B(U_\ell\Sigma_zU_\ell^T + \Sigma_l)B^T + \Sigma_w],\qquad ~~~~~~[AU_r \Sigma_z (BU_\ell)^T~,    AU_r \Sigma_z (DU_\ell)^T~,    AU_r \Sigma_z (AU_r)^T~,    AU_r \Sigma_z (CU_r)^T]~,$
    $~~~\qquad[C\Sigma_r C^T + D(U_\ell\Sigma_zU_\ell^T + \Sigma_l)D^T + \Sigma_w]])\qquad ~~~~ [CU_r \Sigma_z (BU_\ell)^T~, CU_r \Sigma_z (DU_\ell)^T~, CU_r \Sigma_z (AU_r)^T~, CU_r \Sigma_z (CU_r)^T]]$"""
    fig.suptitle(desc_str , y=.73, fontsize='xx-large')

interactive(children=(FloatSlider(value=1.0, description='scaleA', max=2.0), FloatSlider(value=1.0, descriptio…