<a href="https://colab.research.google.com/github/nncliff/qwen-32B/blob/main/chapter-4/ipynb/quant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quantization Demo
This notebook demonstrates Int8 and Int4 quantization using NumPy.

In [1]:
import numpy as np

def generate_random_matrix(rows: int, cols: int) -> np.ndarray:
    # Generates a random matrix with values uniformly distributed between -1 and 1
    return np.random.uniform(-1, 1, (rows, cols))

def calculate_mean_squared_error(original: np.ndarray, reconstructed: np.ndarray) -> float:
    # Calculates the Mean Squared Error between the original and reconstructed matrices
    return np.mean((original - reconstructed) ** 2)

### Int8 Quantization Explanation

The following code performs symmetric int8 quantization:

1.  **Calculate the Scale Factor**:
    ```python
    scale = np.max(np.abs(matrix)) / 127
    ```
    *   `np.abs(matrix)`: Takes the absolute value of every element.
    *   `np.max(...)`: Finds the largest absolute value.
    *   `/ 127`: Maps the range `[-max_abs, max_abs]` to `[-127, 127]`.

2.  **Quantize the Values**:
    ```python
    quantized = np.round(matrix / scale).astype(np.int8)
    ```
    *   `matrix / scale`: Normalizes data to the range -127 to 127.
    *   `np.round(...)`: Rounds to the nearest integer.
    *   `.astype(np.int8)`: Casts to 8-bit integer.

In [2]:
def quantize_to_int8(matrix: np.ndarray) -> np.ndarray:
    # Quantizes the input matrix to int8
    scale = np.max(np.abs(matrix)) / 127  # Scale factor for int8 (-128 to 127, signed 8-bit integer)
    quantized = np.round(matrix / scale).astype(np.int8) # Ensure values fit in int8 range (round and cast)
    return quantized, scale

def dequantize_from_int8(quantized: np.ndarray, scale: float) -> np.ndarray:
    # Dequantizes the int8 matrix back to float32
    return quantized.astype(np.float32) * scale

### Int4 Quantization Explanation

The following code simulates 4-bit quantization (stored in int8):

1.  **Scale and Round**:
    ```python
    quantized = np.round(matrix / scale)
    ```
    *   **Scaling**: Divides by `scale` (calculated as `max_val / 7`) to map values to approx -7.0 to 7.0.
    *   **Rounding**: Converts to nearest whole number.

2.  **Clip and Cast**:
    ```python
    quantized = np.clip(quantized, -8, 7).astype(np.int8)
    ```
    *   **Clipping**: Forces values to stay within -8 and 7 (the range of a signed 4-bit integer).
    *   **Casting**: Stores the result in an 8-bit integer container (since NumPy lacks a native `int4` type).

In [3]:
def quantize_to_int4(matrix: np.ndarray) -> np.ndarray:
    # Quantizes the input matrix to int4
    scale = np.max(np.abs(matrix)) / 7  # Scale factor for int4
    quantized = np.round(matrix / scale) #.astype(np.int8)???
    quantized = np.clip(quantized, -8, 7).astype(np.int8)  # Ensure values fit in int4 range
    return quantized, scale

def dequantize_from_int4(quantized: np.ndarray, scale: float) -> np.ndarray:
    # Dequantizes the int4 matrix back to float32
    return quantized.astype(np.float32) * scale

In [4]:
np.random.seed(42)  # For reproducibility

rows, cols = 100, 100
original_matrix = generate_random_matrix(rows, cols)

# Int8 Quantization
quantized_int8, scale_int8 = quantize_to_int8(original_matrix)
reconstructed_int8 = dequantize_from_int8(quantized_int8, scale_int8)
mse_int8 = calculate_mean_squared_error(original_matrix, reconstructed_int8)
print(f"Int8 Quantization MSE: {mse_int8}")

# Int4 Quantization
quantized_int4, scale_int4 = quantize_to_int4(original_matrix)
reconstructed_int4 = dequantize_from_int4(quantized_int4, scale_int4)
mse_int4 = calculate_mean_squared_error(original_matrix, reconstructed_int4)
print(f"Int4 Quantization MSE: {mse_int4}")

print(f"Original Matrix Sample:\n{original_matrix}\n")
print(f"INT8 Quantized Sample:\n{quantized_int8}\n")
print(f"INT8 Dequantized Sample:\n{reconstructed_int8}\n")
print(f"INT8 MSE: {mse_int8}\n\n")
print(f"INT4 Quantized Sample:\n{quantized_int4}\n")
print(f"INT4 Dequantized Sample:\n{reconstructed_int4}\n")
print(f"INT4 MSE: {mse_int4}\n")

Int8 Quantization MSE: 5.1724704854865734e-06
Int4 Quantization MSE: 0.0017039919225166954
Original Matrix Sample:
[[-0.25091976  0.90142861  0.46398788 ... -0.14491796 -0.94916175
  -0.78421715]
 [-0.93714163  0.27282082 -0.37128804 ...  0.79422052  0.77417285
   0.55975109]
 [ 0.28406329 -0.83172007 -0.67674257 ... -0.56835795  0.24578095
  -0.82930507]
 ...
 [ 0.56146546 -0.56999046  0.4312711  ...  0.29069125 -0.17052839
   0.86070235]
 [ 0.56976597 -0.59998216 -0.73831528 ...  0.35963625  0.66184498
  -0.91680809]
 [-0.95377866 -0.2539858  -0.70242343 ...  0.89341583 -0.20502402
  -0.56571919]]

INT8 Quantized Sample:
[[ -32  114   59 ...  -18 -121 -100]
 [-119   35  -47 ...  101   98   71]
 [  36 -106  -86 ...  -72   31 -105]
 ...
 [  71  -72   55 ...   37  -22  109]
 [  72  -76  -94 ...   46   84 -116]
 [-121  -32  -89 ...  113  -26  -72]]

INT8 Dequantized Sample:
[[-0.25196264  0.89761691  0.46455612 ... -0.14172899 -0.95273374
  -0.78738325]
 [-0.93698607  0.27558414 -0.37007