# Neural Network Scalability Analysis

Testing scalability properties across four neural network architectures:

1. **ReLU-based Neural Networks**
    - Neural networks with nonlinearities parametrized by ReLU functions

2. **B-Spline Neural Networks** 
    - Neural networks with nonlinearities parametrized by B-Splines

3. **Quasi-Ground Truth Neural Networks**
    - Neural networks with true nonlinearities and learnable parameters 

4. **Ground Truth Networks**
    - Neural networks with true nonlinearities and true parameters

This comparison will help evaluate performance and computational efficiency across different architectures.

In [2]:
from typing import List, Callable
import flax.linen as nn
import jax.numpy as jnp
from jax import jit, random, config

config.update("jax_enable_x64", True)

The `LearnableActivations` class is a custom neural network module designed to apply multiple learnable activation functions to input data. The key heuristic behind this class is to enable flexibility and adaptability in activation functions by parameterizing them with learnable parameters. 

#### Key Features:
1. **Input Splitting**: The input features are divided into equal parts, corresponding to the number of activation functions provided. This ensures each activation function operates on a distinct subset of the input.

2. **Learnable Parameters**: Each activation function is associated with a fixed number of learnable parameters (default is 2). These parameters are initialized and stored in a single array.

3. **Iterative Activation**: The module applies each activation function to its respective input split, concatenating the results to produce the final output.

This approach allows for dynamic and trainable activation functions, making the model more expressive and capable of learning complex patterns in the data.

In [3]:
class LearnableActivations(nn.Module):
    input_features: int
    activations: List[Callable]
    n_params: int = 2

    def setup(self):
        self.num_activations = len(self.activations)
        if self.input_features % self.num_activations != 0:
            raise ValueError("Input features must be divisible by num_activations")

        # All params stored in a single array: shape (num_activations, n_params)
        self.params = self.param(
            "activation_params",
            nn.initializers.ones,
            (self.num_activations, self.n_params),
        )

    def __call__(self, x):
        splits = jnp.split(x, self.num_activations, axis=-1)

        # Initialize with first activation
        activated = self.activations[0](splits[0], self.params[0])

        # Iteratively concatenate remaining activations
        for i in range(1, self.num_activations):
            activated = jnp.concatenate(
                [activated, self.activations[i](splits[i], self.params[i])], axis=-1
            )

        return activated

In [4]:
class FixedActivations(nn.Module):
    input_features: int
    activations: List[Callable]

    def setup(self):
        self.num_activations = len(self.activations)
        if self.input_features % self.num_activations != 0:
            raise ValueError("Input features must be divisible by num_activations")

    def __call__(self, x):
        splits = jnp.split(x, self.num_activations, axis=-1)

        # Initialize with first activation
        activated = self.activations[0](splits[0])

        # Iteratively concatenate remaining activations
        for i in range(1, self.num_activations):
            activated = jnp.concatenate(
                [activated, self.activations[i](splits[i])], axis=-1
            )

        return activated

In [5]:
class SimpleModel(nn.Module):
    input_dim: int  # Neurons per activation group
    n_funcs: int  # Number of activation functions
    output_dim: int
    activations: List[Callable]
    max_num_params: int = 2

    def setup(self):
        self.custom_activation = LearnableActivations(
            self.input_dim * self.n_funcs,
            self.activations,
            self.max_num_params,
        )
        self.output_layer = nn.Dense(self.output_dim)

    def __call__(self, x):
        x = self.custom_activation(x)
        x = self.output_layer(x)
        return x

In [6]:
class FixedModel(nn.Module):
    input_dim: int  # Neurons per activation group
    n_funcs: int  # Number of activation functions
    output_dim: int
    activations: List[Callable]

    def setup(self):
        self.custom_activation = FixedActivations(
            self.input_dim * self.n_funcs,
            self.activations,
        )
        self.output_layer = nn.Dense(self.output_dim)

    def __call__(self, x):
        x = self.custom_activation(x)
        x = self.output_layer(x)
        return x

In [7]:
@jit
def func(x, params):
    return (
        params[0] * nn.relu(x + params[1])
        + params[2] * nn.relu(x + params[3])
        + params[4] * nn.relu(x + params[5])
    )


# @jax.jit
# def func1(x, params):
#     return x * params[0]


@jit
def func2(x):
    return 0.5 + x * x * -1


# activations = [func1, func2, func]
activations = [func2] * 3
# activations = [func, func, func]

L = 3
input_dim = 3
N = input_dim
output_dim = 3
batch_size = 5

model = FixedModel(N, L, output_dim=output_dim, activations=activations)
key = random.key(0)
x1 = jnp.zeros((batch_size, 1))
x2 = jnp.ones((batch_size, 1)) * jnp.pi / 2
x3 = jnp.ones((batch_size, 1)) * jnp.pi / 2
x = jnp.concatenate([x1, x2, x3], axis=1)

x = jnp.tile(x, L)

params = model.init(key, x)

output = model.apply(params, x)
# print("Model parameters:")
# print(params["params"]["output_layer"]["kernel"])
# print(jnp.sum(params["params"]["output_layer"]["kernel"], axis=0))
# print("Input shape:", x.shape)
# print("Hidden layer output shape:", model.hidden_layer(x).shape)
# print("After activation shape:", model.activation(model.hidden_layer(x)).shape)
# print("Final output shape:", output.shape)
print("Input:", x)
print(output.dtype)

Input: [[0.         1.57079633 1.57079633 0.         1.57079633 1.57079633
  0.         1.57079633 1.57079633]
 [0.         1.57079633 1.57079633 0.         1.57079633 1.57079633
  0.         1.57079633 1.57079633]
 [0.         1.57079633 1.57079633 0.         1.57079633 1.57079633
  0.         1.57079633 1.57079633]
 [0.         1.57079633 1.57079633 0.         1.57079633 1.57079633
  0.         1.57079633 1.57079633]
 [0.         1.57079633 1.57079633 0.         1.57079633 1.57079633
  0.         1.57079633 1.57079633]]
float64
