Before we explore SIREN^2, let us first examine the new weight perturbation scheme and what effect it plays on the spectral properties of the intermediate network quantities.

# Explore 'Gaussian weight perturbation' scheme



In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import numpy.fft as fft
from matplotlib.ticker import ScalarFormatter
import matplotlib.ticker as ticker
from scipy import fft, signal

Let us consider the following settings for an MLP network:
* One input ($x$) and one output ($\phi$)
* Input is uniformly distributed as $x \in \mathcal{U}[-1,1]$ with a total of $2^8$ samples
* Four hidden layers $l=1,2,3,4$ with periodic non-linearity $sin(\omega_0 x_{in})$ with $\omega_0 = 30$
* Each hidden layer has 2048 features.

In [2]:
# Parameters
mu = 0
variance = 1
sigma = np.sqrt(variance)
fan_in = 2048
n = 65536//4
n_total = n * fan_in
noise_scale = 1.0   # for Gaussian weight perturbation
omega_0 = 30.0

color = 'red'

We begin by defining the weights and biases, and then introduce the perturbation noise to the weights.

Consider the following weight initilization scheme of SIREN as per Sitzmann et al.:
* The weights connecting the inputs and the first hidden layer are sampled from a uniform distribution as $W_{10} \in \mathcal{U}[-1,1]$
* The weights connecting the intermediate hidden layers are sampled from a uniform distribution as $W_{21} \in \mathcal{U}\left(-\frac{1}{\omega_0}\sqrt{6/d},\frac{1}{\omega_0}\sqrt{6/d}\right)$ (same for $W_{32}, W_{43}, W_{54}$), where $d = \text{fan\_in}$ (in this case = 2048).
* All the bias vectors are extracted from a uniform distribution as $\mathbf{b} \in \mathcal{U}(-\sqrt{6/d}, \sqrt{6/d})$

While this weight initialization scheme is widely used and effective across many applications, it suffers from the spectral bottleneck phenomenon, where fitting signals with weak low-frequency and strong high-frequency representation leads to a collapse in the spectral energy of the Neural Tangent Kernel, resulting in reconstruction failure. We propose a simple weight perturbation strategy that mitigates this issue by broadening the PSD of pre- and post-activations, significantly improving performance on a wide range of audio fitting tasks.

The new weight perturbation scheme add the following noise to the network weights between the first hidden layer and the output layer (excluding the weight between the inputs and first hidden layer).
For example, the weights $W_{21}$ is transformed as follows:
$$W_{21} \rightarrow W_{21} + \eta_{21}$$

with noise extracted from a Gaussian distribution as,
$$\eta_{21} \in \mathcal{N}\left( 0, \left(\frac{s}{\omega_0}\right)^2 \right)$$

In [6]:
## define weight and bias matrices

# inputs to layer-1 {01}
W_10 = np.random.uniform(-1,1,size=(fan_in, 1))                                                             # size: [fan_in, 1]
bias_1 = np.random.uniform(-np.sqrt(1),np.sqrt(1),size=(fan_in, 1))                                         # size: [fan_in, 1]

# layer-1 to layer-2 {21}
W_21 = np.random.uniform(-np.sqrt(6/fan_in)/omega_0, np.sqrt(6/fan_in)/omega_0, size=(fan_in, fan_in))      # size: [fan_in, fan_in]
bias_2 = np.random.uniform(-np.sqrt(1/fan_in),np.sqrt(1/fan_in),size=(fan_in, 1))                           # size: [fan_in, 1]

# layer-2 to layer-3 {32}
W_32 = np.random.uniform(-np.sqrt(6/fan_in)/omega_0, np.sqrt(6/fan_in)/omega_0, size=(fan_in, fan_in))      # size: [fan_in, fan_in]
bias_3 = np.random.uniform(-np.sqrt(1/fan_in),np.sqrt(1/fan_in),size=(fan_in, 1))                           # size: [fan_in, 1]

# layer-3 to layer-4 {43}
W_43 = np.random.uniform(-np.sqrt(6/fan_in)/omega_0, np.sqrt(6/fan_in)/omega_0, size=(fan_in, fan_in))      # size: [fan_in, fan_in]
bias_4 = np.random.uniform(-np.sqrt(1/fan_in),np.sqrt(1/fan_in),size=(fan_in, 1))                           # size: [fan_in, 1]

# layer-4 to output {54}
W_54 = np.random.uniform(-np.sqrt(6/fan_in)/omega_0, np.sqrt(6/fan_in)/omega_0, size=(1, fan_in))           # size: [1, fan_in]
bias_5 = np.random.uniform(-np.sqrt(1/fan_in),np.sqrt(1/fan_in),size=(1, 1))                                # size: [1, 1]

# add noise to {21}
# noise = np.random.randn(*W_21.shape) * noise_scale/omega_0
# W_21 = W_21 + noise

Now compute the pre- and post-activation quantities in each layer

In [9]:
# uniform inputs
x = np.linspace(-1, 1, n).reshape(-1,1)                                                                 # size: [n, 1]

# layer-1
x_pre1  = omega_0*(np.dot(x,W_10) + bias_1)                                                             # size: [n, fan_in]
x_post1 = np.sin(x_pre1)                                                                                # size: [n, fan_in]

# layer-2
x_pre2 = omega_0*(np.dot(x_post1,W_21) + bias_2)                                                        # size: [n, fan_in]
x_post2 = np.sin(x_pre2)                                                                                # size: [n, fan_in]

# layer-3
x_pre3 = omega_0*(np.dot(x_post2,W_32) + bias_3)                                                        # size: [n, fan_in]
x_post3 = np.sin(x_pre3)                                                                                # size: [n, fan_in]

# layer-4
x_pre4 = omega_0*(np.dot(x_post3,W_32) + bias_3)                                                        # size: [n, fan_in]
x_post4 = np.sin(x_pre4)                                                                                # size: [n, fan_in]

# layer-
x_pre4 = omega_0*(np.dot(x_post3,W_32) + bias_3)                                                        # size: [n, fan_in]
x_post4 = np.sin(x_pre4)                                                                                # size: [n, fan_in]

Just an auxillary function to compute PSD of our time-series (or space-series)

In [10]:
def get_spectrum(activations, max_freq_index=1000):
    n = activations.shape[0] # Number of samples (rows)
    max_freq_index=n//2
    signal_summed = activations.reshape(-1,1).astype(np.double).sum(axis=0)
    # fs = n / 2.0
    fs = 2 * np.pi * activations.shape[0]
    frequencies, psd = signal.periodogram(signal_summed, fs=fs, scaling='density')
    limit_index = min(max_freq_index, len(frequencies))

    return frequencies[:limit_index], psd[:limit_index]

Let’s plot the distributions
(Might look intemidating, but it’s simpler than it looks)

In [11]:
## Plotting
# --- GLOBAL FONT SETTINGS using rcParams ---
size_f = 1
plt.rcParams['font.size'] = size_f         # Default font size for text
plt.rcParams['axes.labelsize'] = size_f    # Font size for x and y labels
plt.rcParams['xtick.labelsize'] = size_f   # Font size for x-axis tick labels
plt.rcParams['ytick.labelsize'] = size_f   # Font size for y-axis tick labels
plt.rcParams['legend.fontsize'] = size_f   # Font size for legends
plt.rcParams['axes.titlesize'] = size_f   # Font size for subplot titles (if you add them)
plt.rcParams['font.family'] = 'monospace' # Global monospace font (if desired)

def scientific_formatter(x, pos):
    if x == 0:
        return "0"
    s = f"{x:.0e}"  # Format with 0 decimal places for the mantissa (e.g., 1e+03)
    
    # Split into mantissa and exponent
    parts = s.split('e')
    mantissa = parts[0]
    exponent = parts[1] # e.g., '+03' or '-12'

    # Remove leading zero from single-digit exponents (e.g., '+03' -> '+3')
    if exponent.startswith('+0') and len(exponent) == 3: # Matches '+0X'
        exponent = '+' + exponent[2]
    elif exponent.startswith('-0') and len(exponent) == 3: # Matches '-0X'
        exponent = '-' + exponent[2]

    return f"{mantissa}e{exponent}"

custom_formatter = ticker.FuncFormatter(scientific_formatter)

fig, ax = plt.subplots(7, 2, figsize=(4.5, 7))

# Apply styling to all subplots
for row in ax:
    for sub_ax in row:
        # Apply the custom formatter to both x and y axes
        sub_ax.xaxis.set_major_formatter(custom_formatter)
        sub_ax.yaxis.set_major_formatter(custom_formatter)
        sub_ax.tick_params(direction='in', which='both', labelsize=8) # Set labelsize here for convenience

        # Set monospace font for tick labels
        for tick in sub_ax.get_xticklabels():
            tick.set_fontfamily('monospace')
        for tick in sub_ax.get_yticklabels():
            tick.set_fontfamily('monospace')

        sub_ax.xaxis.get_offset_text().set_visible(False)
        sub_ax.yaxis.get_offset_text().set_visible(False)


# inputs
ax[0][0].hist(x.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
freq, spectrum = get_spectrum(x)
ax[0][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[0][1].set_yscale('log')

# layer-1 (pre)
ax[1][0].hist(x_pre1.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
freq, spectrum = get_spectrum(x_pre1)
ax[1][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[1][1].set_yscale('log')

# layer-1 (post)
ax[2][0].hist(x_post1.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
x_ref = np.linspace(-1, 1, 500)
ax[2][0].plot(x_ref, stats.arcsine.pdf(x_ref, -1, 2), linestyle=':', linewidth=3.0, color='black', markersize=0.4, zorder=2)
freq, spectrum = get_spectrum(x_post1)
ax[2][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[2][1].set_yscale('log')

# layer-2 (pre)
ax[3][0].hist(x_pre2.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
mu, sigma = 0, np.sqrt(1 + fan_in * noise_scale**2 / 2.0)
x_ref = np.linspace(mu - 3*sigma, mu + 3*sigma, 500)
ax[3][0].plot(x_ref, stats.norm.pdf(x_ref, mu, sigma), linestyle=':', linewidth=3.0, color='black', markersize=0.4, zorder=2)
ax[3][0].set_xlim(mu - 5*sigma, mu + 5*sigma)
freq, spectrum = get_spectrum(x_pre2)
ax[3][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[3][1].set_yscale('log')

# layer-2 (post)
ax[4][0].hist(x_post2.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
x_ref = np.linspace(-1, 1, 500)
ax[4][0].plot(x_ref, stats.arcsine.pdf(x_ref, -1, 2), linestyle=':', linewidth=3.0, color='black', markersize=0.4, zorder=2)
freq, spectrum = get_spectrum(x_post2)
ax[4][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[4][1].set_yscale('log')

# layer-3 (pre)
ax[5][0].hist(x_pre3.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
mu, sigma = 0, 1
x_ref = np.linspace(mu - 3*sigma, mu + 3*sigma, 500)
ax[5][0].plot(x_ref, stats.norm.pdf(x_ref, mu, sigma), linestyle=':', linewidth=3.0, color='black', markersize=0.4, zorder=2)
ax[5][0].set_xlim(mu - 5*sigma, mu + 5*sigma)
freq, spectrum = get_spectrum(x_pre3)
ax[5][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[5][1].set_yscale('log')

# layer-3 (post)
ax[6][0].hist(x_post3.reshape(-1,1), bins=256, density=True, alpha=0.5, color=color)
x_ref = np.linspace(-1, 1, 500)
ax[6][0].plot(x_ref, stats.arcsine.pdf(x_ref, -1, 2), linestyle=':', linewidth=3.0, color='black', markersize=0.4, zorder=2)
freq, spectrum = get_spectrum(x_post3)
ax[6][1].plot(freq, spectrum, alpha=0.5, color=color)
ax[6][1].set_yscale('log')


ax[0][1].set_ylim(1e-8,1e+3)
ax[1][1].set_ylim(1e-8,1e+2)
ax[2][1].set_ylim(1e-8,1e-1)
ax[3][1].set_ylim(1e-8,1e+3)
ax[4][1].set_ylim(1e-8,1e-1)
ax[5][1].set_ylim(1e-8,1e+3)
ax[6][1].set_ylim(1e-8,1e-1)

pi_multiples = np.array([1e-4, 1e-2, 1]) # Example: pi/2, pi, 2*pi
pi_labels = [r'$1e-4$', r'$1e-2$', r'$\pi$'] # Corresponding LaTeX labels

for i in range(7):
    ax[i][1].set_xscale('log')
    ax[i][1].set_xticks(pi_multiples * np.pi)
    ax[i][1].set_xticklabels(pi_labels)

# --- Ensure minimal whitespace and tight plots ---
plt.tight_layout(pad=0.1, h_pad=0.1, w_pad=0.1)

# Save with bbox_inches='tight' and pad_inches=0 for absolute minimal border
# plt.savefig(f'{network}.pdf', bbox_inches='tight', pad_inches=0)
plt.show()


  ax[1][1].set_yscale('log')
  ax[2][1].set_yscale('log')
  ax[3][1].set_yscale('log')
  ax[4][1].set_yscale('log')
  ax[5][1].set_yscale('log')
  ax[6][1].set_yscale('log')
  ax[i][1].set_xscale('log')


ValueError: Data has no positive values, and therefore can not be log-scaled.

Error in callback <function _draw_all_if_interactive at 0x11c74ff40> (for post_execute), with arguments args (),kwargs {}:


ValueError: Data has no positive values, and therefore can not be log-scaled.

ValueError: Data has no positive values, and therefore can not be log-scaled.

<Figure size 450x700 with 14 Axes>

Now based on this weight perturbation scheme, let us define a SIREN architecture:
* SIREN: Uses default weight initilization scheme of Sitzmann et al. (NeurIPS 2019)
* SIREN_square: Adds noise to the default weight initilization scheme of Sitzmann et al. through the present Gaussian weight perturbation scheme

In [None]:
import torch
import torch.nn as nn
import numpy as np
import math
import sys

@torch.jit.script
def sine_block(x: torch.Tensor, w0: float, a0: float):
    return a0 * torch.sin(w0 * x)


class SineLayer(nn.Module):
    def __init__(self, w0, A=1.0):
        super().__init__()
        self.a0 = A
        self.w0 = w0

    def forward(self,x):
        return sine_block(x, self.w0, self.a0)
    

def add_noise(tensor, scale, freq_decay=0.9, with_filter=False):
    """
    Add random/filtered noise to the input tensor.
    """
    noise = torch.randn_like(tensor) * scale    # extract noise from normal distribution [mean=0, std=scale]

    if with_filter:
        for i in range(noise.shape[0]):
            noise[i] *= freq_decay ** i     # in case one wishes to trigger high-freq modes only

    return tensor + noise


In [None]:
class SIREN(nn.Module):
    def __init__(self, in_dim, HL_dim, out_dim, w0=30, first_w0=3000, n_HLs=4):
        super().__init__()
        self.net = []
        
        self.net.append(nn.Linear(in_dim, HL_dim))
        self.net.append(SineLayer(first_w0))
        for _ in range(n_HLs-1):
            self.net.append(nn.Linear(HL_dim, HL_dim))
            self.net.append(SineLayer(w0))
        self.net.append(nn.Linear(HL_dim, out_dim))
        
        self.net = nn.Sequential(*self.net)

        # init weights
        with torch.no_grad():
            self.net[0].weight.uniform_(-1.0/in_dim, 1.0/in_dim)

            for i in range(n_HLs):
                self.net[(i+1)*2].weight.uniform_(-np.sqrt(6.0/HL_dim)/w0, np.sqrt(6.0/HL_dim)/w0)

    def forward(self, x):
        return self.net(x)

In [None]:
class SIREN_square(nn.Module):
    """
    official implementation of SIREN^2

    Hyperparameters:
        S0 = noise scale for first hidden layer
        a  = exponential decay factor of noise

        use S0=1.00/omaga_0, a=100  --> for audio
        use S0=0.25/omega_0, a=100  --> for images
    """
    def __init__(self, S0=0.04, a=100, w0=30, in_dim=1, HL_dim=256, out_dim=1, first_w0=3000, n_HLs=4, with_filter=False):
        super().__init__()
        eps = sys.float_info.epsilon
        self.net = []
        
        self.net.append(nn.Linear(in_dim, HL_dim))
        self.net.append(SineLayer(first_w0))
        for _ in range(n_HLs-1):
            self.net.append(nn.Linear(HL_dim, HL_dim))
            self.net.append(SineLayer(w0))
        self.net.append(nn.Linear(HL_dim, out_dim))
        
        self.net = nn.Sequential(*self.net)

        # SIREN^2 weight initilization 
        with torch.no_grad():
            self.net[0].weight.uniform_(-1.0/in_dim, 1.0/in_dim)

            for k in range(n_HLs):
                i = 2 * (k + 1)
                scale = S0/((a + eps)**(i/2-1) + eps)
                base = torch.empty_like(self.net[i].weight).uniform_(-np.sqrt(6.0/HL_dim)/w0, np.sqrt(6.0/HL_dim)/w0)   # base
                self.net[i].weight.copy_(add_noise(base, scale=scale, with_filter=with_filter))                         # base + noise

    def forward(self,x):
        return self.net(x)

Let us train a SIREN and SIREN_square to fit a high-frequency audio file and compare their performance