# Zhang 2025: Key Insight (Simple Pedagogical Example)

**Paper**: "Loss-Minimizing Model Compression via Joint Factorization Optimization" (Zhang et al., 2025)

This notebook demonstrates the fundamental insight in the clearest possible way using a simple 3×3 matrix example.

## Core Question

When we factorize a matrix $W \approx L \cdot R^T$, we introduce noise $\delta = W - L \cdot R^T$.

**Traditional approach (SVD, etc.)**: Minimize $\|W - L \cdot R^T\|^2$ without considering any downstream objective.

**Zhang 2025 insight**: The factorization noise affects your objective function. If the noise points **opposite** to the gradient, your loss **decreases**!

$$\Delta \text{Loss} = \frac{\partial \text{Loss}}{\partial W} \cdot \delta$$

If $\frac{\partial \text{Loss}}{\partial W} \cdot \delta < 0$, then **loss decreases** ✓

## Setup: Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

np.random.seed(42)

## Step 1: Create Original Matrix and Define Objective

We'll use a simple 3×3 matrix and a toy objective function: we want the row sums to match target values.

**Objective**: Make row sums equal to $[10, 20, 30]$

**Loss function**: $\text{Loss} = \sum_i (\text{row\_sum}_i - \text{target}_i)^2$

In [None]:
# Create a simple 3x3 matrix
W_original = np.array([[1.0, 2.0, 3.0],
                        [4.0, 5.0, 6.0],
                        [7.0, 8.0, 9.0]])

print("Original Matrix W:")
print(W_original)
print()

# Suppose we want to achieve some objective (e.g., row sums = targets)
targets = np.array([10.0, 20.0, 30.0])  # Desired row sums
actual = W_original.sum(axis=1)

print(f"Target row sums:  {targets}")
print(f"Current row sums: {actual}")
print(f"Errors:           {actual - targets}")

## Step 2: Compute Gradient

The **gradient** $\frac{\partial \text{Loss}}{\partial W}$ tells us: "How does each element of $W$ affect the loss?"

For our row-sum objective:
$$\frac{\partial \text{Loss}}{\partial W_{ij}} = 2 \cdot \text{error}_i$$

All elements in the same row contribute equally to that row's error.

In [None]:
# Compute gradient: how does each element affect the objective?
errors = actual - targets
gradient = errors[:, np.newaxis] @ np.ones((1, 3))  # Broadcast to full matrix

print("Gradient (∂Loss/∂W):")
print(gradient)
print()
print("→ Gradient shows: each element contributes equally to row sum error")

## Step 3: Traditional SVD Factorization

**Traditional approach**: Use SVD to get low-rank approximation that minimizes $\|W - L \cdot R^T\|^2$

SVD doesn't know about our objective function—it just minimizes reconstruction error.

In [None]:
# Traditional SVD: minimize ||W - L·R^T||² without considering gradient
U, s, Vt = svd(W_original)

# Rank-2 approximation
L_svd = U[:, :2] @ np.diag(np.sqrt(s[:2]))
R_svd = Vt[:2, :].T @ np.diag(np.sqrt(s[:2]))
W_svd = L_svd @ R_svd.T

noise_svd = W_svd - W_original

print("Traditional SVD (rank=2):")
print(f"  Reconstruction error: {np.linalg.norm(W_original - W_svd, 'fro'):.2e}")
print()
print("Noise introduced by SVD:")
print(noise_svd)
print()

# Check objective
actual_svd = W_svd.sum(axis=1)
loss_original = np.sum((actual - targets) ** 2)
loss_svd = np.sum((actual_svd - targets) ** 2)

print(f"Original loss: {loss_original:.4f}")
print(f"SVD loss:      {loss_svd:.4f}")
print(f"Change:        {loss_svd - loss_original:+.4f}")

### Zhang 2025 Analysis: Check Gradient · Noise

The key equation from Zhang 2025:
$$\Delta \text{Loss} \approx 2 \cdot (\text{gradient} \cdot \text{noise})$$

If $\text{gradient} \cdot \text{noise} < 0$, loss **decreases** ✓  
If $\text{gradient} \cdot \text{noise} > 0$, loss **increases** ✗

In [None]:
# Zhang 2025 insight: gradient · noise tells us about loss change
inner_product_svd = np.sum(gradient * noise_svd)
predicted_change_svd = 2 * inner_product_svd  # Factor of 2 from quadratic loss

print("Zhang 2025 Analysis of SVD:")
print(f"  gradient · noise = {inner_product_svd:.4f}")
print(f"  Predicted loss change: {predicted_change_svd:.4f}")
print(f"  Actual loss change:    {loss_svd - loss_original:+.4f}")
print()

if inner_product_svd > 0:
    print("  → Noise in SAME direction as gradient → Loss INCREASES")
elif inner_product_svd < 0:
    print("  → Noise in OPPOSITE direction to gradient → Loss DECREASES ✓")
else:
    print("  → Noise orthogonal to gradient → Loss unchanged")

## Step 4: Zhang 2025 Gradient-Guided Factorization

**Zhang 2025 approach**: Instead of blindly minimizing reconstruction error, choose the factorization noise to point **opposite to the gradient**.

This way, the factorization error helps reduce the objective loss!

In [None]:
# Add noise in opposite direction to gradient
noise_zhang = -0.5 * gradient  # Opposite to gradient, small magnitude
W_zhang = W_original + noise_zhang

print("Gradient-guided noise (opposite to gradient):")
print(noise_zhang)
print()

print("New matrix W:")
print(W_zhang)
print()

# Check objective
actual_zhang = W_zhang.sum(axis=1)
loss_zhang = np.sum((actual_zhang - targets) ** 2)

print(f"Original loss: {loss_original:.4f}")
print(f"Zhang loss:    {loss_zhang:.4f}")
print(f"Change:        {loss_zhang - loss_original:+.4f}")

In [None]:
inner_product_zhang = np.sum(gradient * noise_zhang)
predicted_change_zhang = 2 * inner_product_zhang

print("Zhang 2025 Analysis:")
print(f"  gradient · noise = {inner_product_zhang:.4f}")
print(f"  Predicted loss change: {predicted_change_zhang:.4f}")
print(f"  Actual loss change:    {loss_zhang - loss_original:+.4f}")
print()
print("  ✓ Loss DECREASES because noise opposes gradient!")

## Step 5: Comparison

Compare the two approaches side-by-side:

In [None]:
print("=" * 70)
print("COMPARISON: SVD vs. Zhang 2025")
print("=" * 70)
print()
print(f"SVD:         Loss = {loss_svd:.4f}  (change: {loss_svd - loss_original:+.4f})")
print(f"Zhang 2025:  Loss = {loss_zhang:.4f}  (change: {loss_zhang - loss_original:+.4f})")
print()

improvement = ((loss_svd - loss_zhang) / loss_original) * 100
print(f"Zhang is {abs(improvement):.1f}% better!")

## Step 6: Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Gradient
ax = axes[0]
im = ax.imshow(gradient, cmap='RdBu_r', vmin=-10, vmax=10)
ax.set_title('Gradient\n(∂Loss/∂W)')
ax.set_xlabel('Column')
ax.set_ylabel('Row')
for i in range(3):
    for j in range(3):
        ax.text(j, i, f'{gradient[i,j]:.1f}', 
                ha='center', va='center', color='white' if abs(gradient[i,j]) > 5 else 'black')
plt.colorbar(im, ax=ax)

# Plot 2: SVD Noise
ax = axes[1]
im = ax.imshow(noise_svd, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_title(f'SVD Noise\n(gradient·noise = {inner_product_svd:.2f})')
ax.set_xlabel('Column')
ax.set_ylabel('Row')
for i in range(3):
    for j in range(3):
        ax.text(j, i, f'{noise_svd[i,j]:.2f}', 
                ha='center', va='center', color='white' if abs(noise_svd[i,j]) > 0.5 else 'black')
plt.colorbar(im, ax=ax)

# Plot 3: Zhang Noise
ax = axes[2]
im = ax.imshow(noise_zhang, cmap='RdBu_r', vmin=-5, vmax=5)
ax.set_title(f'Zhang 2025 Noise\n(gradient·noise = {inner_product_zhang:.2f})')
ax.set_xlabel('Column')
ax.set_ylabel('Row')
for i in range(3):
    for j in range(3):
        ax.text(j, i, f'{noise_zhang[i,j]:.2f}', 
                ha='center', va='center', color='white' if abs(noise_zhang[i,j]) > 2 else 'black')
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

## Key Takeaway

**Traditional factorization (SVD, etc.)** minimizes $\|W - L \cdot R^T\|^2$ without considering the objective function.

**Zhang 2025**: Choose $L \cdot R^T$ such that the noise $(W - L \cdot R^T)$ points in the **OPPOSITE direction to the gradient** → Loss DECREASES!

This is **joint optimization**: factorization error + objective are optimized simultaneously, not sequentially.

---

### Application to SAXS Deconvolution

In SAXS, we have: $D = C \cdot S^T + \text{noise}$

**Current methods (EFA, MCR-ALS, REGALS)**:
1. Stage 1: Minimize $\|D - C \cdot S^T\|^2$ (factorization)
2. Stage 2: Apply constraints (non-negativity, smoothness, etc.)

**Zhang 2025 suggests**: Optimize both stages **simultaneously** by choosing factorization that inherently satisfies physical constraints while fitting data.

This could:
- Avoid rotation ambiguity
- Prevent rank inflation
- Find optimal number of components
- Improve physical plausibility