In [None]:
class Layer:
    def __init__(self) -> None:
        self.params = {}
        self.grads = {}
        self.lr = 0.0
        self.lr_type = None
        self.rg = 0.0
        self.rg_type = None
           
    def forward(self, input_data):
        raise NotImplementedError
    
    def backward(self, grad, alpha):
        raise NotImplementedError
        
class ZLayer(Layer):
    def __init__(self, inSize, outSize, eps_W=1e-2, eps_B=0.0, wInit='Xavier', seed=None, lr=.5, lr_type=None, rg=0.0, rg_type='L2'):
        super().__init__()
        if seed != None:
            np.random.seed(seed)
        if wInit.upper() == 'HE':
            sigma = np.sqrt(2 / inSize)
        else :
            sigma = np.sqrt(2 / (inSize + outSize))
        self.params['W'] = sigma * np.random.randn(outSize, inSize)
        self.params['B'] = np.zeros((outSize, 1))
        #print('W:\n', self.params['W'][0][0:10], self.params['W'].shape)
        self.lr = lr
        self.lr_initial = lr
        self.lr_type = lr_type
        self.rg = rg
        self.rg_type = rg_type
        
    def forward(self, inData):
        self.input = inData
        self.output = np.dot(self.params['W'], self.input) + self.params['B']
        return self.output
        
    def backward(self, outGrad, alpha=1e-0):
        inGrad = np.dot(self.params['W'].T, outGrad)
        self.grads['W'] = np.dot(outGrad, self.input.T) + self.rg / outGrad.shape[1] * self.params['W']
        self.grads['B'] = np.sum(outGrad, axis=1, keepdims=True)
        return inGrad        
    
class ALayer(Layer):
    def __init__(self, g, gPrime):
        super().__init__()
        self.g = g
        self.gPrime = gPrime

    def forward(self, inData):
        self.input = inData
        self.output = self.g(self.input)
        return self.output

    def backward(self, outGrad):
        return self.gPrime(self.input) * outGrad

class SLayer(ALayer):
    def backward(self, outGrad):
        return np.einsum('jik,kj->ij', self.gPrime(self.output), outGrad)
    
class LLayer(Layer):
    def __init__(self, g, gPrime, eps=1e-8):
        self.g = g
        self.gPrime = gPrime
        self.eps = eps

    def forward(self, pred, obs):
        self.input = pred
        self.output = self.g(pred, obs)
        return self.output

    def backward(self, pred, obs):
        return self.gPrime(pred, obs)    