In [1]:
import tensorflow as tf

class OuterProd(tf.keras.layers.Layer):
    def __init__(self,*args,**kwargs):
        super().__init__(self,*args,**kwargs)
        
    def call(self,input1,input2):
        factor1 = tf.reshape(input1,shape=input1.shape + (1,))
        factor2 = tf.math.conj(tf.reshape(input2,shape=input2.shape[:-1] + (1,) + input2.shape[-1]))
        return factor1*factor2


In [2]:
import numpy as np

class MatrixAndVector(tf.keras.layers.Layer):
    def __init__(self,v=None,*args,**kwargs):
        super().__init__(*args,**kwargs)
        self.v = v

    def build(self,input_shape):
        if self.v is not None:
            self.V = tf.Variable(initial_value=self.v,trainable=True,dtype=self.dtype)
        else:
            self.V = tf.Variable(initial_value=tf.random.normal(shape=(1,) + input_shape[1:]),trainable=True,dtype=self.dtype)
        self.setupA(input_shape,self.V)
        
    def setupA(self,input_shape,V):
        idmat = tf.eye(num_rows=input_shape[-1],batch_shape=(1,) + input_shape[1:-1],dtype=self.dtype)
        self.L = tf.Variable(initial_value = tf.linalg.cholesky(idmat + OuterProd(dtype=self.dtype)(self.V,self.V)), trainable = False,dtype=self.dtype)        
    def applyA(self,x):
        return tf.linalg.matvec(self.L,tf.linalg.matvec(self.L,x,adjoint_a=True))
    def call(self, inputs):
        Ax = self.applyA(inputs)
        vvhx = tf.reduce_sum(inputs*tf.math.conj(self.V),axis=-1,keepdims=True)*self.V
        return  Ax + vvhx


v_init = np.zeros(shape=(1,2))
v_init[slice(None),0] = 1
                             
x = tf.keras.layers.Input((2,))
ApVVh= MatrixAndVector(v_init)
y = ApVVh(x)
model = tf.keras.Model(x,y)
model.compile(optimizer = tf.keras.optimizers.SGD(),loss=tf.keras.losses.MeanSquaredError())

x_train=np.random.randn(1000,2)
y_train = 2*x_train
x_val=np.random.randn(100,2)
y_val = 2*x_val
model.fit(x=x_train,y=y_train,batch_size=10,epochs=16,validation_data=(x_val,y_val))
ApVVh.V
                             

Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16
Epoch 7/16
Epoch 8/16
Epoch 9/16
Epoch 10/16
Epoch 11/16
Epoch 12/16
Epoch 13/16
Epoch 14/16
Epoch 15/16
Epoch 16/16


<tf.Variable 'matrix_and_vector/Variable:0' shape=(1, 2) dtype=float32, numpy=array([[1.0283290e-09, 1.0000001e+00]], dtype=float32)>

In [4]:
ApVVh.dtype

'float32'

In [3]:
class PostProcess:
    update = {}
class CustomTrainModel(tf.keras.Model):
    def train_step(self,data):
        myoutputs = tf.keras.Model.train_step(self,data)
        for tv in self.trainable_variables:
            if tv.name in PostProcess.update:
                PostProcess.update[tv.name]()
        return myoutputs

In [75]:
def eigvecReversal(U,V):
    vtv = tf.reduce_sum(V*V,axis=-1,keepdims=True)
    utu = tf.reduce_sum(U*U,axis=-1,keepdims=True)
    utv = tf.reduce_sum(U*V,axis=-1,keepdims=True)
    utvs = tf.reshape(utv,utv.shape[:-1])
    rootRadicand = tf.math.sqrt(vtv*utu)
    rootRadicands = tf.reshape(rootRadicand,rootRadicand.shape[:-1])
        
    valPlus = utvs + rootRadicands
    valMinus = utvs - rootRadicands
    vecPlus = vtv*U + rootRadicand*V
    vecMinus = vtv*U - rootRadicand*V
    #vecPlus = tf.where(tf.abs(rootRadicand) > 1e-5,vecPlus,U)
    #vecMinus = tf.where(tf.abs(rootRadicand) > 1e-5,vecMinus,-tf.math.divide_no_nan(utv,utu)*U + V)
    #tf.debugging.Assert(condition=tf.abs(rootRadicands) > 1e-5,data=rootRadicands)
        
    vecPlus = tf.math.l2_normalize(vecPlus,epsilon=1e-5)
    vecMinus = tf.math.l2_normalize(vecMinus,epsilon=1e-5)
    return ((valPlus,valMinus),(vecPlus,vecMinus))
 

In [12]:
U = tf.random.normal((1,2))
V = tf.random.normal((1,2))
eigvals,eigvecs = eigvecReversal(U,V)

In [13]:
A = OuterProd()(U,V) + OuterProd()(V,U)
print(A)
A_again = eigvals[0]*OuterProd()(eigvecs[0],eigvecs[0]) + eigvals[1]*OuterProd()(eigvecs[1],eigvecs[1])
print(A_again)

tf.Tensor(
[[[-1.8185118   0.5211488 ]
  [ 0.5211488   0.77038383]]], shape=(1, 2, 2), dtype=float32)
tf.Tensor(
[[[-1.8185118  0.5211488]
  [ 0.5211488  0.7703838]]], shape=(1, 2, 2), dtype=float32)


In [9]:


import tensorflow_probability as tfp

class NewMatrixAndVector(MatrixAndVector):
    def build(self,input_shape):
        super().build(input_shape)
    def setupA(self,input_shape,V):
        idmat = tf.eye(num_rows=input_shape[-1],batch_shape=(1,) + input_shape[1:-1],dtype=self.dtype)
        self.A = tf.Variable(initial_value = idmat + OuterProd(dtype=self.dtype)(self.V,self.V), trainable = False)
    def applyA(self,x):
        return tf.linalg.matvec(self.A,x)

v_init = np.zeros(shape=(1,2))
v_init[slice(None),0] = 1

ApVVh= NewMatrixAndVector(v_init)
                             
x = tf.keras.layers.Input((2,))
y = ApVVh(x)
model = CustomTrainModel(x,y)
model.compile(optimizer=tf.keras.optimizers.SGD(),loss=tf.keras.losses.MeanSquaredError(),run_eagerly=False)

x_train=np.random.randn(1000,2)
y_train = 2*x_train
x_val=np.random.randn(100,2)
y_val = 2*x_val
model.fit(x=x_train,y=y_train,batch_size=10,epochs=16,validation_data=(x_val,y_val))
ApVVh.V

Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16
Epoch 7/16
Epoch 8/16
Epoch 9/16
Epoch 10/16
Epoch 11/16
Epoch 12/16
Epoch 13/16
Epoch 14/16
Epoch 15/16
Epoch 16/16


<tf.Variable 'new_matrix_and_vector/Variable:0' shape=(1, 2) dtype=float32, numpy=array([[ 1.1247319e-08, -9.9999994e-01]], dtype=float32)>

In [4]:


import tensorflow_probability as tfp

class NewMatrixAndVector(MatrixAndVector,PostProcess):
    def build(self,input_shape):
        super().build(input_shape)
        self.Vfixed = tf.Variable(initial_value=tf.identity(self.V),trainable=False,dtype=self.dtype)
        PostProcess.update[self.V.name] = self.post_update
    def setupA(self,input_shape,V):
        idmat = tf.eye(num_rows=input_shape[-1],batch_shape=(1,) + input_shape[1:-1],dtype=self.dtype)
        self.A = tf.Variable(initial_value = idmat + OuterProd(dtype=self.dtype)(self.V,self.V), trainable = False)
    def applyA(self,x):
        return tf.linalg.matvec(self.A,x)
    def updateA(self,u_vec,mult):
        return self.A.assign(self.A + mult*OuterProd(dtype=self.dtype)(u_vec,u_vec))
    def post_update(self):
        U = self.V - self.Vfixed
        with tf.control_dependencies([U]):
            self.V.assign(self.Vfixed)
        newA = self.updateA(U,1.0)
        eigvals, eigvecs = eigvecReversal(U,self.Vfixed)
        for ii in range(2):
            with tf.control_dependencies([newA]):
                newA = self.updateA(eigvecs[ii],eigvals[ii])

v_init = np.zeros(shape=(1,2))
v_init[slice(None),0] = 1

ApVVh= NewMatrixAndVector(v_init)
                             
x = tf.keras.layers.Input((2,))
y = ApVVh(x)
model = CustomTrainModel(x,y)
model.compile(optimizer=tf.keras.optimizers.SGD(0.2),loss=tf.keras.losses.MeanSquaredError(),run_eagerly=False)

x_train=np.random.randn(1000,2)
y_train = 2*x_train
x_val=np.random.randn(100,2)
y_val = 2*x_val
model.fit(x=x_train,y=y_train,batch_size=10,epochs=16,validation_data=(x_val,y_val))
ApVVh.V

Epoch 1/16


NameError: in user code:

    /home/lrr/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    /home/lrr/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:796 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/lrr/anaconda3/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/lrr/anaconda3/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/lrr/anaconda3/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/lrr/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:789 run_step  **
        outputs = model.train_step(data)
    <ipython-input-3-2a29dd52d3f3>:8 train_step
        PostProcess.update[tv.name]()
    <ipython-input-4-018dd40a6d6e>:22 post_update
        eigvals, eigvecs = eigvecReversal(U,self.Vfixed)

    NameError: name 'eigvecReversal' is not defined


In [18]:
print(ApVVh.A)

AttributeError: 'NewMatrixAndVector' object has no attribute 'A'

In [76]:


import tensorflow_probability as tfp

class NewMatrixAndVector(MatrixAndVector,PostProcess):
    def build(self,input_shape):
        super().build(input_shape)
        self.Vfixed = tf.Variable(initial_value=tf.identity(self.V),trainable=False,dtype=self.dtype)
        PostProcess.update[self.V.name] = self.post_update
    def updateA(self,u_vec,mult):
        return self.L.assign(tfp.math.cholesky_update(self.L,update_vector=u_vec,multiplier=mult))
    def post_update(self):
        U = self.V - self.Vfixed
        #print(U)
        with tf.control_dependencies([U]):
            self.V.assign(self.Vfixed)
        newA = self.updateA(U,1.0)
        #print(tf.matmul(newA,newA,adjoint_b=True))
        eigvals, eigvecs = eigvecReversal(U,self.Vfixed)
        for ii in range(2):
            with tf.control_dependencies([newA]):
                newA = self.updateA(eigvecs[ii],eigvals[ii])

v_init = np.zeros((1,2))
b = np.zeros(shape=(1,2))
b[slice(None),0] = 1
v_init[slice(None),0] = 1

ApVVh= NewMatrixAndVector(v_init)
                             
x = tf.keras.layers.Input((2,))
y = ApVVh(x)
model = CustomTrainModel(x,y)
model.compile(optimizer=tf.keras.optimizers.SGD(0.2),loss=tf.keras.losses.MeanSquaredError(),run_eagerly=False)

x_train=np.random.randn(1000,2)
y_train = 2*x_train
x_val=np.random.randn(100,2)
y_val = 2*x_val
model.fit(x=x_train,y=y_train,batch_size=10,epochs=16,validation_data=(x_val,y_val))
ApVVh.V

Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16
Epoch 7/16
Epoch 8/16
Epoch 9/16
Epoch 10/16
Epoch 11/16
Epoch 12/16
Epoch 13/16
Epoch 14/16
Epoch 15/16
Epoch 16/16


<tf.Variable 'new_matrix_and_vector_49/Variable:0' shape=(1, 2) dtype=float32, numpy=array([[1., 0.]], dtype=float32)>

In [8]:
print(tf.linalg.matmul(ApVVh.L,ApVVh.L,adjoint_b=True))

tf.Tensor(
[[[1.00850006+0.j         0.0021042 +0.00019345j]
  [0.0021042 -0.00019345j 1.34594523+0.j        ]]], shape=(1, 2, 2), dtype=complex128)


In [4]:
def complexNum(x):
    return tf.complex(x,tf.cast(0.0,dtype = x.dtype))

def eigvecReversal(U,V):
    vhv = tf.reduce_sum(tf.math.conj(V)*V,axis=-1,keepdims=True)
    uhu = tf.reduce_sum(tf.math.conj(U)*U,axis=-1,keepdims=True)
    uhv = tf.reduce_sum(tf.math.conj(U)*V,axis=-1,keepdims=True)
    rootRadicand = tf.math.sqrt(vhv*uhu - complexNum(tf.math.imag(uhv)**2))
        
    valPlus = complexNum(tf.math.real(uhv)) + rootRadicand
    valMinus = complexNum(tf.math.real(uhv)) - rootRadicand
    vecPlus = vhv*U + (1j*complexNum(tf.math.imag(uhv)) + rootRadicand)*V
    vecMinus = vhv*U + (1j*complexNum(tf.math.imag(uhv))  - rootRadicand)*V
    vecPlus = tf.where(tf.abs(rootRadicand) > 1e-5,vecPlus,U)
    vecMinus = tf.where(tf.abs(rootRadicand) > 1e-5,vecMinus,-tf.math.divide_no_nan(uhv,uhu)*U + V)
    #tf.debugging.Assert(condition=tf.abs(rootRadicands) > 1e-5,data=rootRadicands)
        
    vecPlus = tf.math.l2_normalize(vecPlus,epsilon=1e-5)
    vecMinus = tf.math.l2_normalize(vecMinus,epsilon=1e-5)
    valPlus = tf.reshape(valPlus,valPlus.shape[:-1])
    valMinus = tf.reshape(valMinus,valMinus.shape[:-1])
    return ((valPlus,valMinus),(vecPlus,vecMinus))
 

In [64]:
U = tf.complex(tf.random.normal((1,2)),tf.random.normal((1,2)))
V = tf.complex(tf.random.normal((1,2)),tf.random.normal((1,2)))
eigvals,eigvecs = eigvecReversal(U,V)

In [65]:
A = OuterProd()(U,V) + OuterProd()(V,U)
print(A)
A_again = eigvals[0]*OuterProd()(eigvecs[0],eigvecs[0]) + eigvals[1]*OuterProd()(eigvecs[1],eigvecs[1])
print(A_again)

tf.Tensor(
[[[ 1.3619424 +0.j         -0.34927106+0.65208745j]
  [-0.34927106-0.65208745j -0.63275087+0.j        ]]], shape=(1, 2, 2), dtype=complex64)
tf.Tensor(
[[[ 1.3619424 +0.j        -0.34927112+0.6520874j]
  [-0.34927112-0.6520874j -0.632751  +0.j       ]]], shape=(1, 2, 2), dtype=complex64)


In [1]:
def complexNum(x):
    return tf.complex(x,tf.cast(0.0,dtype = x.dtype))

In [5]:
import tf_rewrites as tfp
class NewMatrixAndVector(MatrixAndVector,PostProcess):
    def build(self,input_shape):
        super().build(input_shape)
        self.Vfixed = tf.Variable(initial_value=tf.identity(self.V),trainable=False,dtype=self.dtype)
        PostProcess.update[self.V.name] = self.post_update
    def updateA(self,u_vec,mult):
        return self.L.assign(tfp.cholesky_update(self.L,update_vector=u_vec,multiplier=mult))
    def post_update(self):
        U = self.V - self.Vfixed
        with tf.control_dependencies([U]):
            self.V.assign(self.Vfixed)
        newA = self.updateA(U,1.0)
        eigvals, eigvecs = eigvecReversal(U,self.Vfixed)
        for ii in range(2):
            with tf.control_dependencies([newA]):
                newA = self.updateA(eigvecs[ii],eigvals[ii])

v_a = np.zeros((1,2))
v_a[slice(None),0] = 1
v_b = np.zeros(shape=(1,2))
#v_b[slice(None),1] = 0.5
v_init = tf.complex(v_a,v_b)

ApVVh= NewMatrixAndVector(v_init,dtype=tf.complex128)

                             
x = tf.keras.layers.Input((2,),dtype = tf.complex128)
y = ApVVh(x)
model = CustomTrainModel(x,y)
model.compile(optimizer=tf.keras.optimizers.SGD(.1),loss=tf.keras.losses.MeanSquaredError(),run_eagerly=False)

x_train=np.random.randn(1000,2) + 1j*np.random.randn(1000,2)
y_train = 2*x_train
x_val=np.random.randn(100,2) + 1j*np.random.randn(100,2)
y_val = 2*x_val
model.fit(x=x_train,y=y_train,batch_size=10,epochs=16,validation_data=(x_val,y_val))
ApVVh.V

Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16
Epoch 7/16
Epoch 8/16
Epoch 9/16
Epoch 10/16
Epoch 11/16
Epoch 12/16
Epoch 13/16
Epoch 14/16
Epoch 15/16
Epoch 16/16


<tf.Variable 'new_matrix_and_vector/Variable:0' shape=(1, 2) dtype=complex128, numpy=array([[1.+0.j, 0.+0.j]])>

In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import tf_rewrites as tfr
rho = 1.
x_shape = (4,3,5,5)
dtype = tf.complex128
def cmplx_normal(x_shape,dtype):
    return tf.complex(tf.random.normal(x_shape,dtype=dtype.real_dtype),tf.random.normal(x_shape,dtype=dtype.real_dtype))
x = cmplx_normal(x_shape,dtype)
idmat = tf.eye(num_rows=x_shape[-1],batch_shape = x_shape[:2],dtype= dtype)
A = rho*idmat + tf.linalg.matmul(x,x,adjoint_b=True)
v = cmplx_normal(x_shape[:-1],dtype)
m = tf.complex(0.*tf.random.normal((4,3,),dtype=dtype.real_dtype) + 1.,0.*tf.random.normal((4,3),dtype=dtype.real_dtype))

L = tf.linalg.cholesky(A)
L2 = tfr.cholesky_update(L,v,m)
L3 = tfp.math.cholesky_update(L,v,m)

In [2]:
def addDim(x):
    return tf.reshape(x,x.shape + (1,))

In [3]:
print(tf.reduce_max(tf.abs(A - tf.linalg.matmul(L,L,adjoint_b=True))))
A2 = A + addDim(addDim(m))*tf.linalg.matmul(addDim(v),addDim(v),adjoint_b=True)
print(tf.reduce_max(tf.abs(A2 - tf.linalg.matmul(L2,L2,adjoint_b=True))))
print(tf.reduce_max(tf.abs(A2 - tf.linalg.matmul(L3,L3,adjoint_b=True))))

tf.Tensor(3.552713678800501e-15, shape=(), dtype=float64)
tf.Tensor(5.329070518200751e-15, shape=(), dtype=float64)
tf.Tensor(28.918581336141166, shape=(), dtype=float64)


In [1]:
import tensorflow as tf
import matrix_decompositions_tf as fctr
def cmplx_normal(x_shape,dtype):
    return tf.complex(tf.random.normal(x_shape,dtype=dtype.real_dtype),tf.random.normal(x_shape,dtype=dtype.real_dtype))
def addDim(x):
    return tf.reshape(x,x.shape + (1,))

u_shape = (3,4,2,6)
v_shape = (3,1,1,6)
dtype = tf.complex128

u = cmplx_normal(v_shape,dtype)
v = cmplx_normal(v_shape,dtype)

eigvals,eigvecs = fctr.rank2eigen(u,v)

eigsum = 0.
for val,vec in zip(eigvals,eigvecs):
    eigsum = eigsum + addDim(addDim(val))*tf.linalg.matmul(addDim(vec),addDim(vec),adjoint_b=True)

straightSum = tf.linalg.matmul(addDim(u),addDim(v),adjoint_b = True) + tf.linalg.matmul(addDim(v),addDim(u),adjoint_b=True)

print(tf.reduce_max(tf.abs(eigsum - straightSum)))

tf.Tensor(2.6645352591003757e-15, shape=(), dtype=float64)


In [3]:
tf.reduce_sum(eigvecs[0]*tf.math.conj(eigvecs[0]),axis=-1,keepdims=True)

<tf.Tensor: shape=(3, 4, 2, 1), dtype=complex128, numpy=
array([[[[0.12376822+0.j],
         [0.20984116+0.j]],

        [[0.02174207+0.j],
         [0.0400276 +0.j]],

        [[0.02745349+0.j],
         [0.02636881+0.j]],

        [[0.06287725+0.j],
         [0.00775799+0.j]]],


       [[[0.12664901+0.j],
         [0.00206361+0.j]],

        [[0.01370283+0.j],
         [0.04698159+0.j]],

        [[0.02298322+0.j],
         [0.04410531+0.j]],

        [[0.01009955+0.j],
         [0.02290069+0.j]]],


       [[[0.00745083+0.j],
         [0.0144493 +0.j]],

        [[0.02748401+0.j],
         [0.02219594+0.j]],

        [[0.04715586+0.j],
         [0.03141533+0.j]],

        [[0.03774115+0.j],
         [0.00278519+0.j]]]])>