In [1]:
import tensorflow as tf
import matrix_decompositions_tf as fctr
import tf_rewrites as tfr

In [2]:
R = 1
Khat = 64
K = 5
rho = tf.cast(100.,tf.complex128)

In [3]:
class FFT(tf.keras.layers.Layer):
    def __init__(self,fft_length,*args,**kwargs):
        self.fft_length = fft_length
        super().__init__(*args,**kwargs)
    def fft(self,inputs):
        return tf.signal.rfft(input_tensor=inputs,fft_length=(self.fft_length,))
    def call(self,inputs):
        return self.fft(inputs)
    def get_config(self):
        return {'fft_length': self.fft_length}
class IFFT(FFT):
    def fft(self,inputs):
        return tf.signal.irfft(input_tensor=inputs,fft_length=(self.fft_length,))

class Trunc(tf.keras.layers.Layer):
    def __init__(self,truncate_length,axis,*args,**kwargs):
        self.truncate_length = truncate_length
        self.axis = axis
        super().__init__(*args,**kwargs)
    def call(inputs):
        slices = (slice(None),)*axis + (slice(0,self.truncate_length),)
        return inputs[slices[:]]
    def get_config(self):
        return {'truncate_length': self.truncate_length, 'axis': self.axis}
def get_lowrank_approx(A,*args,**kwargs):
    U,s,V = fctr.randomized_svd(A,*args,**kwargs)
    if A.shape[1] > A.shape[0]:
        U = U*tf.cast(tf.reshape(s,(1,-1)),U.dtype)
    else:
        V = V*tf.cast(tf.reshape(s,(1,-1)),V.dtype)
    return (U,V)
def stack_svd(A,*args,**kwargs):
    Ashape = A.shape
    U,V = get_lowrank_approx(tf.reshape(A,(A.shape[0],-1)),*args,**kwargs)
    R = U.shape[1]
    U = tf.reshape(U,(Ashape[0],R,1,))
    Vt = tf.reshape(tf.transpose(V,perm=(1,0)),(R,) + Ashape[1:])
    return U,Vt

In [4]:
fft = FFT(fft_length = Khat,dtype=tf.float64)
ifft = IFFT(fft_length = Khat,dtype = tf.complex128)
trunc = Trunc(truncate_length = K,axis = 2,dtype = tf.float64)


In [5]:
C = 2000
M = C

In [6]:
def update_Cholesky(L,Dupdate,Dhat,R):
    # low rank approximation in spatial domain
    U,Vt = stack_svd(Dupdate,n_components=R,n_oversamples=10)

    # Convert to frequency domain
    Uhat = tf.cast(U,tf.complex128)
    VhatH = fft(Vt)

    # Get Rank-1 Updates
    UtU = tf.cast(tf.reduce_sum(U*U,axis=0,keepdims = False),tf.complex128)
    Vhat = tf.transpose(VhatH,perm=(1,0,2),conjugate=True)

    # Get Rank-2 Updates
    Dhu = tf.transpose(tf.linalg.matmul(tf.transpose(Dhat,perm=(2,0,1)),tf.transpose(Uhat,perm=(2,0,1)),adjoint_a = True),perm=(1,2,0))
    eta_u = tf.cast(tf.reduce_sum(tf.math.conj(Dhu)*Dhu,axis = 0,keepdims = True),tf.float64)
    eta_v = tf.cast(tf.reduce_sum(tf.math.conj(Vhat)*Vhat,axis = 0,keepdims = True),tf.float64)
    eta_uv = tf.reduce_sum(tf.math.conj(Dhu)*Vhat,axis = 0,keepdims = True)

    radicand = tf.math.sqrt(eta_u*eta_v - tf.math.imag(eta_uv)**2)
    eig_vecs_plus = tf.cast(eta_u,tf.complex128)*Vhat + (-1j*tf.cast(tf.math.imag(eta_uv),tf.complex128) + tf.cast(radicand,tf.complex128))*Dhu
    eig_vecs_minus = tf.cast(eta_u,tf.complex128)*Vhat - (1j*tf.cast(tf.math.imag(eta_uv),tf.complex128) + tf.cast(radicand,tf.complex128))*Dhu
    eig_vals_plus = tf.cast(tf.squeeze(tf.math.real(eta_uv) + radicand,axis = 0),tf.complex128)
    eig_vals_minus = tf.cast(tf.squeeze(tf.math.real(eta_uv) - radicand,axis = 0),tf.complex128)

    # permute
    Vhat = tf.transpose(Vhat,perm=(2,0,1))
    eig_vecs_plus = tf.transpose(eig_vecs_plus,perm = (2,0,1))
    eig_vecs_minus = tf.transpose(eig_vecs_minus,perm = (2,0,1))
    for val,vec in zip(tf.unstack(UtU,axis=0),tf.unstack(Vhat,axis = -1)):
        L = tfr.cholesky_update(L,vec,val)

    for vals,vecs in zip((eig_vals_plus,eig_vals_minus),(eig_vecs_plus,eig_vecs_minus)):
        for val,vec in zip(tf.unstack(vals,axis=0),tf.unstack(vecs,axis=-1)):
            L = tfr.cholesky_update(L,vec,val)
    return L


In [7]:
def generate_D_update(C,M,K,R,a = 1.,bminusa = 4.,noisefloor = 0.1,rescale=32):
    U = tf.random.normal(shape=(K,C,R),dtype=tf.float64)
    V = tf.random.normal(shape=(K,M,R),dtype=tf.float64)
    S = bminusa*tf.random.uniform(shape=(1,1,R),dtype=tf.float64) + a
    Dupdate = tf.linalg.matmul(U*S,V,transpose_b = True) + noisefloor*tf.random.normal(shape=(K,C,M),dtype=tf.float64)
    return tf.transpose(Dupdate,perm=(1,2,0))/rescale

In [8]:
def compute_Cholesky(rho,Dhat):
    M = Dhat.shape[1]
    Dhat_t = tf.transpose(Dhat,perm=(2,0,1))
    idMat = tf.linalg.eye(num_rows = M,batch_shape = (1,),dtype=tf.complex128)
    DhD = tf.linalg.matmul(Dhat_t,Dhat_t,adjoint_a=True)
    return tf.linalg.cholesky(rho*idMat + DhD)

In [9]:
D = tf.random.normal(shape=(C,M,Khat),dtype=tf.float64)
Dupdate = generate_D_update(C,M,K,R)
Dhat = fft(D)
Dupdatehat = fft(Dupdate)
L = compute_Cholesky(rho,Dhat)


In [10]:
%time Lnew = compute_Cholesky(rho,Dhat + Dupdatehat)


CPU times: user 6min 14s, sys: 11.6 s, total: 6min 25s
Wall time: 54.5 s


In [11]:
%time Lnewapprox = update_Cholesky(L,Dupdate,Dhat,R)

CPU times: user 2min, sys: 27.2 s, total: 2min 27s
Wall time: 49.9 s
