# Zhang 2025: Joint Optimization for Matrix Factorization (Full Demo)

**Paper**: "Loss-Minimizing Model Compression via Joint Factorization Optimization" (Zhang et al., 2025)

This notebook demonstrates the complete algorithm with realistic data and iterative optimization.

## Problem Setting

Given:
- Matrix $W \in \mathbb{R}^{n \times m}$ (e.g., data or model weights)
- Downstream task with objective function $\mathcal{L}(W, \text{targets})$

Goal: Find low-rank approximation $W \approx L \cdot R^T$ with rank $k \ll \min(n,m)$

## Two Approaches

### Approach 1: Traditional (Sequential)
1. Stage 1: Minimize $\|W - L \cdot R^T\|^2$ using SVD
2. Stage 2: Use factorized $W' = L \cdot R^T$ for downstream task
3. Hope that good reconstruction → good task performance

### Approach 2: Zhang 2025 (Joint)
1. Simultaneously optimize:
   - Factorization accuracy: $\|W - L \cdot R^T\|^2 \leq \epsilon$
   - Task performance: minimize $\mathcal{L}(L \cdot R^T, \text{targets})$
2. Key insight: Choose factorization that reduces objective loss

**Result**: Joint optimization achieves better task performance than sequential approach!

## Setup: Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

np.random.seed(42)

## Step 1: Generate Toy Data

Create a matrix with low-rank structure + noise to simulate realistic data.

In [None]:
def generate_toy_data(n_samples=100, n_features=50, rank=5, noise=0.1):
    """Generate toy data with low-rank structure + noise"""
    # True low-rank factors
    U_true = np.random.randn(n_samples, rank)
    V_true = np.random.randn(n_features, rank)
    
    # True low-rank matrix
    W_true = U_true @ V_true.T
    
    # Add noise
    W_noisy = W_true + noise * np.random.randn(n_samples, n_features)
    
    return W_noisy, W_true, U_true, V_true

# Generate data
n_samples, n_features = 100, 50
true_rank = 5
W_noisy, W_true, _, _ = generate_toy_data(n_samples, n_features, true_rank, noise=0.5)

# Generate target values for objective function
# (In SAXS: this could be "physical plausibility scores")
targets = W_true.sum(axis=1) + 0.1 * np.random.randn(n_samples)

print(f"Data shape: {W_noisy.shape}")
print(f"True rank: {true_rank}")

## Step 2: Define Objective Function and Gradient

For demonstration, we use a simple objective: predict binary targets from matrix row sums.

In SAXS context, this could represent "physical plausibility" measures.

In [None]:
def objective_function(W, targets):
    """
    Simple objective: predict binary targets from matrix W
    Loss = mean squared error between row sums and targets
    """
    predictions = W.sum(axis=1)
    loss = np.mean((predictions - targets) ** 2)
    return loss


def compute_gradient(W, targets):
    """
    Gradient of objective function with respect to W
    Tells us: "How does each element of W affect the loss?"
    """
    predictions = W.sum(axis=1)
    errors = predictions - targets
    
    # Gradient: ∂Loss/∂W[i,j] = 2 * error[i] / n_features
    gradient = 2 * errors[:, np.newaxis] @ np.ones((1, W.shape[1])) / W.shape[1]
    
    return gradient

# Test the functions
print("Original data objective loss:", objective_function(W_noisy, targets))

## Step 3: Demonstrate Gradient Direction Insight

Zhang 2025's key insight: $\Delta\text{Loss} = \frac{\partial \text{Loss}}{\partial W} \cdot \delta$

Where $\delta$ is the factorization noise.

In [None]:
# Simple 2D example
W = np.array([[1.0, 2.0, 3.0],
               [4.0, 5.0, 6.0]])
targets_demo = np.array([10.0, 25.0])

gradient = compute_gradient(W, targets_demo)

print("=" * 70)
print("KEY INSIGHT: Gradient Direction Effect")
print("=" * 70)
print()
print(f"Original matrix W:")
print(W)
print()
print(f"Gradient ∂Loss/∂W:")
print(gradient)
print()

# Case 1: Random factorization noise (no gradient info)
delta_random = 0.5 * np.random.randn(*W.shape)
inner_product_random = np.sum(gradient * delta_random)

W_random = W + delta_random
loss_original = objective_function(W, targets_demo)
loss_random = objective_function(W_random, targets_demo)

print(f"Case 1: Random noise")
print(f"  Inner product (gradient · δ):  {inner_product_random:.4f}")
print(f"  Loss change:                    {loss_random - loss_original:.4f}")
print()

# Case 2: Gradient-guided noise (Zhang 2025 approach)
delta_guided = -0.5 * gradient  # Opposite direction to gradient
inner_product_guided = np.sum(gradient * delta_guided)

W_guided = W + delta_guided
loss_guided = objective_function(W_guided, targets_demo)

print(f"Case 2: Gradient-guided noise (opposite to gradient)")
print(f"  Inner product (gradient · δ):  {inner_product_guided:.4f}")
print(f"  Loss change:                    {loss_guided - loss_original:.4f}")
print()

print("✓ Gradient-guided factorization REDUCES loss!")

## Step 4: Traditional SVD Factorization

Stage 1: Minimize $\|W - L \cdot R^T\|^2$ using SVD  
Stage 2: Hope factorized version works well for objective

In [None]:
def traditional_svd_factorization(W, rank):
    """
    Traditional approach: Minimize ||W - L·R^T||² using SVD
    """
    U, s, Vt = svd(W, full_matrices=False)
    
    # Keep only top 'rank' components
    L = U[:, :rank] @ np.diag(np.sqrt(s[:rank]))
    R = Vt[:rank, :].T @ np.diag(np.sqrt(s[:rank]))
    
    W_approx = L @ R.T
    
    reconstruction_error = np.linalg.norm(W - W_approx, 'fro')
    
    return L, R, W_approx, reconstruction_error

# Apply traditional SVD
print("APPROACH 1: Traditional SVD Factorization")
print("-" * 70)
L_svd, R_svd, W_svd, error_svd = traditional_svd_factorization(W_noisy, true_rank)
loss_svd = objective_function(W_svd, targets)

print(f"Factorization error: {error_svd:.4f}")
print(f"Objective loss:      {loss_svd:.4f}")

## Step 5: Zhang 2025 Joint Optimization

Simultaneously optimize:
1. Factorization accuracy: $\|W - L \cdot R^T\|^2 \leq \epsilon$
2. Objective function: minimize $\mathcal{L}(L \cdot R^T, \text{targets})$

Key: Update $L$ and $R$ to reduce BOTH factorization error and objective loss.

In [None]:
def zhang_joint_optimization(W, targets, rank, n_iterations=100, lr=0.01):
    """
    Zhang 2025 approach: Joint optimization
    """
    n_samples, n_features = W.shape
    
    # Initialize with SVD (good starting point)
    L, R, _, _ = traditional_svd_factorization(W, rank)
    
    # Compute gradient of objective
    gradient = compute_gradient(W, targets)
    
    losses = []
    factorization_errors = []
    
    for iteration in range(n_iterations):
        W_current = L @ R.T
        
        # Current loss and factorization error
        current_loss = objective_function(W_current, targets)
        current_error = np.linalg.norm(W - W_current, 'fro')
        
        losses.append(current_loss)
        factorization_errors.append(current_error)
        
        # Gradient descent on loss
        delta_W = -lr * gradient
        
        # Project delta_W onto low-rank factors
        R_norm = R.T @ R + 1e-8 * np.eye(rank)
        L_norm = L.T @ L + 1e-8 * np.eye(rank)
        
        L = L + lr * (delta_W @ R) @ np.linalg.inv(R_norm)
        R = R + lr * (delta_W.T @ L) @ np.linalg.inv(L_norm)
        
        # Reproject to maintain factorization accuracy
        if iteration % 10 == 0:
            W_current = L @ R.T
            error = W - W_current
            L = L + 0.1 * error @ R @ np.linalg.pinv(R.T @ R)
    
    W_final = L @ R.T
    final_error = np.linalg.norm(W - W_final, 'fro')
    
    return L, R, W_final, final_error, losses, factorization_errors

# Apply Zhang 2025 joint optimization
print("APPROACH 2: Zhang 2025 Joint Optimization")
print("-" * 70)
L_joint, R_joint, W_joint, error_joint, losses, errors = zhang_joint_optimization(
    W_noisy, targets, true_rank, n_iterations=100, lr=0.01
)
loss_joint = objective_function(W_joint, targets)

print(f"Factorization error: {error_joint:.4f}")
print(f"Objective loss:      {loss_joint:.4f}")

## Step 6: Comparison

In [None]:
print("=" * 70)
print("COMPARISON")
print("=" * 70)
print(f"Original matrix loss:        {objective_function(W_noisy, targets):.4f}")
print(f"SVD factorization loss:      {loss_svd:.4f}")
print(f"Joint optimization loss:     {loss_joint:.4f}")
print()
print(f"Loss improvement (SVD):      {objective_function(W_noisy, targets) - loss_svd:.4f}")
print(f"Loss improvement (Joint):    {objective_function(W_noisy, targets) - loss_joint:.4f}")
print()

improvement_ratio = (loss_svd - loss_joint) / loss_svd * 100
print(f"Joint is {improvement_ratio:.1f}% better than SVD")

## Step 7: Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Original data
ax = axes[0, 0]
ax.imshow(W_noisy, aspect='auto', cmap='viridis')
ax.set_title(f'Original Noisy Matrix\nObjective Loss: {objective_function(W_noisy, targets):.4f}')
ax.set_xlabel('Features')
ax.set_ylabel('Samples')

# Plot 2: SVD reconstruction
ax = axes[0, 1]
ax.imshow(W_svd, aspect='auto', cmap='viridis')
ax.set_title(f'SVD Reconstruction (rank={true_rank})\nObjective Loss: {loss_svd:.4f}')
ax.set_xlabel('Features')
ax.set_ylabel('Samples')

# Plot 3: Joint optimization result
ax = axes[1, 0]
ax.imshow(W_joint, aspect='auto', cmap='viridis')
ax.set_title(f'Joint Optimization (rank={true_rank})\nObjective Loss: {loss_joint:.4f}')
ax.set_xlabel('Features')
ax.set_ylabel('Samples')

# Plot 4: Optimization trajectory
ax = axes[1, 1]
ax2 = ax.twinx()

line1 = ax.plot(losses, 'b-', label='Objective Loss', linewidth=2)
line2 = ax2.plot(errors, 'r--', label='Factorization Error', linewidth=2)

ax.set_xlabel('Iteration')
ax.set_ylabel('Objective Loss', color='b')
ax2.set_ylabel('Factorization Error', color='r')
ax.tick_params(axis='y', labelcolor='b')
ax2.tick_params(axis='y', labelcolor='r')
ax.set_title('Joint Optimization Trajectory')
ax.grid(True, alpha=0.3)

lines = line1 + line2
labels = [l.get_label() for l in lines]
ax.legend(lines, labels, loc='upper right')

plt.tight_layout()
plt.show()

## Application to SAXS Deconvolution

### Current SAXS Methods (Two-Stage Approach)

In SAXS, we decompose: $D = C \cdot S^T + \text{noise}$

**Stage 1 (Factorization)**: 
- EFA, MCR-ALS, REGALS minimize $\|D - C \cdot S^T\|^2$

**Stage 2 (Constraints)**: 
- Apply physical constraints: non-negativity, smoothness, compact support

**Problem**: Two-stage is suboptimal (Zhang 2025 confirms this!)

### Zhang 2025 Paradigm for SAXS

**Joint optimization**:
$$\min_{C,S} \|D - C \cdot S^T\|^2 + \lambda_1 \cdot \text{Smoothness}(C) + \lambda_2 \cdot \text{Smoothness}(S)$$

Subject to: $C \geq 0, S \geq 0$

**Benefits**:
1. **Optimal K**: Lemma 3 provides framework for determining optimal number of components
2. **Avoid rotation ambiguity**: Physical objectives guide factorization directly
3. **Prevent rank inflation**: Joint optimization naturally selects minimal rank
4. **Better physical plausibility**: Constraints built into factorization, not added after

### Potential Applications

1. **Gradient-guided SVD** for initial decomposition
2. **Joint optimization** of reconstruction + physical objectives
3. **Optimal rank determination** using loss-compression trade-off (Lemma 3)
4. **Direct incorporation** of P(r) smoothness, Kratky analysis, etc.

This could significantly improve SAXS deconvolution beyond current methods!