In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

np.set_printoptions(precision=3, linewidth=500, threshold=500, suppress=True, edgeitems=5)

pd.set_option('display.max_rows', 200)
pd.set_option('display.min_rows', 50)
pd.set_option('display.width', 1000)
pd.set_option('display.max_columns', 100)

%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [2]:
from numpy.random import randn
from numpy import diag, ndarray



In [3]:
import sys
sys.path.append('..')

In [5]:
from utils.kronecker import KroneckerOperator, KroneckerProduct, KroneckerSum, KroneckerDiag, _ProductChain, _SumChain
from utils.linalg import kronecker_product_literal, kronecker_sum_literal, vec, ten


In [87]:
N1 = 6
N2 = 5
N3 = 4
N4 = 3

A1 = randn(N1, N1)
A2 = randn(N2, N2)
A3 = randn(N3, N3)
A4 = randn(N4, N4)


A11 = randn(N1, N1)
A22 = randn(N2, N2)
A33 = randn(N3, N3)
A44 = randn(N4, N4)

X = randn(N4, N3, N2, N1)
D = randn(N4, N3, N2, N1)

In [172]:
import numpy as np
from numpy import ndarray, eye as I, diag
from typing import Union, Callable
import networkx as nx
from scipy.sparse import spmatrix

from algorithms.cgm import solve_SPCGM
from graph.filters import _FilterFunction, MultivariateFilterFunction, UnivariateFilterFunction
from graph.graphs import BaseGraph, ProductGraph
from models.reconstruction.reconstruction_utils import get_y_and_s
from utils.checks import check_compatible
from utils.linalg import vec, ten
from utils.kronecker import KroneckerBlock, KroneckerBlockDiag, KroneckerDiag

from numpy.linalg import eigh, solve
from scipy.optimize import minimize


In [173]:
class VarSolver:

    def __init__(self, Omega_Q: ndarray, Q: ndarray, X: ndarray, lam: float = 0.005):
        assert all([Omega_Q.shape == Q.shape]), f'Omega_Q and Q should all have the same shape, but they are {Omega_Q.shape} and {Q.shape} respectively'

        self.Omega_Q = Omega_Q
        self.X = X
        self.Q = Q.astype(bool)
        self.lam = lam
        self.params = None

    def rmse(self, Omega: ndarray):
        return (((Omega - self.predict()) ** 2).sum() / (np.prod(self.Q.shape))) ** 0.5

    def r_squared(self, Omega: ndarray):
        return 1 - ((Omega - self.predict()) ** 2).sum() / ((Omega - Omega.mean()) ** 2).sum()

    def _get_params(self):
        pass

    def predict(self):
        pass


class LFPVarSolver(VarSolver):

    def __init__(self,
                 Omega_Q: ndarray,
                 Q: ndarray,
                 X: ndarray,
                 graph: BaseGraph,
                 filter_function: _FilterFunction,
                 lam: float):

        super().__init__(Omega_Q, Q, X, lam)

        check_compatible(signal=Omega_Q, graph=graph, filter_function=filter_function)

        if isinstance(filter_function, UnivariateFilterFunction):
            self.x0 = np.array([0, 0, filter_function.beta, 0, filter_function.beta])
            
        elif isinstance(filter_function, MultivariateFilterFunction):
            self.x0 = np.array([0, 0] + filter_function.beta.tolist() + [0] + filter_function.beta.tolist())

        self.graph = graph
        self.filter_function = filter_function

        self.S_ = ten(self.X[:, 0], like=Q)
        self.A_ = ten(self.X[:, -2], like=Q)

        print('HEY')
        print(self.rmse(self.x0))


    def H(self, Y: ndarray, beta: float | ndarray):
        """
        Apply the operation Y -> ten(H @ vec(Y)) efficiently for a filter defined by λ -> η(λ; β)
        """

        self.filter_function.set_beta(beta)

        if isinstance(self.filter_function, MultivariateFilterFunction):
            G = self.filter_function(self.graph.lams)
        else:
            G = self.filter_function(self.graph.lam)


        return self.graph.scale_spectral(Y, G)

    def Omega(self, v: ndarray) -> ndarray:
        """
        Return the estitenor for Omega for a given objective vector
        """

        if isinstance(self.filter_function, MultivariateFilterFunction):
            beta1 = v[2:2+self.filter_function.ndim]
            beta2 = v[3+self.filter_function.ndim:]

        else:
            beta1 = v[2]
            beta2 = v[4]

        print(v)

        return v[0] + v[1] * self.H(self.S_, beta1) + v[2+self.filter_function.ndim] * self.H(self.A_, beta2)

    def objective(self, v: ndarray):
        """
        The objective function to minimise
        """
        return ((self.Omega_Q - self.Q * self.Omega(v)) ** 2).sum() + self.lam * ((v - self.x0) ** 2).sum()

    def _get_params(self, verbose=True):

        self.result = minimize(self.objective, x0=self.x0, bounds=[(None, None), (None, None)] + [(-1, None)] * self.filter_function.ndim + [(None, None)] + [(-1, None)] * self.filter_function.ndim)

        if verbose:
            print(self.result)

        self.params = [self.result.x]

    def predict(self, verbose=True):
        if self.params is None:
            self._get_params(verbose=verbose)
        return self.Omega(self.params[0])

#






In [200]:
np.random.seed(0)

# set variables
N1 = 20
N2 = 15
N = N1 * N2

graph = ProductGraph.lattice(N1, N2)
signal = np.random.rand(N2, N1)
signal[np.random.randint(0, 2, (N2, N1)).astype(bool)] = np.nan
gamma = 0.02

Y, S = get_y_and_s(signal)
filter_function = UnivariateFilterFunction.diffusion(beta=1)
G = filter_function(graph.lam)
n_neighbours = graph.A.sum(0)

X = np.array([np.ones(N),
              vec(1 - S),
              graph.U @ KroneckerDiag(G) @ graph.U.T @ vec(1 - S),
              (graph.U ** 2) @ vec(G),
              (graph.U ** 2) @ vec(G ** 2),
              n_neighbours,
              graph.U @ KroneckerDiag(G) @ graph.U.T @ n_neighbours
              ]).T

X[:, 1:] = X[:, 1:] / X[:, 1:].std(axis=0)
X[:, 1:] = X[:, 1:] - X[:, 1:].mean(axis=0)

S_ = ten(X[:, 1], like=Q)
A_ = ten(X[:, -2], like=Q)

# calculate the solution explicitly
H2 = graph.U.to_array() @ diag(vec(G ** 2)) @ graph.U.to_array().T
explicit_sigma = H2 @ np.linalg.inv(gamma * np.eye(N) + diag(vec(S)) @ H2)

Omega_true = np.log(ten(diag(explicit_sigma), like=Y))

Q = np.random.randint(0, 2, Omega_true.shape)
Omega_Q = np.zeros_like(Omega_true)
Omega_Q[Q.astype(bool)] = Omega_true[Q.astype(bool)]

P = X.shape[1]

In [202]:
def H(beta: float | ndarray):
    """
    Get a filter operator for a given value of beta
    """

    filter_function.set_beta(beta)

    if isinstance(filter_function, MultivariateFilterFunction):
        G = filter_function(graph.lams)
        
    else:
        G = filter_function(graph.lam)


    return graph.U @ KroneckerDiag(G) @ graph.U.T


def Omega(v: ndarray) -> ndarray:
    """
    Return the estitenor for Omega for a given objective vector
    """

    if isinstance(filter_function, MultivariateFilterFunction):
        beta1 = v[2:2+filter_function.ndim]
        beta2 = v[3+filter_function.ndim:]

    else:
        beta1 = v[2]
        beta2 = v[4]
        
    print(H(beta1) @ S_)
    print(H(beta2) @ A_)

    return v[0] + v[1] * H(beta1) @ S_ + v[2+filter_function.ndim] * H(beta2) @ A_

In [196]:
Omega([1, 1, 2.8, 0.5, 1])

[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[[-1.796 -1.321 -0.91  -0.719 -0.659 -0.644 -0.641 -0.

array([[1.102, 1.339, 1.545, 1.64 , 1.671, 1.678, 1.679, 1.68 , 1.68 , 1.68 , 1.68 , 1.68 , 1.68 , 1.679, 1.678, 1.671, 1.64 , 1.545, 1.339, 1.102],
       [1.339, 1.577, 1.783, 1.878, 1.908, 1.915, 1.917, 1.917, 1.917, 1.917, 1.917, 1.917, 1.917, 1.917, 1.915, 1.908, 1.878, 1.783, 1.577, 1.339],
       [1.545, 1.783, 1.988, 2.083, 2.114, 2.121, 2.122, 2.123, 2.123, 2.123, 2.123, 2.123, 2.123, 2.122, 2.121, 2.114, 2.083, 1.988, 1.783, 1.545],
       [1.64 , 1.878, 2.083, 2.179, 2.209, 2.216, 2.218, 2.218, 2.218, 2.218, 2.218, 2.218, 2.218, 2.218, 2.216, 2.209, 2.179, 2.083, 1.878, 1.64 ],
       [1.671, 1.908, 2.114, 2.209, 2.239, 2.247, 2.248, 2.248, 2.248, 2.248, 2.248, 2.248, 2.248, 2.248, 2.247, 2.239, 2.209, 2.114, 1.908, 1.671],
       [1.678, 1.915, 2.121, 2.216, 2.247, 2.254, 2.255, 2.256, 2.256, 2.256, 2.256, 2.256, 2.256, 2.255, 2.254, 2.247, 2.216, 2.121, 1.915, 1.678],
       [1.679, 1.917, 2.122, 2.218, 2.248, 2.255, 2.257, 2.257, 2.257, 2.257, 2.257, 2.257, 2.257, 2.257, 

In [None]:
def demonstrate_lfp_regression():

    fig, axes = plt.subplots(ncols=3, nrows=4, figsize=(6, 8))

    axes[0, 0].imshow(Omega_true)
    axes[0, 0].set_title('Ω true')

    for ax, lam in zip(axes.flatten()[1:], np.logspace(-3, 1, 11)):

        estimator_rnc =  LFPVarSolver(Omega_Q, Q, X, graph, filter_function, lam=lam)
        ax.imshow(estimator_rnc.predict(), vmin=Omega_true.min(), vmax=Omega_true.max())
        ax.set_title(f'λ = {lam:.3f}')

    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    fig.suptitle('LFP Estimates')

    plt.tight_layout()
    plt.show()
