Maybe just a bunch of ANDs in every layer, then ANDs of ANDs? Issue: sparsity goes down.

Just replace ANDs with at-least-two-of-a-small-subset functions? Say each input node is on with prob
$p$; consider an output node with $k$ inputs; the prob this has two inputs on is $p^2(1-p)^{k-2}
\binom{k}{2}$. The prob it has at least two inputs on is $1-(1-p)^k-kp (1-p)^{k-1}$. The difference
between these two is $O\left(\frac{(pk)^3}{1-pk}\right)$, from summing a geometric series. Setting
the first one to $p$ gives $p^2(1-p)^{k-2} \binom{k}{2}=p$, which implies
$p(1-p)^{k-2}=\frac{2}{k(k-1)}$. This should work out at about $p=\frac{2}{(k-1)^2}$, up to
lower-order corrections, which also makes the diff small. Equivalently, $k=\sqrt{\frac{2}{p}}+1$ We
might want to correct the ideal network such that it is more precisely binary though? I.e., we might
want to do it without ReLUs? I guess can try both options, but let's first try the one without ReLUs
that just computes these gates with perfect binary outputs.


Here's an alternative calculation in a slightly different setup (though probably they are the same
up to error terms in some reasonable sense). Let's say each entry of the weight matrix Bernoulli
with probability $q$ and each input is Bernoulli with probability $p$. We want it to be the case
that taking a random matrix and a random input, the probability that an output is on is $p$. The
probability it is on is the probability that there are at least two simultaneous hits from that row
of the weight matrix and the input. Each hit has probability $pq$, so this has probability
$1-(1-pq)^m-m pq (1-pq)^{m-1}$. So to keep sparsity constant, we want $1-(1-pq)^m-m pq
(1-pq)^{m-1}=p$. Up to sth like a $O((mpq)^3/(1-mpq))$ term as before, we can just solve $p=(pq)^2
m(m-1)/2$. This gives $q=\sqrt{\frac{2}{m(m-1)p}}$. A more precise solution can be found using
numerical methods (after all, fixing $m$ and $p$, it's just a matter of finding a root of a
polynomial in $q$), I think. But this should be fine for us for now.

Jake: roughly we want $p=(qp)^2 m\implies q=\frac 1 {m\sqrt p}$. If we decide to pick $p=q$, we have
$p=m^{1/3}$


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def custom_activation(x):
    return F.relu(x) - F.relu(x - 1)


def denoising_nonlinearity(x, epsilon):
    x[torch.abs(x) < epsilon] = 0
    return x


# Custom layer with specified properties
class CustomLayer(nn.Module):
    def __init__(self, input_dim, output_dim, probability_q, p_positive=0.5, bias_value=-1):
        super(CustomLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.probability_q = probability_q

        # Initialize weights and biases
        # self.weights = nn.Parameter(torch.Tensor(output_dim, input_dim))
        # self.bias = nn.Parameter(torch.Tensor(output_dim))
        self.layer = nn.Linear(input_dim, output_dim)
        self.bias_value = bias_value
        self.p_positive = p_positive
        self.reset_parameters()
        for param in self.parameters():
            param.requires_grad = False

    def reset_parameters(self):
        self.layer.weight.data = torch.bernoulli(
            torch.full((self.output_dim, self.input_dim), self.probability_q)
        )
        random_signs = (
            torch.bernoulli(self.p_positive * torch.ones(self.output_dim, self.input_dim)) * 2 - 1
        )
        self.layer.weight.data = self.layer.weight.data * random_signs
        self.layer.bias.data.fill_(self.bias_value)

    def forward(self, x):
        x = self.layer(x)
        return custom_activation(x)


class SummationLayer(nn.Module):
    def __init__(self, input_dim):
        super(SummationLayer, self).__init__()
        self.input_dim = input_dim
        # Initialize coefficients as 1 or -1 with 50/50 probability
        self.coefficients = torch.where(
            torch.rand(input_dim) > 0.5, torch.tensor(1.0), torch.tensor(-1.0)
        )
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        # Adjust for the case where x might not be batched
        if x.dim() == 1:
            # x is a 1D tensor, implying a single sample rather than a batch
            return torch.sum(x * self.coefficients, dim=0, keepdim=True)
        else:
            # x is a 2D tensor, implying a batch of samples
            return torch.sum(x * self.coefficients, dim=1, keepdim=True)


# Neural Network with L Custom Layers and a Summation Layer at the end
class CustomNetwork(nn.Module):
    def __init__(
        self, layer_dims, probability_q, p_positive=0.5, bias_value=-1, final_summation=False
    ):
        super(CustomNetwork, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(1, len(layer_dims)):
            self.layers.append(
                CustomLayer(layer_dims[i - 1], layer_dims[i], probability_q, p_positive, bias_value)
            )

        # Add the summation layer at the end, treated as just another layer
        self.final_summation = final_summation
        if final_summation:
            self.summation_layer = SummationLayer(layer_dims[-1])

    def forward(self, x):
        # assert that none of the weights or biases are nan
        for layer in self.layers:
            assert not torch.isnan(layer.layer.weight).any(), f"weights of layer {layer} are nan"
            assert not torch.isnan(layer.layer.bias).any(), f"bias of layer {layer} are nan"

        # List to store activations from each layer, including the input and output layers
        activations = [x]
        for layer in self.layers:
            x = layer(x)
            activations.append(x)  # Store the activation of each layer
        # Apply the summation layer and treat its output as the activation of the final layer
        if self.final_summation:
            x = self.summation_layer(x)
            activations.append(x)  # Include the final output as the last "activation"
        return x, activations  # Return a list of activations for all layers

In [None]:
def create_matrix_with_unit_norm_columns(d, m):
    """
    Create an d x m matrix E where each column has a unit norm.

    Args:
        d (int): Number of columns in W_E, corresponding to the dimension of V_1.
        m (int): Number of rows in W_E, corresponding to the dimension of U_1.

    Returns:
        torch.Tensor: The matrix W_E with each row normalized to have a unit norm.
    """
    # Step 1: Generate a d x m matrix with Gaussian entries
    W_E = torch.randn(m, d)

    # Step 2: Normalize each column to have a unit norm
    norms = torch.norm(W_E, dim=1, keepdim=True)
    W_E_normalized = W_E / norms

    return W_E_normalized

How many inputs do we need to give the small net a reasonable chance to learn the ideal algorithm? Well, for a lower bound, the number of times each output gate is active with one of the inputs being active should be at least one in expectation, I guess. But well, maybe this is a bad question to ask. We should probably just keep training on random inputs and track the loss over time. It'll plausibly be clear from the loss curve if the algo is still getting learned or if we've peaked. Should maybe still hold out a small number of inputs as a test set though!

In [None]:
class IdealNetworkDataset(Dataset):
    def __init__(self, length, m, p, W_E, ideal_network: CustomNetwork):
        """
        Args:
            length (int): Number of items in the dataset.
            m (int): Dimension of the input vectors for the big network (U_1).
            p (float): Probability of an entry being 1 in the input vector.
            E (torch.Tensor): Matrix for mapping u to v (dimensions n x m).
            ideal_network (CustomNetwork): The ideal (big) network to run input through and
            get activations.
        """
        self.length = length
        self.m = m
        self.p = p
        self.W_E = W_E
        self.ideal_network = ideal_network
        self.ideal_inputs = torch.bernoulli(torch.full((length, self.m), self.p))
        self.compressed_inputs = self.ideal_inputs @ self.W_E

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Generate a random vector u in U_1
        ideal_input = self.ideal_inputs[idx]
        compressed_input = self.compressed_inputs[idx]

        # Run u (not v) through the ideal network and store activations at each layer
        # This can be done for an entire batch at once.
        _, ideal_activations = self.ideal_network(ideal_input)

        # Return the input for the small network (v) and the corresponding activations as targets
        return (compressed_input, ideal_activations)

In [None]:
class SmallNetwork(nn.Module):
    def __init__(
        self,
        layer_dims,
        readoff_dims,
        nonlinearity=F.relu,
        final_summation=False,
        readoff_nonlinearity=denoising_nonlinearity,
    ):
        """
        Initializes the small network with specific layer dimensions and readoff layers for all
        activations, including a separate readoff for the summation layer's output.

        Args:
        layer_dims (list): Dimensions of the small network's layers, including the input layer.
            Typically [d]*L, where d is the dimension of the compressed input and L is the number of
            layers.
        readoff_dims (list): Target dimensions for the readoff layers, corresponding to each
            layer in the big network. Typically [m]*L, where m is the dimension of the input to
            the big network and L is the number of layers.
        """
        super(SmallNetwork, self).__init__()
        self.layers = nn.ModuleList()
        self.readoff_layers = nn.ModuleList()

        # Initialize the small network layers with ReLU activations
        for i in range(1, len(layer_dims)):
            self.layers.append(nn.Linear(layer_dims[i - 1], layer_dims[i]))

        # Final summation layer
        self.final_summation = final_summation
        if final_summation:
            self.summation_layer = SummationLayer(layer_dims[-1])

        # Ensure we have a readoff layer for each layer in layer_dims plus one for the summation
        # layer
        for i, dim in enumerate(layer_dims):
            self.readoff_layers.append(nn.Linear(dim, readoff_dims[i], bias=False))
        if final_summation:
            self.readoff_layers.append(nn.Linear(1, 1, bias=False))

        self.nonlinearity = nonlinearity
        self.readoff_nonlinearity = readoff_nonlinearity
        # epsilons is a learnable list of thresholds for the denoising nonlinearity initialised at
        # 1/sqrt(layer_dim) for each layer
        self.epsilons = nn.ParameterList(
            [nn.Parameter(torch.tensor(1 / np.sqrt(layer_dim))) for layer_dim in layer_dims]
        )

    def forward(self, x):
        activations = [x]  # Store activations from the small network, including input
        for layer in self.layers:
            x = self.nonlinearity(layer(x))
            activations.append(x)

        # Apply the summation layer to the output of the last linear layer
        if self.final_summation:
            summation_output = self.summation_layer(activations[-1])
            activations.append(summation_output)

        # Map activations and summation output through their respective readoff layers
        readoff_activations = []
        for i, (readoff_layer, activation) in enumerate(zip(self.readoff_layers, activations)):
            readoff_activation = readoff_layer(activation)
            # print("activation: ", activation)
            # print("readoff_activation: ", readoff_activation)
            denoised_readoff = self.readoff_nonlinearity(readoff_activation, self.epsilons[i])
            # print("denoised_readoff: ", denoised_readoff)
            readoff_activations.append(denoised_readoff)

        return activations, readoff_activations  # Adjust return types as necessary

In [None]:
def reweighted_MSE(predictions: torch.Tensor, targets: torch.Tensor, ones_weighting=1):
    """
    Compute MSE between predictions and targets. Predictions should be a batch of vectors with
    entries that are zero and one. The MSE is reweighted to give more importance to the ones in the
    targets, so that for each sample, the MSE is computed as:
    filter the predictions by targets that should be 1. Then compute MSE and divide by the number of
    1s in the targets. Then multiply by ones_weighting.
    Then filter the predictions by targets that should be 0. Compute MSE and divide by the number of
    0s in the targets.
    Then sum the two MSEs.
    """
    assert predictions.shape == targets.shape
    assert torch.all((targets == 0) | (targets == 1))
    assert not torch.isnan(predictions).any(), "Predictions contain NaNs"
    assert not torch.isnan(targets).any(), "Targets contain NaNs"
    ones_predictions = predictions * targets
    ones_mse = torch.sum((ones_predictions - targets) ** 2, dim=-1) / torch.sum(targets, dim=-1)
    ones_mse[targets.sum(dim=-1) == 0] = 0  # Avoid NaNs when there are no ones in the target
    zeros_predictions = predictions * (1 - targets)
    zeros_mse = torch.sum(zeros_predictions**2, dim=-1) / torch.sum(1 - targets, dim=-1)
    zeros_mse[targets.sum(dim=-1) == targets.shape[-1]] = (
        0  # Avoid NaNs when there are no zeros in the target
    )
    total_MSE = ones_weighting * ones_mse + zeros_mse
    return total_MSE.sum()

In [None]:
def train(small_network, dataloader, lr=1e-3, epsilon_loss_weight=0.0):
    performance_criterion = reweighted_MSE
    optimizer = torch.optim.Adam(small_network.parameters(), lr=lr)

    for step, (compressed_input, ideal_activations) in enumerate(tqdm(dataloader)):
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        # assert no nans in ideal_activations or compressed_input
        ideal_activations = [activation.float() for activation in ideal_activations]
        compressed_input = compressed_input.float()
        assert not torch.isnan(compressed_input).any(), "Compressed input contains NaNs"
        assert not torch.isnan(ideal_activations[0]).any(), "Ideal activation contains NaNs"

        small_activations, readoff_activations = small_network(compressed_input)
        assert not torch.isnan(small_activations[0]).any(), "Small activation contains NaNs"
        assert all(
            [not torch.isnan(readoff_activation).any()]
            for readoff_activation in readoff_activations
        ), "Readoff activation contains NaNs"

        # Compute loss
        total_loss = 0
        for layer in range(len(ideal_activations)):
            readoff_activation = readoff_activations[layer]
            ideal_activation = ideal_activations[layer]
            assert not torch.isnan(
                readoff_activation
            ).any(), "Readoff activation contains NaNs at layer " + str(layer)
            assert not torch.isnan(ideal_activation).any(), "Ideal activation contains NaNs"
            performance_loss = performance_criterion(readoff_activation, ideal_activation)
            epsilons = small_network.epsilons[layer]
            epsilon_penalty = torch.norm(epsilons, p=4)
            total_loss += performance_loss + epsilon_penalty * epsilon_loss_weight
            # print everything and break out of the entire training loop if total_loss is nan
            if torch.isnan(total_loss):
                print("Performance loss: ", performance_loss)
                print("Epsilon penalty: ", epsilon_penalty)
                print("Total loss: ", total_loss)
                raise ValueError("Total loss is NaN")
        # Backward pass
        total_loss.backward()
        # if step % 1 == 0:
        print("Loss: ", total_loss.item())

        # Optimize
        optimizer.step()

In [None]:
m = 1000  # dim of ideal sparse net
# prob each input is on; should also be the prob each gate later on is on
p = 1 / 100  # math.log(m) * m ** (-1)
q = 0.1  # (2 / (m * (m - 1) * p)) ** (0.1)  # prob each weight matrix entry is nonzero
print(p, q)
b = -1
n = 100  # dim into which we'll try to compress the ideal net
dataset_length = 10000  # Number of data points
batch_size = 50  # Batch size for training
L = 8  # num of layers, including input but not the 1-neuron output
layer_dims = [m] * L  # Dimension of each layer including input and output dimension
probability_q = q  # Probability of presence of each entry in weight matrix
# k = math.sqrt(2/p)+1 # fan-in

p_positive = 0.25  # Probability of a weight being positive
ideal_network = CustomNetwork(layer_dims, probability_q, p_positive, bias_value=b)

# Initialize the dataset
W_E = create_matrix_with_unit_norm_columns(n, m)

ideal_network_dataset = IdealNetworkDataset(dataset_length, m, p, W_E, ideal_network)

# Create a DataLoader
dataloader = DataLoader(ideal_network_dataset, batch_size=100, shuffle=True)

layer_dims = [n] * L  # Layer dimensions for the small network, excluding the summation layer
readoff_dims = [m] * L  # Readoff dimensions for each layer plus the summation layer

small_network = SmallNetwork(layer_dims, readoff_dims)

num_workers = 1  # Number of workers for the DataLoader
dataloader = DataLoader(
    ideal_network_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)

In [None]:
train(small_network, dataloader, lr=1e-3, epsilon_loss_weight=0.0)

In [None]:
acts = ideal_network_dataset[:][1]
plt.hist([act.sum(dim=-1) for act in acts][-1])

In [None]:
small_network.epsilons[0]