# Low-Rank Compensation for Int8 Quantization

This notebook demonstrates how to improve Int8 quantization accuracy using Low-Rank Compensation (LoRC) via SVD.

In [None]:
import numpy as np
import time

def generate_weight_matrix(rows: int, cols: int) -> np.ndarray:
    # Generates a random weight matrix with values uniformly distributed between -1 and 1
    return np.random.uniform(-1, 1, (rows, cols)).astype(np.float32)

def calculate_mean_squared_error(original: np.ndarray, reconstructed: np.ndarray) -> float:
    # Calculates the Mean Squared Error between the original and reconstructed matrices
    return np.mean((original - reconstructed) ** 2)

### Int8 Quantization

Standard symmetric quantization to 8-bit integers.

In [None]:
def quantize_to_int8(matrix: np.ndarray) -> (np.ndarray, float):
    # Quantizes the input matrix to int8
    scale = np.max(np.abs(matrix)) / 127  # Scale factor for int8 (-128 to 127, signed 8-bit integer)
    quantized = np.clip(np.round(matrix / scale), -128, 127).astype(np.int8) # Ensure values fit in int8 range (round and cast)
    return quantized, scale

def dequantize_from_int8(quantized: np.ndarray, scale: float) -> np.ndarray:
    # Dequantizes the int8 matrix back to float32
    return quantized.astype(np.float32) * scale

### Low-Rank Compensation

We use Singular Value Decomposition (SVD) to approximate the *residual error* (Original - Quantized). The goal is to decompose a matrix $M$ into two smaller matrices $A$ and $B$ such that $M \approx A \times B$.

Here is the step-by-step breakdown:

1.  **Perform SVD**:
    ```python
    U, S, Vt = np.linalg.svd(original, full_matrices=False)
    ```
    *   This function breaks the `original` matrix into three components: $U$, $\Sigma$ (represented by `S`), and $V^T$ (`Vt`).
    *   Mathematically: $M = U \cdot \Sigma \cdot V^T$.
    *   `S` is a 1D array containing the **singular values**, which represent the "strength" or importance of each component, sorted from largest to smallest.

2.  **Create Matrix A (Left Factor)**:
    ```python
    A = U[:, :rank] * np.sqrt(S[:rank])
    ```
    *   `U[:, :rank]`: Takes the first `rank` columns of $U$ (the most important features).
    *   `np.sqrt(S[:rank])`: Takes the square root of the top `rank` singular values.
    *   **Why square root?** To balance the magnitude between $A$ and $B$, the singular values are split evenly. $A$ gets half the "weight" ($\sqrt{\Sigma}$).
    *   **Result**: $A$ has shape `(rows, rank)`.

3.  **Create Matrix B (Right Factor)**:
    ```python
    B = (np.sqrt(S[:rank])[:, np.newaxis] * Vt[:rank, :])
    ```
    *   `Vt[:rank, :]`: Takes the first `rank` rows of $V^T$.
    *   `np.sqrt(S[:rank])`: The other half of the "weight".
    *   `[:, np.newaxis]`: Reshapes the 1D array into a column vector so it can be multiplied correctly (broadcasted) across the rows of `Vt`.
    *   **Result**: $B$ has shape `(rank, cols)`.

**Summary**:
When you multiply $A \times B$, you get $U_{rank} \cdot \sqrt{S_{rank}} \cdot \sqrt{S_{rank}} \cdot V^T_{rank} = U_{rank} \cdot S_{rank} \cdot V^T_{rank}$. This reconstructs the original matrix using only the most important information, discarding the noise (the smaller singular values).

In [None]:
def low_rank_compensation(original: np.ndarray, rank: int=8) -> np.ndarray:
    # The idea is to approximate the original matrix with a low-rank matrix
    # In details, we perform SVD and keep only the top 'rank' singular values/vectors
    # the shape of original is (m, n)
    U, S, Vt = np.linalg.svd(original, full_matrices=False) # the shape of U is (m, m), S is (min(m,n),), Vt is (n, n)

    # Get the top 'rank' components
    A = U[:, :rank] * np.sqrt(S[:rank])
    # Get the corresponding B matrix
    B = (np.sqrt(S[:rank])[:, np.newaxis] * Vt[:rank, :])

    # Reconstruct the low-rank approximation
    compensated = np.dot(A, B)
    return compensated, A, B

def apply_low_rank_compensation(quantized: np.ndarray, scale: float, compensation: (np.ndarray, np.ndarray)) -> np.ndarray:
    # Dequantize the quantized matrix
    dequantized = dequantize_from_int8(quantized, scale)
    # Apply low-rank compensation
    compensated = dequantized + compensation
    return compensated

### Why Calculate the Residual?

It is crucial to apply the low-rank approximation to the **residual error** ($W - W_{quant}$), not the original matrix.

*   **Incorrect Approach**: If we approximate the original matrix ($W \approx AB$) and add it to the quantized matrix ($W_{quant}$), we get $W_{quant} + AB \approx 2W$. This doubles the signal magnitude and ruins accuracy.
*   **Correct Approach (Residual)**: We want to capture what was *lost* during quantization.
    1.  Calculate Residual: $R = W - W_{quant}$.
    2.  Approximate Residual: $R \approx AB$.
    3.  Compensate: $W_{final} = W_{quant} + AB \approx W_{quant} + (W - W_{quant}) = W$.

This effectively "adds back" the information lost during the quantization process.

In [None]:
np.random.seed(42)  # For reproducibility

rows, cols = 256, 256
original_matrix = generate_weight_matrix(rows, cols)

# Int8 Quantization
quantized_int8, scale_int8 = quantize_to_int8(original_matrix)
reconstructed_int8 = dequantize_from_int8(quantized_int8, scale_int8)
mse_int8 = calculate_mean_squared_error(original_matrix, reconstructed_int8)
print(f"Int8 Quantization MSE: {mse_int8}")

# Generate low-rank compensation
residual = original_matrix - reconstructed_int8
compensation, A, B = low_rank_compensation(residual, rank=8)

# Apply Low-Rank Compensation
compensated_reconstruction = apply_low_rank_compensation(quantized_int8, scale_int8, compensation)
mse_compensated = calculate_mean_squared_error(original_matrix, compensated_reconstruction)
print(f"Int8 with Low-Rank Compensation MSE: {mse_compensated}")

print(f"Original Matrix Sample:\n{original_matrix[:5, :5]}\n")
print(f"INT8 Quantized Sample:\n{quantized_int8[:5, :5]}\n")
print(f"INT8 Dequantized Sample:\n{reconstructed_int8[:5, :5]}\n")
print(f"INT8 MSE: {mse_int8}\n\n")
print(f"Compensated Reconstruction Sample:\n{compensated_reconstruction[:5, :5]}\n")
print(f"Compensated MSE: {mse_compensated}\n")