In [1]:
import torch
from scipy.stats.qmc import Halton

def halton_perturbations(dim, Q, n_samples=64, batch_size=1):
    """
    Generate Halton sequence based perturbations.

    Parameters:
    dim (int): Dimension of the input data.
    Q (float): The range for perturbations.
    n_samples (int): Number of perturbed samples to generate.
    batch_size (int): Size of each batch.
    
    Returns:
    np.ndarray: Array of perturbations.
    """
    sampler = Halton(dim)
    # Generate Halton points in [0, 1], then scale to [-Q, Q]
    x = (sampler.random(n_samples * batch_size) * 2 - 1) * Q
    return x.reshape(batch_size, n_samples, dim)

def model(x):
    # x.shape: B, N, D
    fc = torch.nn.Linear(x.shape[-1], 5)
    x = fc(x)
    return x

def loss(y, y_pert):
    return torch.mean((y - y_pert)**2)

In [2]:
def test():
    batch_size = 1
    Q = 0.5  # Example Q value
    n_features = 3
    H = 64
    x = torch.randn(batch_size, n_features)  # Example input batch, batch size = 1

    delta_x = torch.from_numpy(halton_perturbations(n_features, Q, H)).float()
    x_pert = x.unsqueeze(1) + delta_x

    y = model(x)
    y_pert = model(x_pert)
    SSM_loss = loss(y, y_pert)
    print(SSM_loss)

test()

tensor(0.9066, grad_fn=<MeanBackward0>)
