-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Labels
enhancementNew feature or requestNew feature or request
Description
MLX currently lacks built-in support for weight normalization, which is a crucial feature for various deep learning architectures, particularly in audio processing and generative models. Weight normalization is a reparameterization technique that decouples the magnitude and direction of weight vectors, often leading to better conditioning and faster convergence.
Current Situation:
- No built-in equivalent to PyTorch's torch.nn.utils.weight_norm
- Users need to implement custom solutions, which may not be optimal or consistent
Proposed Solution:
I've developed a reference implementation that could serve as a starting point:
import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple
def compute_norm(x: mx.array,
p: int,
dim: Optional[Union[int, List[int]]] = None,
keepdim: bool = False) -> mx.array:
"""
Compute the p-norm of a tensor along specified dimensions.
Args:
x: Input array
p: Order of the norm (1 or 2)
dim: Dimension(s) along which to compute the norm
keepdim: Whether to keep the reduced dimensions
Returns:
MLX array containing the computed norm
"""
if p not in [1, 2]:
raise ValueError("Only p-norms with p of 1 or 2 are supported")
# Handle dimension input
if dim is None:
dim = tuple(range(x.ndim))
elif isinstance(dim, int):
dim = (dim,)
if p == 1:
# L1 norm
return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
else:
# L2 norm
return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))
def weight_norm(weight_v: mx.array,
weight_g: mx.array,
dim: Optional[int] = None) -> mx.array:
"""
Applies weight normalization to the input tensor.
Weight normalization reparameterizes weight vectors in a neural network
as a magnitude scalar times a direction vector: w = g * v/||v||
Args:
weight_v: Weight direction tensor (v)
weight_g: Weight magnitude tensor (g)
dim: Dimension along which to normalize. If None, normalize over all dims
except dim=-1
Returns:
Normalized weight tensor
"""
rank = len(weight_v.shape)
if dim is not None:
# Adjust negative dim
if dim < -1:
dim += rank
# Create list of axes to normalize over
axes = list(range(rank))
if dim != -1:
axes.remove(dim)
else:
# Default behavior: normalize over all dimensions
axes = list(range(rank))
# Compute L2 norm of v along specified axes
norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
# Normalize and scale by g: w = g * (v / ||v||)
normalized_weight = weight_v / (norm_v + 1e-7) # Add epsilon for numerical stability
return normalized_weight * weight_g
# Example usage:
def test_weight_norm():
# Create sample tensors
v = mx.random.normal((64, 3, 3)) # Direction tensor
g = mx.random.normal((64, 1, 1)) # Magnitude tensor
# Apply weight normalization
w = weight_norm(v, g, dim=0)
# Verify shape
assert w.shape == v.shape
# Verify norm along specified dimension
norm_w = compute_norm(w, p=2, dim=[1, 2], keepdim=True)
mx.eval(norm_w) # Force computation
return w, norm_w
if __name__ == "__main__":
normalized_weight, weight_norm = test_weight_norm()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request