Skip to content

Feature Request: Add Weight Normalization Support (weight_norm) #1888

@Blaizzy

Description

@Blaizzy

MLX currently lacks built-in support for weight normalization, which is a crucial feature for various deep learning architectures, particularly in audio processing and generative models. Weight normalization is a reparameterization technique that decouples the magnitude and direction of weight vectors, often leading to better conditioning and faster convergence.
Current Situation:

  • No built-in equivalent to PyTorch's torch.nn.utils.weight_norm
  • Users need to implement custom solutions, which may not be optimal or consistent

Proposed Solution:
I've developed a reference implementation that could serve as a starting point:

import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple

def compute_norm(x: mx.array, 
                p: int, 
                dim: Optional[Union[int, List[int]]] = None, 
                keepdim: bool = False) -> mx.array:
    """
    Compute the p-norm of a tensor along specified dimensions.
    
    Args:
        x: Input array
        p: Order of the norm (1 or 2)
        dim: Dimension(s) along which to compute the norm
        keepdim: Whether to keep the reduced dimensions
    
    Returns:
        MLX array containing the computed norm
    """
    if p not in [1, 2]:
        raise ValueError("Only p-norms with p of 1 or 2 are supported")
    
    # Handle dimension input
    if dim is None:
        dim = tuple(range(x.ndim))
    elif isinstance(dim, int):
        dim = (dim,)
    
    if p == 1:
        # L1 norm
        return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
    else:
        # L2 norm
        return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))

def weight_norm(weight_v: mx.array, 
                weight_g: mx.array, 
                dim: Optional[int] = None) -> mx.array:
    """
    Applies weight normalization to the input tensor.
    
    Weight normalization reparameterizes weight vectors in a neural network 
    as a magnitude scalar times a direction vector: w = g * v/||v||
    
    Args:
        weight_v: Weight direction tensor (v)
        weight_g: Weight magnitude tensor (g)
        dim: Dimension along which to normalize. If None, normalize over all dims
            except dim=-1
    
    Returns:
        Normalized weight tensor
    """
    rank = len(weight_v.shape)
    
    if dim is not None:
        # Adjust negative dim
        if dim < -1:
            dim += rank
            
        # Create list of axes to normalize over
        axes = list(range(rank))
        if dim != -1:
            axes.remove(dim)
    else:
        # Default behavior: normalize over all dimensions
        axes = list(range(rank))
    
    # Compute L2 norm of v along specified axes
    norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
    
    # Normalize and scale by g: w = g * (v / ||v||)
    normalized_weight = weight_v / (norm_v + 1e-7)  # Add epsilon for numerical stability
    return normalized_weight * weight_g

# Example usage:
def test_weight_norm():
    # Create sample tensors
    v = mx.random.normal((64, 3, 3))  # Direction tensor
    g = mx.random.normal((64, 1, 1))  # Magnitude tensor
    
    # Apply weight normalization
    w = weight_norm(v, g, dim=0)
    
    # Verify shape
    assert w.shape == v.shape
    
    # Verify norm along specified dimension
    norm_w = compute_norm(w, p=2, dim=[1, 2], keepdim=True)
    mx.eval(norm_w)  # Force computation
    
    return w, norm_w

if __name__ == "__main__":
    normalized_weight, weight_norm = test_weight_norm()

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions