In [11]:
import jax
from flax import linen as nn
import jax.numpy as jnp
import jax.random as random
import nnaugment
import numpy as np

bias_init = nn.initializers.normal(stddev=1e-6)

class SimpleCNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3), bias_init=bias_init)(x)
        x = nn.gelu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), bias_init=bias_init)(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.flatten()
        x = nn.Dense(features=15, name="Dense_2", bias_init=bias_init)(x)
        x = nn.Dense(features=10, name="Dense_3", bias_init=bias_init)(x)
        return x

In [12]:
rng = random.PRNGKey(0)
input_shape = (28, 28, 1)
model = SimpleCNN()

In [13]:
# Initialize the weights
x = jnp.ones(input_shape)
variables = model.init(rng, x)
initial_output = model.apply(variables, x)


# Augment weights
if isinstance(model, SimpleCNN):
    layers_to_permute = ["Conv_0", "Conv_1", "Dense_2"]
else:
    raise ValueError(f"Unknown model type: {type(model)}")

augmented_params = nnaugment.random_permutation(
    variables['params'], 
    layers_to_permute=layers_to_permute,
    convention="flax")
augmented_variables = {'params': augmented_params}

In [14]:
jax.tree_map(lambda x: x.shape, variables)

{'params': {'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)},
  'Conv_1': {'bias': (16,), 'kernel': (3, 3, 32, 16)},
  'Dense_2': {'bias': (15,), 'kernel': (3136, 15)},
  'Dense_3': {'bias': (10,), 'kernel': (15, 10)}}}

In [15]:
variables['params']['Conv_0']['bias']

Array([-3.50728271e-07, -7.06043295e-07, -6.61901595e-07, -4.50394765e-07,
       -1.64961222e-07,  1.16701165e-06, -1.41028158e-06, -5.46690103e-07,
        6.42782652e-07, -2.02183230e-07, -1.48379002e-06, -1.33969615e-06,
       -5.20142976e-07, -8.90838692e-07,  9.51444690e-07, -7.46093690e-07,
        1.38884332e-06,  5.91921662e-07, -3.42096371e-07,  5.37360563e-07,
        6.38460165e-07, -4.31932023e-07, -4.78100560e-07,  7.13702150e-07,
        9.17730558e-10,  1.94411740e-07, -3.80567258e-07, -4.25100353e-07,
       -2.56481769e-07,  1.69610064e-06, -9.35697869e-07, -1.08783155e-07],      dtype=float32)

In [16]:

# Check non-equality of augmented model's parameters against the original
for name, layer in augmented_params.items():
    if name in layers_to_permute:
        assert not np.allclose(layer["kernel"], variables['params'][name]["kernel"], rtol=5e-2), \
            f"Kernel parameters of model {type(model).__name__}, layer {name} are almost identical after augmentation."
        assert not np.allclose(layer["bias"], variables['params'][name]["bias"], rtol=5e-2), \
            f"Bias parameters of model {type(model).__name__}, layer {name} are almost identical after augmentation."

        assert not np.allclose(layer["kernel"], variables['params'][name]["kernel"], rtol=0.2), \
            f"Kernel parameters of model {type(model).__name__}, layer {name} are within +-20% size of each other after augmentation."
        assert not np.allclose(layer["bias"], variables['params'][name]["bias"], rtol=0.2), \
            f"Bias parameters of model {type(model).__name__}, layer {name} are within +-20% size of each other after augmentation."

# Check for unchanged output
augmented_output = model.apply(augmented_variables, x)
assert jnp.allclose(initial_output, augmented_output, atol=1e-6), "Outputs differ after weight augmentation."


AssertionError: Bias parameters of model SimpleCNN, layer Conv_1 are almost identical after augmentation.