In [None]:
import torch
import torch.nn as nn

def hard_threshold(arr, thresh=0.0):
    arr[arr <= thresh] = 0.0
    return arr

class JumpReLU(nn.Module):
    def __init__(self, jump=0.):
        super(JumpReLU, self).__init__()
        self.jump = jump
        
    def forward(self, input):
        if(self.training == True):
            return hard_threshold(input, thresh=0.0)
        elif(self.training == False):
            return hard_threshold(input, thresh=self.jump)
    
    def __repr__(self):
        return self.__class__.__name__ + "()"

class CLTFeatureActivations(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = nn.Linear(in_features=dim_in, out_features=dim_out)
        self.nonlinearity = JumpReLU()
        
    def forward(self, x):
        enc = self.encoder(x)
        a = self.nonlinearity(enc)
        return a
    
class CLTReconstruction(nn.Module):
    def __init__(self, n_layers, dim_out, dim_in):
        super().__init__()
        self.n_layers = n_layers
        self.decoders = nn.ModuleList([
            nn.Linear(dim_out, dim_in) for i in range(n_layers)
        ])
        self.tanh = nn.functional.tanh()
        
    def forward(self, clt_outs):
        y_hat = torch.zeros()
        for i, clt_out in enumerate(clt_outs):
            y_hat += self.decoders[i](clt_out)
        return y_hat
    
    def loss(self, y_hat, y, lambd, c, clt_outs):
        l_mse = nn.MSELoss(y_hat, y)
        
        sum_over_layers = 0
        for l in range(self.n_layers):
            decoder_matrix = self.decoders[l]
            sum_over_features = 0
            for feature_idx in range(decoder_matrix.size(0)):
                activation = self.tanh(c * torch.norm(decoder_matrix[feature_idx]) @ clt_outs[l][feature_idx])
                sum_over_features += activation
            sum_over_layers += sum_over_features
        l_sparsity = lambd * sum_over_layers
        return l_mse + l_sparsity