<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MN_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
from scipy.linalg import svd, signm # Scipy for SVD and matrix sign (msign)

# --- 1. Helper Functions (Approximations of Complex Operations) ---

def matrix_sign(M):
    """
    Computes the unitary component of the polar decomposition (msign/retraction)
    for a rectangular matrix M, which snaps its singular values to one.
    This is the numerically stable way to compute msign for m x n matrices.
    """
    # Full matrices=False ensures U is m x k and V is n x k (where k=rank M),
    # matching the notation in Figure 5 and the Stiefel manifold definition.
    U, S, Vt = np.linalg.svd(M, full_matrices=False)

    # The unitary factor Q is U @ V^T. The diagonal matrix S (singular values)
    # is discarded, effectively setting all singular values to 1.
    return U @ Vt

def project_to_stiefel(W):
    """
    Retracts W back to the Stiefel manifold by calculating the matrix sign of W.
    W_retracted <- msign(W).
    For a tall matrix (m >= n), this projection is equivalent to W / sqrt(W^T W)
    in the Euclidean norm case, but for spectral norm, msign is used.
    """
    return matrix_sign(W)


# --- 2. Manifold Muon Core Logic ---

class ManifoldMuonOptimizer:
    def __init__(self, m, n, learning_rate_eta, dual_rate_alpha, dual_steps):
        """
        Initializes the optimizer for a weight matrix W of size (m, n).
        :param learning_rate_eta: The main learning rate (eta) for the weight update.
        :param dual_rate_alpha: The learning rate (alpha) for the dual ascent.
        :param dual_steps: Number of dual ascent steps to approximate Lambda_opt.
        """
        self.eta = learning_rate_eta
        self.alpha = dual_rate_alpha
        self.dual_steps = dual_steps

        # Dual variable Lambda, initialized to zero matrix (n x n) since W is m x n (tall)
        self.Lambda = np.zeros((n, n))

    def H_gradient(self, G, W, Lambda):
        """
        Calculates the gradient of the dual function H(Lambda)[cite: 267].
        H(Lambda) = -eta * d/dLambda || G + 2W(Lambda + Lambda^T) ||_nuclear [cite: 267]
        """
        # The term inside the msign
        M = G + 2 * W @ (Lambda + Lambda.T)

        # H is proportional to: -eta * [ W^T msign(M) + msign(M)^T W ] [cite: 268, 269]
        msign_M = matrix_sign(M)

        # Calculate the full gradient H(Lambda)
        H = -self.eta * (W.T @ msign_M + msign_M.T @ W)

        return H

    def step(self, W, G):
        """
        Performs a single optimization step.
        :param W: Current weight matrix (m x n) on the Stiefel manifold.
        :param G: Gradient of the loss w.r.t W (m x n).
        """
        # -----------------------------------------------------------
        # Step 1: Run gradient ascent on the dual variable Lambda to solve for Lambda_opt
        # -----------------------------------------------------------
        Lambda_opt = self.Lambda.copy()
        for _ in range(self.dual_steps):
            H = self.H_gradient(G, W, Lambda_opt)
            # Dual ascent: Lambda <- Lambda + alpha * H(Lambda)
            Lambda_opt += self.alpha * H

        self.Lambda = Lambda_opt # Update the optimizer's internal Lambda for next step

        # -----------------------------------------------------------
        # Step 2: Compute the update A_opt
        # -----------------------------------------------------------
        # A_opt = -eta * msign(G + 2W(Lambda_opt + Lambda_opt^T))
        M_opt = G + 2 * W @ (Lambda_opt + Lambda_opt.T)
        A_opt = -self.eta * matrix_sign(M_opt)

        # -----------------------------------------------------------
        # Step 3: Apply the update to the weights
        # -----------------------------------------------------------
        # W <- W + A_opt
        W_updated = W + A_opt

        # -----------------------------------------------------------
        # Step 4: Retract the weights back to the manifold
        # -----------------------------------------------------------
        # W <- msign(W)
        W_retracted = project_to_stiefel(W_updated)

        return W_retracted

# --- 3. Example Usage ---

# Define matrix dimensions (m >= n for Stiefel(m,n))
m, n = 128, 64

# Initialize a weight matrix W on the Stiefel manifold
# (W^T W = I_n, implying W has orthonormal columns)
# This is a good starting point for manifold Muon [cite: 217]
U_init, _, Vt_init = svd(np.random.randn(m, n), full_matrices=False)
W = U_init @ Vt_init # An orthogonal matrix (or a part of one) is on the Stiefel manifold

# Create a mock gradient G (e.g., from a backpropagation step)
G = np.random.randn(m, n)

# Optimizer parameters (example values)
ETA = 1e-2     # Learning rate for the weight update
ALPHA = 1e-2   # Learning rate for the dual variable update
DUAL_STEPS = 5 # Number of inner dual ascent steps

optimizer = ManifoldMuonOptimizer(m, n, ETA, ALPHA, DUAL_STEPS)

# Perform one optimization step
W_next = optimizer.step(W, G)

# --- 4. Verification (Post-Step Sanity Check) ---

# Check if the updated W is still on the Stiefel manifold (W^T W should be close to I_n)
I_n_check = W_next.T @ W_next
identity_matrix = np.eye(n)

print(f"--- Manifold Muon Demo (m={m}, n={n}) ---")
print(f"Initial W shape: {W.shape}")
print(f"Next W shape: {W_next.shape}")
print(f"W^T W should be close to I_n. Max error: {np.max(np.abs(I_n_check - identity_matrix)):.2e}")

# The max error should be very small, confirming W_next is successfully retracted
# to the Stiefel manifold.

--- Manifold Muon Demo (m=128, n=64) ---
Initial W shape: (128, 64)
Next W shape: (128, 64)
W^T W should be close to I_n. Max error: 1.33e-15


## LLM

In [3]:
# Conceptual Python Framework for Modular Manifold Comparison

import torch
import numpy as np
from collections import OrderedDict
# Assuming a hypothetical 'modula_lib' that provides tools for the modular norm
# In a real scenario, this would be a specialized library.
# import modula_lib

# ----------------------------------------------------------------------
# 1. CORE FUNCTION: The Modular Norm (Conceptual Implementation)
# ----------------------------------------------------------------------

def calculate_modular_norm(weights_dict, arch_config):
    """
    Conceptually computes the modular norm for an entire LLM architecture.

    In reality, this involves recursively applying composition rules (like max/sum)
    and calculating the spectral norm for each weight matrix (Manifold Muon's norm).
    The weights are normalized by scalar coefficients (s_i) which budget
    learning rates across layers. [cite: 351, 352, 323, 399]

    :param weights_dict: A dictionary of all weight tensors in the LLM.
    :param arch_config: Architectural details (layer types, composition, etc.).
    :return: A dictionary of layer-specific modular norm values.
    """

    # --- Step A: Compute Layer-wise Spectral Norms (The Manifold Muon Norm) ---
    # The spectral norm is the measure of length used for the Stiefel manifold
    # and the Manifold Muon optimizer. [cite: 224, 191]
    layer_norms = {}
    for name, W in weights_dict.items():
        # The spectral norm of a matrix W is its largest singular value (sigma_max).
        sigma_max = torch.linalg.svdvals(W)[0].item()
        layer_norms[name] = sigma_max

    # --- Step B: Apply Recursive Modular Composition (The 'Modular Manifold' Part) ---
    # This step is highly complex but fundamentally involves:
    # 1. Assigning scalar coefficients (s_i) based on the layer's position. [cite: 351]
    # 2. Composing norms using the 'max' function for sequence-connected modules. [cite: 351]

    # Simplified, conceptual budgeting (Example: penalizing late layers less)
    modular_norms = {}
    s_i_base = 1.0 # Conceptual scalar coefficient
    for name, norm in layer_norms.items():
        # Example of architectural budgeting: Layers deeper in the network
        # (e.g., in Transformer blocks) might be weighted differently.
        if "attn" in name or "mlp" in name:
            # Modular norm for a composite module (e.g., an entire Transformer block)
            # is max(s1 * ||w1||, s2 * ||w2||). [cite: 351]
            modular_norms[name] = s_i_base * norm
        else:
            modular_norms[name] = norm

    # The overall modular norm of the LLM is derived from these compositions.
    return modular_norms

# ----------------------------------------------------------------------
# 2. DEMO: Comparing LLM Architectures
# ----------------------------------------------------------------------

def llm_comparison_demo():
    # --- A. Load Mock Architectures and Weights ---
    # In a real study, you would load the full state_dict from a Hugging Face model

    # 1. LLM A: Conceptual Llama-like (e.g., using RMS norm for input/output scaling)
    # Weights for a Transformer layer: W_q, W_k, W_v, W_o, W_up, W_down
    llama_weights = OrderedDict({
        "attn.W_q": torch.randn(128, 128) * 1.05,
        "attn.W_k": torch.randn(128, 128) * 0.98,
        "mlp.W_up": torch.randn(128, 512) * 1.02, # FFN is often wider
        "mlp.W_down": torch.randn(512, 128) * 0.95,
    })

    # 2. LLM B: Conceptual Mistral-like (e.g., optimized for smaller size/different scaling)
    mistral_weights = OrderedDict({
        "attn.W_q": torch.randn(128, 128) * 1.20, # Higher unconstrained norm
        "attn.W_k": torch.randn(128, 128) * 1.15,
        "mlp.W_up": torch.randn(128, 512) * 0.85, # Lower FFN norm
        "mlp.W_down": torch.randn(512, 128) * 0.90,
    })

    # --- B. Calculate Modular Norm (or Spectral Norm as a Proxy) ---
    print("--- Conceptual Modular Norm Analysis ---")

    llama_norms = calculate_modular_norm(llama_weights, "Llama_Config")
    mistral_norms = calculate_modular_norm(mistral_weights, "Mistral_Config")

    print("\n[LLM A (Llama-like) Modular Norms (Spectral Norm Proxy)]")
    for name, norm in llama_norms.items():
        print(f"  {name:<12}: {norm:.4f}")

    print("\n[LLM B (Mistral-like) Modular Norms (Spectral Norm Proxy)]")
    for name, norm in mistral_norms.items():
        print(f"  {name:<12}: {norm:.4f}")

    # --- C. Interpretation and Comparison ---
    print("\n--- Interpretation based on Modular Manifolds Theory ---")

    llama_max_norm = max(llama_norms.values())
    mistral_max_norm = max(mistral_norms.values())

    # The modular norm is tied to the Lipschitz constant of the network. [cite: 325]
    if llama_max_norm < mistral_max_norm:
        print(f"LLM A Max Norm: {llama_max_norm:.4f}")
        print(f"LLM B Max Norm: {mistral_max_norm:.4f}")
        print("LLM A *might* have a tighter Lipschitz constant in this norm, potentially leading to more stable, transferable learning rates if trained with a norm-constrained optimizer like Manifold Muon. [cite: 342, 325, 407]")
    else:
        print(f"LLM A Max Norm: {llama_max_norm:.4f}")
        print(f"LLM B Max Norm: {mistral_max_norm:.4f}")
        print("LLM B *might* have a tighter Lipschitz constant, suggesting more predictable behavior to weight perturbations in this architectural comparison. [cite: 342, 325, 407]")

llm_comparison_demo()

--- Conceptual Modular Norm Analysis ---

[LLM A (Llama-like) Modular Norms (Spectral Norm Proxy)]
  attn.W_q    : 23.2828
  attn.W_k    : 22.0583
  mlp.W_up    : 34.3833
  mlp.W_down  : 31.4547

[LLM B (Mistral-like) Modular Norms (Spectral Norm Proxy)]
  attn.W_q    : 26.7807
  attn.W_k    : 25.6144
  mlp.W_up    : 28.0282
  mlp.W_down  : 31.2510

--- Interpretation based on Modular Manifolds Theory ---
LLM A Max Norm: 34.3833
LLM B Max Norm: 31.2510
LLM B *might* have a tighter Lipschitz constant, suggesting more predictable behavior to weight perturbations in this architectural comparison. [cite: 342, 325, 407]
