In [1]:
import tensorflow as tf
import matrix_decompositions_tf as fctr

In [None]:
R = 4
Khat = 256
K = 25
rho = tf.cast(1.,tf.float64)

In [None]:
class FFT(tf.keras.layers.Layer):
    def __init__(self,fft_length,*args,**kwargs):
        self.fft_length = fft_length
        super().__init__(*args,**kwargs)
    def fft(self,inputs):
        return tf.signal.rfft(input_tensor=inputs,fft_length=(self.fft_length,))
    def call(self,inputs):
        return self.fft(inputs)
    def get_config(self):
        return {'fft_length': self.fft_length}
class IFFT(fft):
    def fft(self,inputs):
        return tf.signal.irfft(input_tensor=inputs,fft_length=(self.fft_length,))

class Trunc(tf.keras.layers.Layer):
    def __init__(self,truncate_length,axis,*args,**kwargs):
        self.truncate_length = truncate_length
        self.axis = axis
        super().__init__(*args,**kwargs)
    def call(inputs):
        slices = (slice(None),)*axis + (slice(0,self.truncate_length),)
        return inputs[slices[:]]
    def get_config(self):
        return {'truncate_length': self.truncate_length, 'axis': self.axis}
def get_lowrank_approx(A,*args,**kwargs):
    U,s,V = fctr.randomized_svd(A,*args,**kwargs)
    if A.shape[1] > A.shape[0]:
        U = U*tf.cast(tf.reshape(s,(1,-1)),U.dtype)
    else:
        V = V*tf.cast(tf.reshape(s,(1,-1)),V.dtype)
    return (U,V)
def stack_svd(A,*args,**kwargs):
    Ashape = A.shape
    U,V = get_lowrank_approx(tf.reshape(A,(A.shape[0],-1)),*args,**kwargs)
    R = U.shape[1]
    U = tf.reshape(U,(Ashape[0],R,1,))
    Vt = tf.reshape(tf.transpose(V,perm=(1,0)),(R,) + Ashape[1:])
    return U,Vt

In [None]:
fft = FFT(fft_length = Khat,dtype=tf.float64)
ifft = IFFT(fft_length = Khat,dtype = tf.complex128)
trunc = Trunc(truncate_length = K,axis = 2,dtype = tf.float64)


In [None]:
C = 64
M = C

In [None]:
D = tf.random.normal(shape=(C,M,Khat))
Dupdate = 0. #need some sort of scewed distribution here.
Dhat = fft(D)
def update_cholesky(L,Dupdate,Dhat,R)
    # low rank approximation in spatial domain
    U,Vt = stack_svd(Dupdate,n_components=R,n_oversamples=10)

    # Convert to frequency domain
    Uhat = tf.cast(U,tf.complex128)
    VhatH = fft(V)

    # Get Rank-1 Updates
    UtU = tf.reduce_sum(U*U,axis=0,keepdims = False)
    Vhat = tf.transpose(VhatH,perm=(1,0,2),conjugate=True)

    # Get Rank-2 Updates
    Dhu = tf.transpose(tf.linalg.matmul(tf.transpose(Dhat,perm=(2,0,1)),tf.transpose(Uhat,perm=(2,0,1)),adjoint_a = True),perm=(1,2,0))
    eta_u = tf.reduce_sum(tf.math.conj(Dhu)*Dhu,axis = 0,keepdims = True)
    eta_v = tf.reduce_sum(tf.math.conj(Vhat)*Vhat,axis = 0,keepdims = True)
    eta_uv = tf.reduce_sum(tf.math.conj(Dhu)*Vhat,axis = 0,keepdims = True)

    radicand = tf.math.sqrt(eta_u*eta_v - tf.math.imag(eta_uv)**2)
    eig_vecs_plus = eta_u*Vhat + (-1j*tf.math.imag(eta_uv) + radicand)*Dhu
    eig_vecs_minus = eta_u*Vhat - (1j*tf.math.imag(eta_uv) + radicand)*Dhu
    eig_vals_plus = tf.squeeze(tf.math.real(eta_uv) + radicand,axis = 0)
    eig_vals_minus = tf.squeeze(tf.math.real(eta_uv) - radicand,axis = 0)

    # permute
    Vhat = tf.transpose(Vhat,perm=(2,0,1))
    eig_vecs_plus = tf.transpose(eig_vecs_plus,perm = (2,0,1))

    for val,vec in zip(tf.unstack(UtU,axis=0),tf.unstack(Vhat,axis = 2)):
        L = tfr.cholesky_update(L,vec,val)

    for vecs,vals in zip(((eig_vals_plus,eig_vals_minus),(eig_vecs_plus,eig_vecs_minus))):
        for val,vec in zip(tf.unstack(vals,axis=0),tf.unstack(vecs,axis=2)):
            L = tfr.cholesky_update(L,vec,val)
    return L



In [9]:
x = tf.random.normal(shape=(25,256))
y = tf.signal.rfft(input_tensor=x,fft_length=(256,))
print(y.shape)

(25, 129)


In [None]:
idMat = tf.linalg.eye(num_rows = self.nof,batch_shape = (1,1,1),dtype=tf.complex128)
L = tf.linalg.cholesky(tf.cast(self.rho,tf.complex128)*idMat + tf.linalg.matmul(a = Dfprev,b = Dfprev,adjoint_a = True))