In [38]:
from tensorflow.keras.layers import Dense
class RegDense(Dense):
    def __init__(self, units, **kwargs):
        super().__init__(units, **kwargs)
        
    def build(self, input_shape):
        from tensorflow import linalg, transpose
        super().build(input_shape)
        _, u, v = linalg.svd(self.kernel, compute_uv=True)
        self.kernel.assign(u @ transpose(v))
    
    @property
    def loss(self):
        from tensorflow import matmul, linalg, square, reduce_mean
        w = self.kernel
        s = w.shape
        c = matmul(w, w, **{"transpose_" + ("b" if s[1] > s[0] else "a"): True})
        b = 0.0 if self.bias is None else reduce_mean(square(self.bias))
        return b + ( linalg.trace(c) - linalg.logdet(c) ) / c.shape[0] - 1
    
    def call(self, inputs, training=None):
        self.add_loss(0.0 if training is None else self.loss)
        return super().call(inputs)

In [40]:
from tensorflow.keras.layers import Conv2D
class RegConv2D(Conv2D):
    def __init__(self, filters, kernel_size, **kwargs):
        super().__init__(filters, kernel_size, **kwargs)
        
    def build(self, input_shape):
        from tensorflow import reshape, linalg, transpose
        from functools import reduce
        from operator import mul
        super().build(input_shape)
        _, u, v = linalg.svd(self.kmat)
        self.kernel.assign(reshape(u @ transpose(v), self.kernel.shape))
    
    @property
    def kmat(self):
        from tensorflow import reshape
        from functools import reduce
        from operator import mul
        w = self.kernel
        m = reshape(w, (reduce(mul, w.shape[:-1]), w.shape[-1]))
        return m
    
    @property
    def loss(self):        
        from tensorflow import matmul, linalg, square, reduce_mean
        w = self.kmat
        s = w.shape
        c = matmul(w, w, **{"transpose_" + ("b" if s[1] > s[0] else "a"): True})
        b = 0.0 if self.bias is None else reduce_mean(square(self.bias))
        return b + ( linalg.trace(c) - linalg.logdet(c) ) / c.shape[0] - 1
        
    def call(self, inputs, training=None):
        self.add_loss(0.0 if training is None else self.loss)
        return super().call(inputs)